44import hashlib
55import os
66
7- class TestCollection (unittest .TestCase ):
87
8+ class TestCollection (unittest .TestCase ):
99def setUp (self )-> None :
1010local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
11- conninfo = os .environ .get ("PGML_CONNECTION" ,local_pgml )
11+ conninfo = os .environ .get ("PGML_CONNECTION" ,local_pgml )
1212self .db = Database (conninfo )
1313self .collection_name = "test_collection_1"
1414self .documents = [
1515 {
16- "id" :hashlib .md5 (f"abcded-{ i } " .encode (' utf-8' )).hexdigest (),
17- "text" :f"Lorem ipsum{ i } " ,
18- "metadata" : { "source" :"test_suite" }
16+ "id" :hashlib .md5 (f"abcded-{ i } " .encode (" utf-8" )).hexdigest (),
17+ "text" :f"Lorem ipsum{ i } " ,
18+ "source" :"test_suite" ,
1919 }
2020for i in range (4 ,7 )
2121 ]
2222self .documents_no_ids = [
2323 {
24- "text" :f"Lorem ipsum{ i } " ,
25- "metadata" : { "source" :"test_suite_no_ids" }
24+ "text" :f"Lorem ipsum{ i } " ,
25+ "source" :"test_suite_no_ids" ,
2626 }
2727for i in range (1 ,4 )
2828 ]
29-
29+
30+ self .documents_with_metadata = [
31+ {
32+ "text" :f"Lorem ipsum metadata" ,
33+ "source" :f"url{ i } " ,
34+ "url" :f"/home{ i } " ,
35+ "user" :f"John Doe-{ i + 1 } " ,
36+ }
37+ for i in range (8 ,12 )
38+ ]
39+
40+ self .documents_with_reviews = [
41+ {
42+ "text" :f"product is abc{ i } " ,
43+ "reviews" :i * 2 ,
44+ }
45+ for i in range (20 ,25 )
46+ ]
47+
48+ self .documents_with_reviews_metadata = [
49+ {
50+ "text" :f"product is abc{ i } " ,
51+ "reviews" :i * 2 ,
52+ "source" :"amazon" ,
53+ "user" :"John Doe" ,
54+ }
55+ for i in range (20 ,25 )
56+ ]
57+
58+ self .documents_with_reviews_metadata += [
59+ {
60+ "text" :f"product is abc{ i } " ,
61+ "reviews" :i * 2 ,
62+ "source" :"ebay" ,
63+ }
64+ for i in range (20 ,25 )
65+ ]
66+
3067self .collection = self .db .create_or_get_collection (self .collection_name )
31-
68+
3269def test_create_collection (self ):
33- assert isinstance (self .collection ,Collection )
34-
70+ assert isinstance (self .collection ,Collection )
71+
3572def test_documents_upsert (self ):
3673self .collection .upsert_documents (self .documents )
3774conn = self .db .pool .getconn ()
38- results = run_select_statement (conn ,"SELECT id FROM %s" % self .collection .documents_table )
75+ results = run_select_statement (
76+ conn ,"SELECT id FROM %s" % self .collection .documents_table
77+ )
3978self .db .pool .putconn (conn )
4079assert len (results )>= len (self .documents )
41-
80+
4281def test_documents_upsert_no_ids (self ):
4382self .collection .upsert_documents (self .documents_no_ids )
4483conn = self .db .pool .getconn ()
45- results = run_select_statement (conn ,"SELECT id FROM %s" % self .collection .documents_table )
84+ results = run_select_statement (
85+ conn ,"SELECT id FROM %s" % self .collection .documents_table
86+ )
4687self .db .pool .putconn (conn )
4788assert len (results )>= len (self .documents_no_ids )
4889
@@ -52,23 +93,25 @@ def test_default_text_splitter(self):
5293
5394assert splitter_id == 1
5495assert splitters [0 ]["name" ]== "RecursiveCharacterTextSplitter"
55-
96+
5697def test_default_embeddings_model (self ):
5798model_id = self .collection .register_model ()
5899models = self .collection .get_models ()
59-
100+
60101assert model_id == 1
61102assert models [0 ]["name" ]== "intfloat/e5-small"
62-
103+
63104def test_generate_chunks (self ):
64105self .collection .upsert_documents (self .documents )
65106self .collection .upsert_documents (self .documents_no_ids )
66107splitter_id = self .collection .register_text_splitter ()
67108self .collection .generate_chunks (splitter_id = splitter_id )
68- splitter_params = {"chunk_size" :100 ,"chunk_overlap" :20 }
69- splitter_id = self .collection .register_text_splitter (splitter_params = splitter_params )
109+ splitter_params = {"chunk_size" :100 ,"chunk_overlap" :20 }
110+ splitter_id = self .collection .register_text_splitter (
111+ splitter_params = splitter_params
112+ )
70113self .collection .generate_chunks (splitter_id = splitter_id )
71-
114+
72115def test_generate_embeddings (self ):
73116self .collection .upsert_documents (self .documents )
74117self .collection .upsert_documents (self .documents_no_ids )
@@ -84,10 +127,42 @@ def test_vector_search(self):
84127self .collection .generate_embeddings ()
85128results = self .collection .vector_search ("Lorem ipsum 1" ,top_k = 2 )
86129assert results [0 ]["score" ]== 1.0
87-
88- # def tearDown(self) -> None:
89- # self.db.archive_collection(self.collection_name)
90130
131+ def test_vector_search_metadata_filter (self ):
132+ self .collection .upsert_documents (self .documents )
133+ self .collection .upsert_documents (self .documents_no_ids )
134+ self .collection .upsert_documents (self .documents_with_metadata )
135+ self .collection .generate_chunks ()
136+ self .collection .generate_embeddings ()
137+ results = self .collection .vector_search (
138+ "Lorem ipsum metadata" ,
139+ top_k = 2 ,
140+ metadata_filter = {"url" :"/home 8" ,"source" :"url 8" },
141+ )
142+ assert results [0 ]["metadata" ]["user" ]== "John Doe-9"
143+
144+ def test_vector_search_generic_filter (self ):
145+ self .collection .upsert_documents (self .documents_with_reviews )
146+ self .collection .generate_chunks ()
147+ self .collection .generate_embeddings ()
148+ results = self .collection .vector_search (
149+ "product is abc 21" ,
150+ top_k = 2 ,
151+ generic_filter = "(documents.metadata->>'reviews')::int < 45" ,
152+ )
153+ assert results [0 ]["metadata" ]["reviews" ]== 42
91154
92-
93-
155+ def test_vector_search_generic_and_metadata_filter (self ):
156+ self .collection .upsert_documents (self .documents_with_reviews_metadata )
157+ self .collection .generate_chunks ()
158+ self .collection .generate_embeddings ()
159+ results = self .collection .vector_search (
160+ "product is abc 21" ,
161+ top_k = 2 ,
162+ generic_filter = "(documents.metadata->>'reviews')::int < 45" ,
163+ metadata_filter = {"source" :"amazon" },
164+ )
165+ assert results [0 ]["metadata" ]["user" ]== "John Doe"
166+
167+ # def tearDown(self) -> None:
168+ # self.db.archive_collection(self.collection_name)