Expand Up @@ -25,13 +25,15 @@ struct ValidSemanticSearchAction { query: String, parameters: Option<Json>, boost: Option<f32>, rerank: Option<ValidRerank>, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] struct ValidFullTextSearchAction { query: String, boost: Option<f32>, rerank: Option<ValidRerank>, } #[derive(Debug, Deserialize)] Expand All @@ -42,6 +44,20 @@ struct ValidQueryActions { filter: Option<Json>, } const fn default_num_documents_to_rerank() -> u64 { 10 } #[derive(Debug, Deserialize, Clone)] #[serde(deny_unknown_fields)] struct ValidRerank { query: String, model: String, #[serde(default = "default_num_documents_to_rerank")] num_documents_to_rerank: u64, parameters: Option<Json>, } const fn default_limit() -> u64 { 10 } Expand Down Expand Up @@ -106,7 +122,11 @@ pub async fn build_search_query( // Build the CTE we actually use later let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); let cte_name = format!("{key}_embedding_score"); let cte_name = if vsa.rerank.is_some() { format!("pre_rerank_{key}_embedding_score") } else { format!("{key}_embedding_score") }; let boost = vsa.boost.unwrap_or(1.); let mut score_cte_non_recursive = Query::select(); let mut score_cte_recurisive = Query::select(); Expand All @@ -131,6 +151,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .join_as( JoinType::InnerJoin, chunks_table.to_table_tuple(), Expand All @@ -157,6 +178,7 @@ pub async fn build_search_query( score_cte_recurisive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) .expr(Expr::cust(format!( r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# Expand Down Expand Up @@ -213,6 +235,7 @@ pub async fn build_search_query( score_cte_non_recursive .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) .expr(Expr::cust_with_values( format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), Expand Down Expand Up @@ -249,6 +272,7 @@ pub async fn build_search_query( Expr::cust("1 = 1"), ) .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr(Expr::cust(format!( r#""{cte_name}".previous_document_ids || documents.id"# ))) Expand Down Expand Up @@ -295,18 +319,75 @@ pub async fn build_search_query( .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) .union(sea_query::UnionType::All, score_cte_recurisive) .to_owned(); let mut score_cte = CommonTableExpression::from_select(score_cte); score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); if let Some(rerank) = vsa.rerank { // Add our row_number_pre_rerank CTE let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); let mut row_number_pre_rerank_cte = CommonTableExpression::from_select(row_number_pre_rerank); row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); with_clause.cte(row_number_pre_rerank_cte); // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); query.expr_as( Expr::cust(format!("(rank).score * {boost}")), Alias::new("score"), ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); let parameters_expr = Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); sub_query_rank_call.expr_as(Expr::cust_with_exprs( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query .columns([SIden::Str("id"), SIden::Str("rank")]) .from_as( SIden::String(format!("row_number_{cte_name}")), Alias::new("rnsv1"), ) .join_subquery( JoinType::InnerJoin, sub_query_rank_call, Alias::new("rnsv2"), Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), ); query.from_subquery(sub_query, Alias::new("sub_query")); let mut query_cte = CommonTableExpression::from_select(query); query_cte.table_name(Alias::new(format!("{key}_embedding_score"))); with_clause.cte(query_cte); } // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) Some(expr.add(Expr::cust(format!( r#"COALESCE("{key}_embedding_score".score, 0.0)"# )))) } else { Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) Some(Expr::cust(format!( r#"COALESCE("{key}_embedding_score".score, 0.0)"# ))) }; score_table_names.push(cte_name ); score_table_names.push(format!("{key}_embedding_score") ); } for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { Expand All @@ -315,10 +396,15 @@ pub async fn build_search_query( let boost = vma.boost.unwrap_or(1.0); // Build the score CTE let cte_name = format!("{key}_tsvectors_score"); let cte_name = if vma.rerank.is_some() { format!("pre_rerank_{key}_tsvectors_score") } else { format!("{key}_tsvectors_score") }; let mut score_cte_non_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( Expand Down Expand Up @@ -361,6 +447,7 @@ pub async fn build_search_query( let mut score_cte_recursive = Query::select() .column((SIden::Str("documents"), SIden::Str("id"))) .column((SIden::Str("chunks"), SIden::Str("chunk"))) .expr_as( Expr::cust_with_values( format!( Expand Down Expand Up @@ -425,13 +512,71 @@ pub async fn build_search_query( score_cte.table_name(Alias::new(&cte_name)); with_clause.cte(score_cte); if let Some(rerank) = vma.rerank { // Add our row_number_pre_rerank CTE let mut row_number_pre_rerank = Query::select(); row_number_pre_rerank .column(SIden::Str("id")) .column(SIden::Str("chunk")) .from(SIden::String(cte_name.clone())) .expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number")) .limit(rerank.num_documents_to_rerank); let mut row_number_pre_rerank_cte = CommonTableExpression::from_select(row_number_pre_rerank); row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}"))); with_clause.cte(row_number_pre_rerank_cte); // Our actual CTE let mut query = Query::select(); query.column(SIden::Str("id")); query.expr_as( Expr::cust(format!("(rank).score * {boost}")), Alias::new("score"), ); // Build the actual CTE let mut sub_query_rank_call = Query::select(); let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]); let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]); let parameters_expr = Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]); sub_query_rank_call.expr_as(Expr::cust_with_exprs( format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit), [model_expr, query_expr, parameters_expr], ), Alias::new("rank")) .from(SIden::String(format!("row_number_{cte_name}"))); let mut sub_query = Query::select(); sub_query .columns([SIden::Str("id"), SIden::Str("rank")]) .from_as( SIden::String(format!("row_number_{cte_name}")), Alias::new("rnsv1"), ) .join_subquery( JoinType::InnerJoin, sub_query_rank_call, Alias::new("rnsv2"), Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"), ); query.from_subquery(sub_query, Alias::new("sub_query")); let mut query_cte = CommonTableExpression::from_select(query); query_cte.table_name(Alias::new(format!("{key}_tsvectors_score"))); with_clause.cte(query_cte); } // Add to the sum expression sum_expression = if let Some(expr) = sum_expression { Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) Some(expr.add(Expr::cust(format!( r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# )))) } else { Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) Some(Expr::cust(format!( r#"COALESCE("{key}_tsvectors_score".score, 0.0)"# ))) }; score_table_names.push(cte_name ); score_table_names.push(format!("{key}_tsvectors_score") ); } let query = if let Some(select_from) = score_table_names.first() { Expand All @@ -440,9 +585,9 @@ pub async fn build_search_query( .into_iter() .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) .collect(); let mutmain_query = Query::select(); let mutjoined_query = Query::select(); for i in 1..score_table_names_e.len() { main_query .full_outer_join(joined_query .full_outer_join( SIden::String(score_table_names[i].to_string()), Expr::col(( SIden::String(score_table_names[i].to_string()), Expand All @@ -455,7 +600,8 @@ pub async fn build_search_query( let sum_expression = sum_expression .context("query requires some scoring through full_text_search or semantic_search")?; main_query joined_query .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) .expr_as(sum_expression, Alias::new("score")) .column(SIden::Str("document")) Expand All @@ -468,10 +614,9 @@ pub async fn build_search_query( ) .order_by(SIden::Str("score"), Order::Desc) .limit(limit); let mut main_query = CommonTableExpression::from_select(main_query); main_query.table_name(Alias::new("main")); with_clause.cte(main_query); let mut joined_query = CommonTableExpression::from_select(joined_query); joined_query.table_name(Alias::new("main")); with_clause.cte(joined_query); // Insert into searches table let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); Expand Down