Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitadf64d2

Browse files
committed
Querybuilder for vector search prototyped
1 parente2e7085 commitadf64d2

File tree

4 files changed

+70
-34
lines changed

4 files changed

+70
-34
lines changed
Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
1-
frompypikaimportQuery,Table,AliasedQuery,Order
1+
frompypikaimportQuery,Table,AliasedQuery,Order,Field
2+
frompypika.functionsimportCast
23
frompgml.queriesimportEmbed,CosineDistance
34

4-
embeddings_table=Table("embeddings_table")
5-
chunks_table=Table("chunks_table")
6-
documents_table=Table("documents_table")
7-
8-
query_embed=Query().select(Embed(transformer="instructxl",text="hello"))
9-
print(query_embed)
10-
11-
query_table=AliasedQuery("query_cte")
12-
query_cte=Query().with_(query_embed,"query_cte").from_(query_table).select('*')
13-
print(query_cte)
5+
embeddings_table=Table("test_collection_1.embeddings_d2beb7")
6+
chunks_table=Table("test_collection_1.chunks")
7+
documents_table=Table("test_collection_1.documents")
148

9+
model="intfloat/e5-small"
10+
text="hello world"
11+
query_embed=Query().select(Embed(transformer=model,text=text))
12+
query_cte=AliasedQuery("query_cte")
13+
cte=AliasedQuery("cte")
1514
table_embed= (
1615
Query()
17-
.with_(
18-
Query()
19-
.from_(embeddings_table)
20-
.cross_join(query_table)
21-
.on(embeddings_table.embedding==query_table.embedding).select('*'),
22-
"cte",
16+
.from_(embeddings_table)
17+
.select(
18+
"chunk_id",
19+
CosineDistance(
20+
embeddings_table.embedding,Cast(query_cte.embedding,"vector")
21+
).as_("score"),
2322
)
24-
.from_(AliasedQuery("cte"))
25-
.select("score")
23+
.inner_join(AliasedQuery("query_cte"))
24+
.on(Field(1)==Field(1))
25+
)
26+
27+
query_cte= (
28+
Query()
29+
.with_(query_embed,"query_cte")
30+
.with_(table_embed,"cte")
31+
.from_("cte")
32+
.select(cte.score,chunks_table.chunk,documents_table.metadata).orderby(cte.score,order=Order.desc)
33+
.inner_join(chunks_table)
34+
.on(chunks_table.id==cte.chunk_id)
35+
.inner_join(documents_table)
36+
.on(documents_table.id==chunks_table.document_id)
2637
)
27-
print(table_embed)
38+
print(query_cte.get_sql().replace('"',""))
39+

‎pgml-sdks/python/pgml/pgml/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
run_select_statement,
66
run_drop_or_delete_statement,
77
)
8-
from .queriesimportEmbed,CosineDistance
8+
from .queriesimportEmbed,CosineDistance

‎pgml-sdks/python/pgml/pgml/collection.py‎

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -805,23 +805,23 @@ def vector_search(
805805

806806
defexecute(self,sql_statement:QueryBuilder)->List[Dict[str,Any]]:
807807
conn=self.pool.getconn()
808-
results=run_select_statement(conn,sql_statement.get_sql().replace("\"",""))
808+
results=run_select_statement(conn,sql_statement.get_sql().replace('"',""))
809809
self.pool.putconn(conn)
810810
returnresults
811811

812-
defvector_recall(self,
812+
defvector_recall(
813+
self,
813814
query:str,
814815
query_parameters:Optional[Dict[str,Any]]= {},
815816
top_k:int=5,
816817
model_id:int=1,
817-
splitter_id:int=1)->List[Dict[str,Any]]:
818-
819-
818+
splitter_id:int=1,
819+
)->List[Dict[str,Any]]:
820820
ifmodel_idinself._cache_model_names.keys():
821821
model=self._cache_model_names[model_id]
822822
else:
823823
models=Table(self.models_table)
824-
q=Query.from_(models).select('name').where(models.id==model_id)
824+
q=Query.from_(models).select("name").where(models.id==model_id)
825825
results=self.execute(q)
826826
model=results[0]["name"]
827827
self._cache_model_names[model_id]=model
@@ -835,7 +835,12 @@ def vector_recall(self,
835835

836836
ifnotembeddings_table:
837837
transforms_table=Table(self.transforms_table)
838-
q=Query.from_(transforms_table).select('table_name').where(transforms_table.model_id==model_id).where(transforms_table.splitter_id==splitter_id)
838+
q= (
839+
Query.from_(transforms_table)
840+
.select("table_name")
841+
.where(transforms_table.model_id==model_id)
842+
.where(transforms_table.splitter_id==splitter_id)
843+
)
839844
embedding_table_results=self.execute(q)
840845
ifembedding_table_results:
841846
embeddings_table=embedding_table_results[0]["table_name"]
@@ -851,8 +856,17 @@ def vector_recall(self,
851856

852857
conn=self.pool.getconn()
853858

854-
cte_query=Query.select(Embed(transformer=model,text=query,parameters=query_parameters)).with_()
855-
table_embedding=Query.from_(embeddings_table).select('chunk_id',CosineDistance(embeddings_table.embedding,query_embedding.cosine)).cross_join(query_embedding)
859+
cte_query=Query.select(
860+
Embed(transformer=model,text=query,parameters=query_parameters)
861+
).with_()
862+
table_embedding= (
863+
Query.from_(embeddings_table)
864+
.select(
865+
"chunk_id",
866+
CosineDistance(embeddings_table.embedding,query_embedding.cosine),
867+
)
868+
.cross_join(query_embedding)
869+
)
856870
cte_select_statement="""
857871
WITH query_cte AS (
858872
SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
@@ -883,4 +897,4 @@ def vector_recall(self,
883897
)
884898
self.pool.putconn(conn)
885899

886-
returnsearch_results
900+
returnsearch_results

‎pgml-sdks/python/pgml/pgml/queries.py‎

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,20 @@
33
fromtypingimportDict
44
frompypikaimportJSON,Array
55

6+
67
classEmbed(Function):
7-
def__init__(self,transformer:str,text:str,parameters:Dict[str,Any]= {},alias:str="embedding")->None:
8-
super(Embed,self).__init__('pgml.embed',transformer,text,JSON(parameters),alias=alias)
8+
def__init__(
9+
self,
10+
transformer:str,
11+
text:str,
12+
parameters:Dict[str,Any]= {},
13+
alias:str="embedding",
14+
)->None:
15+
super(Embed,self).__init__(
16+
"pgml.embed",transformer,text,JSON(parameters),alias=alias
17+
)
18+
919

1020
classCosineDistance(Function):
1121
def__init__(self,lhs:Array,rhs:Array,alias:str="cosine")->None:
12-
super(CosineDistance,self).__init__('cosine_distance',lhs,rhs,alias=alias)
22+
super(CosineDistance,self).__init__("cosine_distance",lhs,rhs,alias=alias)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp