Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

metadata and generic filters in vector search#689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
santiatpml merged 1 commit intomasterfromsanti-pgml-sdk-metadata-filter
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletionspgml-sdks/python/pgml/examples/question_answering.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -33,10 +33,10 @@

start = time()
query = "Who won 20 grammy awards?"
results = collection.vector_search(query, top_k=5,title="Beyoncé")
results = collection.vector_search(query, top_k=5,metadata_filter={"title" : "Beyoncé"})
_end = time()
console.print("\nResults for '%s'" % (query), style="bold")
console.print(results)
console.print("Query time = %0.3f" % (_end - start))

db.archive_collection(collection_name)
#db.archive_collection(collection_name)
39 changes: 23 additions & 16 deletionspgml-sdks/python/pgml/pgml/collection.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -298,18 +298,23 @@ def upsert_documents(
)
continue

metadata = document

_uuid = ""
if id_key not in list(document.keys()):
log.info("id key is not present.. hashing")
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
source_uuid = hashlib.md5(
(text + " " + json.dumps(document)).encode("utf-8")
).hexdigest()
else:
_uuid = document.pop(id_key)
try:
source_uuid = str(uuid.UUID(_uuid))
except Exception:
source_uuid = hashlib.md5(text.encode("utf-8")).hexdigest()
source_uuid = hashlib.md5(str(_uuid).encode("utf-8")).hexdigest()

metadata = document
if _uuid:
document[id_key] = source_uuid

upsert_statement = "INSERT INTO {documents_table} (text, source_uuid, metadata) VALUES ({text}, {source_uuid}, {metadata}) \
ON CONFLICT (source_uuid) \
Expand All@@ -323,9 +328,6 @@ def upsert_documents(

# put the text and id back in document
document[text_key] = text
if _uuid:
document[id_key] = source_uuid

self.pool.putconn(conn)

def register_text_splitter(
Expand DownExpand Up@@ -683,7 +685,8 @@ def vector_search(
top_k: int = 5,
model_id: int = 1,
splitter_id: int = 1,
**kwargs: Any,
metadata_filter: Optional[Dict[str, Any]] = {},
generic_filter: Optional[str] = "",
) -> List[Dict[str, Any]]:
"""
This function performs a vector search on a database using a query and returns the top matching
Expand DownExpand Up@@ -753,13 +756,6 @@ def vector_search(
% (model_id, splitter_id, model_id, splitter_id)
)
return []

if kwargs:
metadata_filter = [f"documents.metadata->>'{k}' = '{v}'" if isinstance(v, str) else f"documents.metadata->>'{k}' = {v}" for k, v in kwargs.items()]
metadata_filter = " AND ".join(metadata_filter)
metadata_filter = f"AND {metadata_filter}"
else:
metadata_filter = ""

cte_select_statement = """
WITH query_cte AS (
Expand All@@ -775,7 +771,7 @@ def vector_search(
SELECT cte.score, chunks.chunk, documents.metadata
FROM cte
INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id {metadata_filter}
INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
""".format(
model=sql.Literal(model).as_string(conn),
query_text=query,
Expand All@@ -784,9 +780,20 @@ def vector_search(
top_k=top_k,
chunks_table=self.chunks_table,
documents_table=self.documents_table,
metadata_filter=metadata_filter,
)

if metadata_filter:
cte_select_statement += (
" AND documents.metadata @> {metadata_filter}".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

This is nice. We should also add a GIN index todocuments.metadata.

santiatpml reacted with thumbs up emoji
metadata_filter=sql.Literal(json.dumps(metadata_filter)).as_string(
conn
)
)
)

if generic_filter:
cte_select_statement += " AND " + generic_filter

search_results = run_select_statement(
conn, cte_select_statement, order_by="score", ascending=False
)
Expand Down
125 changes: 100 additions & 25 deletionspgml-sdks/python/pgml/tests/test_collection.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,45 +4,86 @@
import hashlib
import os

class TestCollection(unittest.TestCase):

class TestCollection(unittest.TestCase):
def setUp(self) -> None:
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
conninfo = os.environ.get("PGML_CONNECTION",local_pgml)
conninfo = os.environ.get("PGML_CONNECTION",local_pgml)
self.db = Database(conninfo)
self.collection_name = "test_collection_1"
self.documents = [
{
"id": hashlib.md5(f"abcded-{i}".encode('utf-8')).hexdigest(),
"text":f"Lorem ipsum {i}",
"metadata": {"source": "test_suite"}
"id": hashlib.md5(f"abcded-{i}".encode("utf-8")).hexdigest(),
"text":f"Lorem ipsum {i}",
"source": "test_suite",
}
for i in range(4, 7)
]
self.documents_no_ids = [
{
"text":f"Lorem ipsum {i}",
"metadata": {"source": "test_suite_no_ids"}
"text":f"Lorem ipsum {i}",
"source": "test_suite_no_ids",
}
for i in range(1, 4)
]


self.documents_with_metadata = [
{
"text": f"Lorem ipsum metadata",
"source": f"url {i}",
"url": f"/home {i}",
"user": f"John Doe-{i+1}",
}
for i in range(8, 12)
]

self.documents_with_reviews = [
{
"text": f"product is abc {i}",
"reviews": i * 2,
}
for i in range(20, 25)
]

self.documents_with_reviews_metadata = [
{
"text": f"product is abc {i}",
"reviews": i * 2,
"source": "amazon",
"user": "John Doe",
}
for i in range(20, 25)
]

self.documents_with_reviews_metadata += [
{
"text": f"product is abc {i}",
"reviews": i * 2,
"source": "ebay",
}
for i in range(20, 25)
]

self.collection = self.db.create_or_get_collection(self.collection_name)

def test_create_collection(self):
assert isinstance(self.collection,Collection)
assert isinstance(self.collection,Collection)

def test_documents_upsert(self):
self.collection.upsert_documents(self.documents)
conn = self.db.pool.getconn()
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
results = run_select_statement(
conn, "SELECT id FROM %s" % self.collection.documents_table
)
self.db.pool.putconn(conn)
assert len(results) >= len(self.documents)

def test_documents_upsert_no_ids(self):
self.collection.upsert_documents(self.documents_no_ids)
conn = self.db.pool.getconn()
results = run_select_statement(conn,"SELECT id FROM %s"%self.collection.documents_table)
results = run_select_statement(
conn, "SELECT id FROM %s" % self.collection.documents_table
)
self.db.pool.putconn(conn)
assert len(results) >= len(self.documents_no_ids)

Expand All@@ -52,23 +93,25 @@ def test_default_text_splitter(self):

assert splitter_id == 1
assert splitters[0]["name"] == "RecursiveCharacterTextSplitter"

def test_default_embeddings_model(self):
model_id = self.collection.register_model()
models = self.collection.get_models()

assert model_id == 1
assert models[0]["name"] == "intfloat/e5-small"

def test_generate_chunks(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
splitter_id = self.collection.register_text_splitter()
self.collection.generate_chunks(splitter_id=splitter_id)
splitter_params = {"chunk_size": 100, "chunk_overlap":20}
splitter_id = self.collection.register_text_splitter(splitter_params=splitter_params)
splitter_params = {"chunk_size": 100, "chunk_overlap": 20}
splitter_id = self.collection.register_text_splitter(
splitter_params=splitter_params
)
self.collection.generate_chunks(splitter_id=splitter_id)

def test_generate_embeddings(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
Expand All@@ -84,10 +127,42 @@ def test_vector_search(self):
self.collection.generate_embeddings()
results = self.collection.vector_search("Lorem ipsum 1", top_k=2)
assert results[0]["score"] == 1.0

# def tearDown(self) -> None:
# self.db.archive_collection(self.collection_name)

def test_vector_search_metadata_filter(self):
self.collection.upsert_documents(self.documents)
self.collection.upsert_documents(self.documents_no_ids)
self.collection.upsert_documents(self.documents_with_metadata)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"Lorem ipsum metadata",
top_k=2,
metadata_filter={"url": "/home 8", "source": "url 8"},
)
assert results[0]["metadata"]["user"] == "John Doe-9"

def test_vector_search_generic_filter(self):
self.collection.upsert_documents(self.documents_with_reviews)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"product is abc 21",
top_k=2,
generic_filter="(documents.metadata->>'reviews')::int < 45",
)
assert results[0]["metadata"]["reviews"] == 42



def test_vector_search_generic_and_metadata_filter(self):
self.collection.upsert_documents(self.documents_with_reviews_metadata)
self.collection.generate_chunks()
self.collection.generate_embeddings()
results = self.collection.vector_search(
"product is abc 21",
top_k=2,
generic_filter="(documents.metadata->>'reviews')::int < 45",
metadata_filter={"source": "amazon"},
)
assert results[0]["metadata"]["user"] == "John Doe"

# def tearDown(self) -> None:
# self.db.archive_collection(self.collection_name)

[8]ページ先頭

©2009-2025 Movatter.jp