2121)
2222
2323from langchain .text_splitter import RecursiveCharacterTextSplitter
24+ from pypika .queries import Schema ,Table ,Query ,QueryBuilder
2425
2526FORMAT = "%(message)s"
2627logging .basicConfig (
@@ -800,3 +801,82 @@ def vector_search(
800801self .pool .putconn (conn )
801802
802803return search_results
804+
805+ def execute (self ,sql_statement :QueryBuilder )-> List [Dict [str ,Any ]]:
806+ conn = self .pool .getconn ()
807+ results = run_select_statement (conn ,sql_statement .get_sql ().replace ("\" " ,"" ))
808+ self .pool .putconn (conn )
809+ return results
810+
811+ def vector_recall (self ,
812+ query :str ,
813+ query_parameters :Optional [Dict [str ,Any ]]= {},
814+ top_k :int = 5 ,
815+ model_id :int = 1 ,
816+ splitter_id :int = 1 )-> List [Dict [str ,Any ]]:
817+
818+
819+ if model_id in self ._cache_model_names .keys ():
820+ model = self ._cache_model_names [model_id ]
821+ else :
822+ models = Table (self .models_table )
823+ q = Query .from_ (models ).select ('name' ).where (models .id == model_id )
824+ results = self .execute (q )
825+ model = results [0 ]["name" ]
826+ self ._cache_model_names [model_id ]= model
827+
828+ embeddings_table = ""
829+ if model_id in self ._cache_embeddings_table_names .keys ():
830+ if splitter_id in self ._cache_embeddings_table_names [model_id ].keys ():
831+ embeddings_table = self ._cache_embeddings_table_names [model_id ][
832+ splitter_id
833+ ]
834+
835+ if not embeddings_table :
836+ transforms_table = Table (self .transforms_table )
837+ q = Query .from_ (transforms_table ).select ('table_name' ).where (transforms_table .model_id == model_id ).where (transforms_table .splitter_id == splitter_id )
838+ embedding_table_results = self .execute (q )
839+ if embedding_table_results :
840+ embeddings_table = embedding_table_results [0 ]["table_name" ]
841+ self ._cache_embeddings_table_names [model_id ]= {
842+ splitter_id :embeddings_table
843+ }
844+ else :
845+ rprint (
846+ "Embeddings for model id %d and splitter id %d do not exist.\n Please run collection.generate_embeddings(model_id = %d, splitter_id = %d)"
847+ % (model_id ,splitter_id ,model_id ,splitter_id )
848+ )
849+ return []
850+
851+ conn = self .pool .getconn ()
852+ cte_select_statement = """
853+ WITH query_cte AS (
854+ SELECT pgml.embed(transformer => {model}, text => '{query_text}', kwargs => {model_params}) AS query_embedding
855+ ),
856+ cte AS (
857+ SELECT chunk_id, 1 - ({embeddings_table}.embedding <=> query_cte.query_embedding::float8[]::vector) AS score
858+ FROM {embeddings_table}
859+ CROSS JOIN query_cte
860+ ORDER BY score DESC
861+ )
862+ SELECT cte.score, chunks.chunk, documents.metadata
863+ FROM cte
864+ INNER JOIN {chunks_table} chunks ON chunks.id = cte.chunk_id
865+ INNER JOIN {documents_table} documents ON documents.id = chunks.document_id
866+ """ .format (
867+ model = sql .Literal (model ).as_string (conn ),
868+ query_text = query ,
869+ model_params = sql .Literal (json .dumps (query_parameters )).as_string (conn ),
870+ embeddings_table = embeddings_table ,
871+ chunks_table = self .chunks_table ,
872+ documents_table = self .documents_table ,
873+ )
874+
875+ cte_select_statement += " LIMIT {top_k}" .format (top_k = top_k )
876+
877+ search_results = run_select_statement (
878+ conn ,cte_select_statement ,order_by = "score" ,ascending = False
879+ )
880+ self .pool .putconn (conn )
881+
882+ return search_results