Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit873ca9b

Browse files
authored
metadata and generic filters in vector search (#689)
1 parent96dd570 commit873ca9b

File tree

3 files changed

+125
-43
lines changed

3 files changed

+125
-43
lines changed

‎pgml-sdks/python/pgml/examples/question_answering.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333

3434
start=time()
3535
query="Who won 20 grammy awards?"
36-
results=collection.vector_search(query,top_k=5,title="Beyoncé")
36+
results=collection.vector_search(query,top_k=5,metadata_filter={"title" :"Beyoncé"})
3737
_end=time()
3838
console.print("\nResults for '%s'"% (query),style="bold")
3939
console.print(results)
4040
console.print("Query time = %0.3f"% (_end-start))
4141

42-
db.archive_collection(collection_name)
42+
#db.archive_collection(collection_name)

‎pgml-sdks/python/pgml/pgml/collection.py‎

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -298,18 +298,23 @@ def upsert_documents(
298298
)
299299
continue
300300

301+
metadata=document
302+
301303
_uuid=""
302304
ifid_keynotinlist(document.keys()):
303305
log.info("id key is not present.. hashing")
304-
source_uuid=hashlib.md5(text.encode("utf-8")).hexdigest()
306+
source_uuid=hashlib.md5(
307+
(text+" "+json.dumps(document)).encode("utf-8")
308+
).hexdigest()
305309
else:
306310
_uuid=document.pop(id_key)
307311
try:
308312
source_uuid=str(uuid.UUID(_uuid))
309313
exceptException:
310-
source_uuid=hashlib.md5(text.encode("utf-8")).hexdigest()
314+
source_uuid=hashlib.md5(str(_uuid).encode("utf-8")).hexdigest()
311315

312-
metadata=document
316+
if_uuid:
317+
document[id_key]=source_uuid
313318

314319
upsert_statement="INSERT INTO {documents_table} (text, source_uuid, metadata) VALUES ({text}, {source_uuid}, {metadata})\
315320
ON CONFLICT (source_uuid)\
@@ -323,9 +328,6 @@ def upsert_documents(
323328

324329
# put the text and id back in document
325330
document[text_key]=text
326-
if_uuid:
327-
document[id_key]=source_uuid
328-
329331
self.pool.putconn(conn)
330332

331333
defregister_text_splitter(
@@ -683,7 +685,8 @@ def vector_search(
683685
top_k:int=5,
684686
model_id:int=1,
685687
splitter_id:int=1,
686-
**kwargs:Any,
688+
metadata_filter:Optional[Dict[str,Any]]= {},
689+
generic_filter:Optional[str]="",
687690
)->List[Dict[str,Any]]:
688691
"""
689692
This function performs a vector search on a database using a query and returns the top matching
@@ -753,13 +756,6 @@ def vector_search(
753756
% (model_id,splitter_id,model_id,splitter_id)
754757
)
755758
return []
756-
757-
ifkwargs:
758-
metadata_filter= [f"documents.metadata->>'{k}' = '{v}'"ifisinstance(v,str)elsef"documents.metadata->>'{k}' ={v}"fork,vinkwargs.items()]
759-
metadata_filter=" AND ".join(metadata_filter)
760-
metadata_filter=f"AND{metadata_filter}"
761-
else:
762-
metadata_filter=""
763759

764760
cte_select_statement="""
765761
WITH query_cte AS (
@@ -775,7 +771,7 @@ def vector_search(
775771
SELECT cte.score, chunks.chunk, documents.metadata
776772
FROM cte
777773
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
778-
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter}
774+
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
779775
""".format(
780776
model=sql.Literal(model).as_string(conn),
781777
query_text=query,
@@ -784,9 +780,20 @@ def vector_search(
784780
top_k=top_k,
785781
chunks_table=self.chunks_table,
786782
documents_table=self.documents_table,
787-
metadata_filter=metadata_filter,
788783
)
789784

785+
ifmetadata_filter:
786+
cte_select_statement+= (
787+
" AND documents.metadata @> {metadata_filter}".format(
788+
metadata_filter=sql.Literal(json.dumps(metadata_filter)).as_string(
789+
conn
790+
)
791+
)
792+
)
793+
794+
ifgeneric_filter:
795+
cte_select_statement+=" AND "+generic_filter
796+
790797
search_results=run_select_statement(
791798
conn,cte_select_statement,order_by="score",ascending=False
792799
)

‎pgml-sdks/python/pgml/tests/test_collection.py‎

Lines changed: 100 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,86 @@
44
importhashlib
55
importos
66

7-
classTestCollection(unittest.TestCase):
87

8+
classTestCollection(unittest.TestCase):
99
defsetUp(self)->None:
1010
local_pgml="postgres://postgres@127.0.0.1:5433/pgml_development"
11-
conninfo=os.environ.get("PGML_CONNECTION",local_pgml)
11+
conninfo=os.environ.get("PGML_CONNECTION",local_pgml)
1212
self.db=Database(conninfo)
1313
self.collection_name="test_collection_1"
1414
self.documents= [
1515
{
16-
"id":hashlib.md5(f"abcded-{i}".encode('utf-8')).hexdigest(),
17-
"text":f"Lorem ipsum{i}",
18-
"metadata": {"source":"test_suite"}
16+
"id":hashlib.md5(f"abcded-{i}".encode("utf-8")).hexdigest(),
17+
"text":f"Lorem ipsum{i}",
18+
"source":"test_suite",
1919
}
2020
foriinrange(4,7)
2121
]
2222
self.documents_no_ids= [
2323
{
24-
"text":f"Lorem ipsum{i}",
25-
"metadata": {"source":"test_suite_no_ids"}
24+
"text":f"Lorem ipsum{i}",
25+
"source":"test_suite_no_ids",
2626
}
2727
foriinrange(1,4)
2828
]
29-
29+
30+
self.documents_with_metadata= [
31+
{
32+
"text":f"Lorem ipsum metadata",
33+
"source":f"url{i}",
34+
"url":f"/home{i}",
35+
"user":f"John Doe-{i+1}",
36+
}
37+
foriinrange(8,12)
38+
]
39+
40+
self.documents_with_reviews= [
41+
{
42+
"text":f"product is abc{i}",
43+
"reviews":i*2,
44+
}
45+
foriinrange(20,25)
46+
]
47+
48+
self.documents_with_reviews_metadata= [
49+
{
50+
"text":f"product is abc{i}",
51+
"reviews":i*2,
52+
"source":"amazon",
53+
"user":"John Doe",
54+
}
55+
foriinrange(20,25)
56+
]
57+
58+
self.documents_with_reviews_metadata+= [
59+
{
60+
"text":f"product is abc{i}",
61+
"reviews":i*2,
62+
"source":"ebay",
63+
}
64+
foriinrange(20,25)
65+
]
66+
3067
self.collection=self.db.create_or_get_collection(self.collection_name)
31-
68+
3269
deftest_create_collection(self):
33-
assertisinstance(self.collection,Collection)
34-
70+
assertisinstance(self.collection,Collection)
71+
3572
deftest_documents_upsert(self):
3673
self.collection.upsert_documents(self.documents)
3774
conn=self.db.pool.getconn()
38-
results=run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
75+
results=run_select_statement(
76+
conn,"SELECT id FROM %s"%self.collection.documents_table
77+
)
3978
self.db.pool.putconn(conn)
4079
assertlen(results)>=len(self.documents)
41-
80+
4281
deftest_documents_upsert_no_ids(self):
4382
self.collection.upsert_documents(self.documents_no_ids)
4483
conn=self.db.pool.getconn()
45-
results=run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
84+
results=run_select_statement(
85+
conn,"SELECT id FROM %s"%self.collection.documents_table
86+
)
4687
self.db.pool.putconn(conn)
4788
assertlen(results)>=len(self.documents_no_ids)
4889

@@ -52,23 +93,25 @@ def test_default_text_splitter(self):
5293

5394
assertsplitter_id==1
5495
assertsplitters[0]["name"]=="RecursiveCharacterTextSplitter"
55-
96+
5697
deftest_default_embeddings_model(self):
5798
model_id=self.collection.register_model()
5899
models=self.collection.get_models()
59-
100+
60101
assertmodel_id==1
61102
assertmodels[0]["name"]=="intfloat/e5-small"
62-
103+
63104
deftest_generate_chunks(self):
64105
self.collection.upsert_documents(self.documents)
65106
self.collection.upsert_documents(self.documents_no_ids)
66107
splitter_id=self.collection.register_text_splitter()
67108
self.collection.generate_chunks(splitter_id=splitter_id)
68-
splitter_params= {"chunk_size":100,"chunk_overlap":20}
69-
splitter_id=self.collection.register_text_splitter(splitter_params=splitter_params)
109+
splitter_params= {"chunk_size":100,"chunk_overlap":20}
110+
splitter_id=self.collection.register_text_splitter(
111+
splitter_params=splitter_params
112+
)
70113
self.collection.generate_chunks(splitter_id=splitter_id)
71-
114+
72115
deftest_generate_embeddings(self):
73116
self.collection.upsert_documents(self.documents)
74117
self.collection.upsert_documents(self.documents_no_ids)
@@ -84,10 +127,42 @@ def test_vector_search(self):
84127
self.collection.generate_embeddings()
85128
results=self.collection.vector_search("Lorem ipsum 1",top_k=2)
86129
assertresults[0]["score"]==1.0
87-
88-
# def tearDown(self) -> None:
89-
# self.db.archive_collection(self.collection_name)
90130

131+
deftest_vector_search_metadata_filter(self):
132+
self.collection.upsert_documents(self.documents)
133+
self.collection.upsert_documents(self.documents_no_ids)
134+
self.collection.upsert_documents(self.documents_with_metadata)
135+
self.collection.generate_chunks()
136+
self.collection.generate_embeddings()
137+
results=self.collection.vector_search(
138+
"Lorem ipsum metadata",
139+
top_k=2,
140+
metadata_filter={"url":"/home 8","source":"url 8"},
141+
)
142+
assertresults[0]["metadata"]["user"]=="John Doe-9"
143+
144+
deftest_vector_search_generic_filter(self):
145+
self.collection.upsert_documents(self.documents_with_reviews)
146+
self.collection.generate_chunks()
147+
self.collection.generate_embeddings()
148+
results=self.collection.vector_search(
149+
"product is abc 21",
150+
top_k=2,
151+
generic_filter="(documents.metadata->>'reviews')::int < 45",
152+
)
153+
assertresults[0]["metadata"]["reviews"]==42
91154

92-
93-
155+
deftest_vector_search_generic_and_metadata_filter(self):
156+
self.collection.upsert_documents(self.documents_with_reviews_metadata)
157+
self.collection.generate_chunks()
158+
self.collection.generate_embeddings()
159+
results=self.collection.vector_search(
160+
"product is abc 21",
161+
top_k=2,
162+
generic_filter="(documents.metadata->>'reviews')::int < 45",
163+
metadata_filter={"source":"amazon"},
164+
)
165+
assertresults[0]["metadata"]["user"]=="John Doe"
166+
167+
# def tearDown(self) -> None:
168+
# self.db.archive_collection(self.collection_name)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp