Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7f030e3

Browse files
committed
pgml embed and cosine functions
1 parent2d7a335 commit7f030e3

File tree

4 files changed

+100
-1
lines changed

4 files changed

+100
-1
lines changed

‎pgml-sdks/python/pgml/pgml/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
run_create_or_insert_statement,
55
run_select_statement,
66
run_drop_or_delete_statement,
7-
)
7+
)

‎pgml-sdks/python/pgml/pgml/collection.py‎

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222

2323
fromlangchain.text_splitterimportRecursiveCharacterTextSplitter
24+
frompypika.queriesimportSchema,Table,Query,QueryBuilder
2425

2526
FORMAT="%(message)s"
2627
logging.basicConfig(
@@ -800,3 +801,82 @@ def vector_search(
800801
self.pool.putconn(conn)
801802

802803
returnsearch_results
804+
805+
defexecute(self,sql_statement:QueryBuilder)->List[Dict[str,Any]]:
806+
conn=self.pool.getconn()
807+
results=run_select_statement(conn,sql_statement.get_sql().replace("\"",""))
808+
self.pool.putconn(conn)
809+
returnresults
810+
811+
defvector_recall(self,
812+
query:str,
813+
query_parameters:Optional[Dict[str,Any]]= {},
814+
top_k:int=5,
815+
model_id:int=1,
816+
splitter_id:int=1)->List[Dict[str,Any]]:
817+
818+
819+
ifmodel_idinself._cache_model_names.keys():
820+
model=self._cache_model_names[model_id]
821+
else:
822+
models=Table(self.models_table)
823+
q=Query.from_(models).select('name').where(models.id==model_id)
824+
results=self.execute(q)
825+
model=results[0]["name"]
826+
self._cache_model_names[model_id]=model
827+
828+
embeddings_table=""
829+
ifmodel_idinself._cache_embeddings_table_names.keys():
830+
ifsplitter_idinself._cache_embeddings_table_names[model_id].keys():
831+
embeddings_table=self._cache_embeddings_table_names[model_id][
832+
splitter_id
833+
]
834+
835+
ifnotembeddings_table:
836+
transforms_table=Table(self.transforms_table)
837+
q=Query.from_(transforms_table).select('table_name').where(transforms_table.model_id==model_id).where(transforms_table.splitter_id==splitter_id)
838+
embedding_table_results=self.execute(q)
839+
ifembedding_table_results:
840+
embeddings_table=embedding_table_results[0]["table_name"]
841+
self._cache_embeddings_table_names[model_id]= {
842+
splitter_id:embeddings_table
843+
}
844+
else:
845+
rprint(
846+
"Embeddings for model id %d and splitter id %d do not exist.\nPlease run collection.generate_embeddings(model_id = %d, splitter_id = %d)"
847+
% (model_id,splitter_id,model_id,splitter_id)
848+
)
849+
return []
850+
851+
conn=self.pool.getconn()
852+
cte_select_statement="""
853+
WITH query_cte AS (
854+
SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
855+
),
856+
cte AS (
857+
SELECT chunk_id, 1 - ({embeddings_table}.embedding <=> query_cte.query_embedding::float8[]::vector) AS score
858+
FROM {embeddings_table}
859+
CROSS JOIN query_cte
860+
ORDER BY score DESC
861+
)
862+
SELECT cte.score, chunks.chunk, documents.metadata
863+
FROM cte
864+
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
865+
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
866+
""".format(
867+
model=sql.Literal(model).as_string(conn),
868+
query_text=query,
869+
model_params=sql.Literal(json.dumps(query_parameters)).as_string(conn),
870+
embeddings_table=embeddings_table,
871+
chunks_table=self.chunks_table,
872+
documents_table=self.documents_table,
873+
)
874+
875+
cte_select_statement+=" LIMIT {top_k}".format(top_k=top_k)
876+
877+
search_results=run_select_statement(
878+
conn,cte_select_statement,order_by="score",ascending=False
879+
)
880+
self.pool.putconn(conn)
881+
882+
returnsearch_results
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
fromtypingimportAny
2+
frompypika.functionsimportFunction
3+
fromtypingimportDict
4+
frompypikaimportJSON,Array
5+
6+
classEmbed(Function):
7+
def__init__(self,transformer:str,text:str,parameters:Dict[str,Any]= {},alias:str="embedding")->None:
8+
super(Embed,self).__init__('pgml.embed',transformer,text,JSON(parameters),alias=alias)
9+
10+
classCosineDistance(Function):
11+
def__init__(self,lhs:Array,rhs:Array,alias:str="cosine")->None:
12+
super(CosineDistance,self).__init__('cosine_distance',lhs,rhs,alias=alias)

‎pgml-sdks/python/pgml/tests/test_collection.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,12 @@ def test_vector_search_generic_and_metadata_filter(self):
164164
)
165165
assertresults[0]["metadata"]["user"]=="John Doe"
166166

167+
deftest_vector_recall(self):
168+
self.collection.upsert_documents(self.documents_with_reviews_metadata)
169+
self.collection.generate_chunks()
170+
self.collection.generate_embeddings()
171+
results=self.collection.vector_recall("product is abc")
172+
print(results)
173+
167174
# def tearDown(self) -> None:
168175
# self.db.archive_collection(self.collection_name)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp