Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd17303f

Browse files
committed
Working vector recall with query builder
1 parentadf64d2 commitd17303f

File tree

4 files changed

+65
-49
lines changed

4 files changed

+65
-49
lines changed

‎pgml-sdks/python/pgml/.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
examples/pika_ex1.py

‎pgml-sdks/python/pgml/examples/pika_ex1.py‎

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
frompypikaimportQuery,Table,AliasedQuery,Order,Field
22
frompypika.functionsimportCast
3+
frompypika.enumsimportSqlTypes
34
frompgml.queriesimportEmbed,CosineDistance
5+
frompypika.utilsimportformat_quotes
6+
frompsycopgimportsql
47

5-
embeddings_table=Table("test_collection_1.embeddings_d2beb7")
6-
chunks_table=Table("test_collection_1.chunks")
7-
documents_table=Table("test_collection_1.documents")
8+
9+
10+
embeddings_table=Table("embeddings_d2beb7",schema="test_collection_1")
11+
chunks_table=Table("chunks",schema="test_collection_1")
12+
documents_table=Table("documents",schema="test_collection_1")
813

914
model="intfloat/e5-small"
1015
text="hello world"
@@ -21,7 +26,7 @@
2126
).as_("score"),
2227
)
2328
.inner_join(AliasedQuery("query_cte"))
24-
.on(Field(1)==Field(1))
29+
.cross()
2530
)
2631

2732
query_cte= (
@@ -35,5 +40,6 @@
3540
.inner_join(documents_table)
3641
.on(documents_table.id==chunks_table.document_id)
3742
)
38-
print(query_cte.get_sql().replace('"',""))
3943

44+
final_query=query_cte.where(documents_table.metadata.contains({"reviews" :42})).limit(5)
45+
print(final_query)

‎pgml-sdks/python/pgml/pgml/collection.py‎

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
run_select_statement,
2121
)
2222

23-
from .queriesimportEmbed,CosineDistance
2423
fromlangchain.text_splitterimportRecursiveCharacterTextSplitter
25-
frompypika.queriesimportSchema,Table,Query,QueryBuilder
24+
frompypikaimportQuery,Table,AliasedQuery,Order,Field
25+
frompypika.queriesimportQueryBuilder
26+
frompypika.functionsimportCast
27+
from .queriesimportEmbed,CosineDistance
28+
2629

