2020run_select_statement ,
2121)
2222
23- from .queries import Embed ,CosineDistance
2423from langchain .text_splitter import RecursiveCharacterTextSplitter
25- from pypika .queries import Schema ,Table ,Query ,QueryBuilder
24+ from pypika import Query ,Table ,AliasedQuery ,Order ,Field
25+ from pypika .queries import QueryBuilder
26+ from pypika .functions import Cast
27+ from .queries import Embed ,CosineDistance
28+
2629
2730FORMAT = "%(message)s"
2831logging .basicConfig (
@@ -805,22 +808,21 @@ def vector_search(
805808
806809def execute (self ,sql_statement :QueryBuilder )-> List [Dict [str ,Any ]]:
807810conn = self .pool .getconn ()
808- results = run_select_statement (conn ,sql_statement .get_sql (). replace ( '"' , "" ) )
811+ results = run_select_statement (conn ,sql_statement .get_sql ())
809812self .pool .putconn (conn )
810813return results
811814
812815def vector_recall (
813816self ,
814817query :str ,
815818query_parameters :Optional [Dict [str ,Any ]]= {},
816- top_k :int = 5 ,
817819model_id :int = 1 ,
818820splitter_id :int = 1 ,
819821 )-> List [Dict [str ,Any ]]:
820822if model_id in self ._cache_model_names .keys ():
821823model = self ._cache_model_names [model_id ]
822824else :
823- models = Table (self .models_table )
825+ models = Table (self .models_table . split ( "." )[ 1 ], schema = self . name )
824826q = Query .from_ (models ).select ("name" ).where (models .id == model_id )
825827results = self .execute (q )
826828model = results [0 ]["name" ]
@@ -834,7 +836,9 @@ def vector_recall(
834836 ]
835837
836838if not embeddings_table :
837- transforms_table = Table (self .transforms_table )
839+ transforms_table = Table (
840+ self .transforms_table .split ("." )[1 ],schema = self .name
841+ )
838842q = (
839843Query .from_ (transforms_table )
840844 .select ("table_name" )
@@ -854,47 +858,43 @@ def vector_recall(
854858 )
855859return []
856860
857- conn = self .pool .getconn ()
861+ embeddings_table = embeddings_table .split ("." )[1 ]
862+ chunks_table = self .chunks_table .split ("." )[1 ]
863+ documents_table = self .documents_table .split ("." )[1 ]
864+
865+ embeddings_table = Table (embeddings_table ,schema = self .name )
866+ chunks_table = Table (chunks_table ,schema = self .name )
867+ documents_table = Table (documents_table ,schema = self .name )
858868
859- cte_query = Query .select (
869+ query_embed = Query () .select (
860870Embed (transformer = model ,text = query ,parameters = query_parameters )
861- ).with_ ()
862- table_embedding = (
863- Query .from_ (embeddings_table )
871+ )
872+ query_cte = AliasedQuery ("query_cte" )
873+ cte = AliasedQuery ("cte" )
874+ table_embed = (
875+ Query ()
876+ .from_ (embeddings_table )
864877 .select (
865878"chunk_id" ,
866- CosineDistance (embeddings_table .embedding ,query_embedding .cosine ),
879+ CosineDistance (
880+ embeddings_table .embedding ,Cast (query_cte .embedding ,"vector" )
881+ ).as_ ("score" ),
867882 )
868- .cross_join (query_embedding )
869- )
870- cte_select_statement = """
871- WITH query_cte AS (
872- SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
873- ),
874- cte AS (
875- SELECT chunk_id, 1 - ({embeddings_table}.embedding <=> query_cte.query_embedding::float8[]::vector) AS score
876- FROM {embeddings_table}
877- CROSS JOIN query_cte
878- ORDER BY score DESC
883+ .inner_join (AliasedQuery ("query_cte" ))
884+ .cross ()
879885 )
880- SELECT cte.score, chunks.chunk, documents.metadata
881- FROM cte
882- INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
883- INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
884- """ .format (
885- model = sql .Literal (model ).as_string (conn ),
886- query_text = query ,
887- model_params = sql .Literal (json .dumps (query_parameters )).as_string (conn ),
888- embeddings_table = embeddings_table ,
889- chunks_table = self .chunks_table ,
890- documents_table = self .documents_table ,
891- )
892-
893- cte_select_statement += " LIMIT {top_k}" .format (top_k = top_k )
894886
895- search_results = run_select_statement (
896- conn ,cte_select_statement ,order_by = "score" ,ascending = False
887+ query_cte = (
888+ Query ()
889+ .with_ (query_embed ,"query_cte" )
890+ .with_ (table_embed ,"cte" )
891+ .from_ ("cte" )
892+ .select (cte .score ,chunks_table .chunk ,documents_table .metadata )
893+ .orderby (cte .score ,order = Order .desc )
894+ .inner_join (chunks_table )
895+ .on (chunks_table .id == cte .chunk_id )
896+ .inner_join (documents_table )
897+ .on (documents_table .id == chunks_table .document_id )
897898 )
898- self .pool .putconn (conn )
899899
900- return search_results
900+ return query_cte