Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitf753be8

Browse files
committed
new query class
1 parent0f3bf87 commitf753be8

File tree

3 files changed

+52
-33
lines changed

3 files changed

+52
-33
lines changed

‎pgml-sdks/python/pgml/pgml/collection.py‎

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
)
2222

2323
fromlangchain.text_splitterimportRecursiveCharacterTextSplitter
24-
frompypikaimportQuery,Table,AliasedQuery,Order,Field
24+
frompypikaimportJSON,Array,Table,AliasedQuery,Order
25+
frompypikaimportQuery
26+
2527
frompypika.queriesimportQueryBuilder
2628
frompypika.functionsimportCast
2729
from .queriesimportEmbed,CosineDistance
2830

31+
_cache_model_names= {}
32+
_cache_embeddings_table_names= {}
33+
2934

3035
FORMAT="%(message)s"
3136
logging.basicConfig(
@@ -808,47 +813,64 @@ def vector_search(
808813

809814
defexecute(self,sql_statement:QueryBuilder)->List[Dict[str,Any]]:
810815
conn=self.pool.getconn()
811-
results=run_select_statement(conn,sql_statement.get_sql())
816+
results=run_select_statement(conn,str(sql_statement))
812817
self.pool.putconn(conn)
813818
returnresults
814819

820+
defquery(self):
821+
returnPGMLQuery(self)
822+
823+
824+
classPGMLQuery(QueryBuilder):
825+
def__init__(self,collection:Collection)->None:
826+
self.collection=collection
827+
828+
def__str__(self)->str:
829+
returnself.get_sql()
830+
831+
deflimit(self,_limit:int):
832+
self=self.limit(_limit)
833+
returnself
834+
815835
defvector_recall(
816836
self,
817837
query:str,
818838
query_parameters:Optional[Dict[str,Any]]= {},
839+
top_k:int=5,
819840
model_id:int=1,
820841
splitter_id:int=1,
821-
)->List[Dict[str,Any]]:
822-
ifmodel_idinself._cache_model_names.keys():
823-
model=self._cache_model_names[model_id]
842+
):
843+
ifmodel_idin_cache_model_names.keys():
844+
model=_cache_model_names[model_id]
824845
else:
825-
models=Table(self.models_table.split(".")[1],schema=self.name)
846+
models=Table(
847+
self.collection.models_table.split(".")[1],schema=self.collection.name
848+
)
826849
q=Query.from_(models).select("name").where(models.id==model_id)
827-
results=self.execute(q)
850+
results=self.collection.execute(q)
828851
model=results[0]["name"]
829-
self._cache_model_names[model_id]=model
852+
_cache_model_names[model_id]=model
830853

831854
embeddings_table=""
832-
ifmodel_idinself._cache_embeddings_table_names.keys():
833-
ifsplitter_idinself._cache_embeddings_table_names[model_id].keys():
834-
embeddings_table=self._cache_embeddings_table_names[model_id][
835-
splitter_id
836-
]
855+
ifmodel_idin_cache_embeddings_table_names.keys():
856+
ifsplitter_idin_cache_embeddings_table_names[model_id].keys():
857+
embeddings_table=_cache_embeddings_table_names[model_id][splitter_id]
837858

838859
ifnotembeddings_table:
839860
transforms_table=Table(
840-
self.transforms_table.split(".")[1],schema=self.name
861+
self.collection.transforms_table.split(".")[1],
862+
schema=self.collection.name,
841863
)
842864
q= (
843865
Query.from_(transforms_table)
844866
.select("table_name")
845867
.where(transforms_table.model_id==model_id)
846868
.where(transforms_table.splitter_id==splitter_id)
847869
)
848-
embedding_table_results=self.execute(q)
870+
embedding_table_results=self.collection.execute(q)
849871
ifembedding_table_results:
850872
embeddings_table=embedding_table_results[0]["table_name"]
851-
self._cache_embeddings_table_names[model_id]= {
873+
_cache_embeddings_table_names[model_id]= {
852874
splitter_id:embeddings_table
853875
}
854876
else:
@@ -859,12 +881,12 @@ def vector_recall(
859881
return []
860882

861883
embeddings_table=embeddings_table.split(".")[1]
862-
chunks_table=self.chunks_table.split(".")[1]
863-
documents_table=self.documents_table.split(".")[1]
884+
chunks_table=self.collection.chunks_table.split(".")[1]
885+
documents_table=self.collection.documents_table.split(".")[1]
864886

865-
embeddings_table=Table(embeddings_table,schema=self.name)
866-
chunks_table=Table(chunks_table,schema=self.name)
867-
documents_table=Table(documents_table,schema=self.name)
887+
embeddings_table=Table(embeddings_table,schema=self.collection.name)
888+
chunks_table=Table(chunks_table,schema=self.collection.name)
889+
documents_table=Table(documents_table,schema=self.collection.name)
868890

869891
query_embed=Query().select(
870892
Embed(transformer=model,text=query,parameters=query_parameters)
@@ -887,7 +909,7 @@ def vector_recall(
887909
.cross()
888910
)
889911

890-
query_cte= (
912+
self= (
891913
Query()
892914
.with_(query_embed,"query_cte")
893915
.with_(table_embed,"cte")
@@ -897,7 +919,7 @@ def vector_recall(
897919
.inner_join(chunks_table)
898920
.on(chunks_table.id==cte.chunk_id)
899921
.inner_join(documents_table)
900-
.on(documents_table.id==chunks_table.document_id)
922+
.on(documents_table.id==chunks_table.document_id).limit(top_k)
901923
)
902924

903-
returnquery_cte
925+
returnself

‎pgml-sdks/python/pgml/pgml/queries.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
fromtypingimportAny
22
frompypika.functionsimportFunction,Cast
3-
fromtypingimportDict,List
4-
frompypikaimportJSON,Array,Field
3+
fromtypingimportDict,List,Optional
4+
frompypikaimportJSON,Array,Table,AliasedQuery,Order
5+
frompypikaimportQueryasPyPikaQuery
56
importjson
7+
fromrichimportprintasrprint
68

79

810
classEmbed(Function):

‎pgml-sdks/python/pgml/tests/test_collection.py‎

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,11 @@ def test_vector_recall(self):
170170
self.collection.upsert_documents(self.documents_with_reviews_metadata)
171171
self.collection.generate_chunks()
172172
self.collection.generate_embeddings()
173-
documents_table=Table("documents",schema=self.collection.name)
174173
query= (
175-
self.collection.vector_recall("product is abc")
176-
.where(documents_table.metadata.contains({"source":"amazon"}))
177-
.where(
178-
Cast(documents_table.metadata.get_json_value("reviews"),"INTEGER")<45
179-
)
180-
.limit(2)
174+
self.collection.query().vector_recall("product is abc").limit(2).limit(1)
181175
)
182176
results=self.collection.execute(query)
177+
print(results)
183178
assertresults[0]["metadata"]["user"]=="John Doe"
184179

185180
# def tearDown(self) -> None:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp