2121)
2222
2323from langchain .text_splitter import RecursiveCharacterTextSplitter
24- from pypika import Query ,Table ,AliasedQuery ,Order ,Field
24+ from pypika import JSON ,Array ,Table ,AliasedQuery ,Order
25+ from pypika import Query
26+
2527from pypika .queries import QueryBuilder
2628from pypika .functions import Cast
2729from .queries import Embed ,CosineDistance
2830
31+ _cache_model_names = {}
32+ _cache_embeddings_table_names = {}
33+
2934
3035FORMAT = "%(message)s"
3136logging .basicConfig (
@@ -808,47 +813,64 @@ def vector_search(
808813
809814def execute (self ,sql_statement :QueryBuilder )-> List [Dict [str ,Any ]]:
810815conn = self .pool .getconn ()
811- results = run_select_statement (conn ,sql_statement . get_sql ( ))
816+ results = run_select_statement (conn ,str ( sql_statement ))
812817self .pool .putconn (conn )
813818return results
814819
820+ def query (self ):
821+ return PGMLQuery (self )
822+
823+
824+ class PGMLQuery (QueryBuilder ):
825+ def __init__ (self ,collection :Collection )-> None :
826+ self .collection = collection
827+
828+ def __str__ (self )-> str :
829+ return self .get_sql ()
830+
831+ def limit (self ,_limit :int ):
832+ self = self .limit (_limit )
833+ return self
834+
815835def vector_recall (
816836self ,
817837query :str ,
818838query_parameters :Optional [Dict [str ,Any ]]= {},
839+ top_k :int = 5 ,
819840model_id :int = 1 ,
820841splitter_id :int = 1 ,
821- ) -> List [ Dict [ str , Any ]] :
822- if model_id in self . _cache_model_names .keys ():
823- model = self . _cache_model_names [model_id ]
842+ ) :
843+ if model_id in _cache_model_names .keys ():
844+ model = _cache_model_names [model_id ]
824845else :
825- models = Table (self .models_table .split ("." )[1 ],schema = self .name )
846+ models = Table (
847+ self .collection .models_table .split ("." )[1 ],schema = self .collection .name
848+ )
826849q = Query .from_ (models ).select ("name" ).where (models .id == model_id )
827- results = self .execute (q )
850+ results = self .collection . execute (q )
828851model = results [0 ]["name" ]
829- self . _cache_model_names [model_id ]= model
852+ _cache_model_names [model_id ]= model
830853
831854embeddings_table = ""
832- if model_id in self ._cache_embeddings_table_names .keys ():
833- if splitter_id in self ._cache_embeddings_table_names [model_id ].keys ():
834- embeddings_table = self ._cache_embeddings_table_names [model_id ][
835- splitter_id
836- ]
855+ if model_id in _cache_embeddings_table_names .keys ():
856+ if splitter_id in _cache_embeddings_table_names [model_id ].keys ():
857+ embeddings_table = _cache_embeddings_table_names [model_id ][splitter_id ]
837858
838859if not embeddings_table :
839860transforms_table = Table (
840- self .transforms_table .split ("." )[1 ],schema = self .name
861+ self .collection .transforms_table .split ("." )[1 ],
862+ schema = self .collection .name ,
841863 )
842864q = (
843865Query .from_ (transforms_table )
844866 .select ("table_name" )
845867 .where (transforms_table .model_id == model_id )
846868 .where (transforms_table .splitter_id == splitter_id )
847869 )
848- embedding_table_results = self .execute (q )
870+ embedding_table_results = self .collection . execute (q )
849871if embedding_table_results :
850872embeddings_table = embedding_table_results [0 ]["table_name" ]
851- self . _cache_embeddings_table_names [model_id ]= {
873+ _cache_embeddings_table_names [model_id ]= {
852874splitter_id :embeddings_table
853875 }
854876else :
@@ -859,12 +881,12 @@ def vector_recall(
859881return []
860882
861883embeddings_table = embeddings_table .split ("." )[1 ]
862- chunks_table = self .chunks_table .split ("." )[1 ]
863- documents_table = self .documents_table .split ("." )[1 ]
884+ chunks_table = self .collection . chunks_table .split ("." )[1 ]
885+ documents_table = self .collection . documents_table .split ("." )[1 ]
864886
865- embeddings_table = Table (embeddings_table ,schema = self .name )
866- chunks_table = Table (chunks_table ,schema = self .name )
867- documents_table = Table (documents_table ,schema = self .name )
887+ embeddings_table = Table (embeddings_table ,schema = self .collection . name )
888+ chunks_table = Table (chunks_table ,schema = self .collection . name )
889+ documents_table = Table (documents_table ,schema = self .collection . name )
868890
869891query_embed = Query ().select (
870892Embed (transformer = model ,text = query ,parameters = query_parameters )
@@ -887,7 +909,7 @@ def vector_recall(
887909 .cross ()
888910 )
889911
890- query_cte = (
912+ self = (
891913Query ()
892914 .with_ (query_embed ,"query_cte" )
893915 .with_ (table_embed ,"cte" )
@@ -897,7 +919,7 @@ def vector_recall(
897919 .inner_join (chunks_table )
898920 .on (chunks_table .id == cte .chunk_id )
899921 .inner_join (documents_table )
900- .on (documents_table .id == chunks_table .document_id )
922+ .on (documents_table .id == chunks_table .document_id ). limit ( top_k )
901923 )
902924
903- return query_cte
925+ return self