2730
FORMAT="%(message)s"
2831
logging.basicConfig(
@@ -805,22 +808,21 @@ def vector_search(
805808

806809
defexecute(self,sql_statement:QueryBuilder)->List[Dict[str,Any]]:
807810
conn=self.pool.getconn()
808-
results=run_select_statement(conn,sql_statement.get_sql().replace('"',""))
811+
results=run_select_statement(conn,sql_statement.get_sql())
809812
self.pool.putconn(conn)
810813
returnresults
811814

812815
defvector_recall(
813816
self,
814817
query:str,
815818
query_parameters:Optional[Dict[str,Any]]= {},
816-
top_k:int=5,
817819
model_id:int=1,
818820
splitter_id:int=1,
819821
)->List[Dict[str,Any]]:
820822
ifmodel_idinself._cache_model_names.keys():
821823
model=self._cache_model_names[model_id]
822824
else:
823-
models=Table(self.models_table)
825+
models=Table(self.models_table.split(".")[1],schema=self.name)
824826
q=Query.from_(models).select("name").where(models.id==model_id)
825827
results=self.execute(q)
826828
model=results[0]["name"]
@@ -834,7 +836,9 @@ def vector_recall(
834836
]
835837

836838
ifnotembeddings_table:
837-
transforms_table=Table(self.transforms_table)
839+
transforms_table=Table(
840+
self.transforms_table.split(".")[1],schema=self.name
841+
)
838842
q= (
839843
Query.from_(transforms_table)
840844
.select("table_name")
@@ -854,47 +858,43 @@ def vector_recall(
854858
)
855859
return []
856860

857-
conn=self.pool.getconn()
861+
embeddings_table=embeddings_table.split(".")[1]
862+
chunks_table=self.chunks_table.split(".")[1]
863+
documents_table=self.documents_table.split(".")[1]
864+
865+
embeddings_table=Table(embeddings_table,schema=self.name)
866+
chunks_table=Table(chunks_table,schema=self.name)
867+
documents_table=Table(documents_table,schema=self.name)
858868

859-
cte_query=Query.select(
869+
query_embed=Query().select(
860870
Embed(transformer=model,text=query,parameters=query_parameters)
861-
).with_()
862-
table_embedding= (
863-
Query.from_(embeddings_table)
871+
)
872+
query_cte=AliasedQuery("query_cte")
873+
cte=AliasedQuery("cte")
874+
table_embed= (
875+
Query()
876+
.from_(embeddings_table)
864877
.select(
865878
"chunk_id",
866-
CosineDistance(embeddings_table.embedding,query_embedding.cosine),
879+
CosineDistance(
880+
embeddings_table.embedding,Cast(query_cte.embedding,"vector")
881+
).as_("score"),
867882
)
868-
.cross_join(query_embedding)
869-
)
870-
cte_select_statement="""
871-
WITH query_cte AS (
872-
SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
873-
),
874-
cte AS (
875-
SELECT chunk_id, 1 - ({embeddings_table}.embedding <=> query_cte.query_embedding::float8[]::vector) AS score
876-
FROM {embeddings_table}
877-
CROSS JOIN query_cte
878-
ORDER BY score DESC
883+
.inner_join(AliasedQuery("query_cte"))
884+
.cross()
879885
)
880-
SELECT cte.score, chunks.chunk, documents.metadata
881-
FROM cte
882-
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
883-
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
884-
""".format(
885-
model=sql.Literal(model).as_string(conn),
886-
query_text=query,
887-
model_params=sql.Literal(json.dumps(query_parameters)).as_string(conn),
888-
embeddings_table=embeddings_table,
889-
chunks_table=self.chunks_table,
890-
documents_table=self.documents_table,
891-
)
892-
893-
cte_select_statement+=" LIMIT {top_k}".format(top_k=top_k)
894886

895-
search_results=run_select_statement(
896-
conn,cte_select_statement,order_by="score",ascending=False
887+
query_cte= (
888+
Query()
889+
.with_(query_embed,"query_cte")
890+
.with_(table_embed,"cte")
891+
.from_("cte")
892+
.select(cte.score,chunks_table.chunk,documents_table.metadata)
893+
.orderby(cte.score,order=Order.desc)
894+
.inner_join(chunks_table)
895+
.on(chunks_table.id==cte.chunk_id)
896+
.inner_join(documents_table)
897+
.on(documents_table.id==chunks_table.document_id)
897898
)
898-
self.pool.putconn(conn)
899899

900-
returnsearch_results
900+
returnquery_cte

‎pgml-sdks/python/pgml/tests/test_collection.py‎

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
frompgml.dbutilsimport*
44
importhashlib
55
importos
6+
frompypikaimportTable
7+
frompypika.functionsimportCast
68

79

810
classTestCollection(unittest.TestCase):
@@ -168,8 +170,15 @@ def test_vector_recall(self):
168170
self.collection.upsert_documents(self.documents_with_reviews_metadata)
169171
self.collection.generate_chunks()
170172
self.collection.generate_embeddings()
171-
results=self.collection.vector_recall("product is abc")
172-
print(results)
173-
173+
documents_table=Table("documents",schema=self.collection.name)
174+
query= (
175+
self.collection.vector_recall("product is abc")
176+
.where(documents_table.metadata.contains({"source":"amazon"}))
177+
.where(Cast(documents_table.metadata.get_json_value("reviews"),'INTEGER')<45)
178+
.limit(10)
179+
)
180+
results=self.collection.execute(query)
181+
assertresults[0]["metadata"]["user"]=="John Doe"
182+
174183
# def tearDown(self) -> None:
175184
# self.db.archive_collection(self.collection_name)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp