Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

SDK - Added re-ranking into document search#1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
SilasMarvin wants to merge2 commits intomaster
base:master
Choose a base branch
Loading
fromsilas-add-re-ranking-into-search
Open
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletionpgml-sdks/pgml/Cargo.lock
View file
Open in desktop

Some generated files are not rendered by default. Learn more abouthow customized files appear on GitHub.

20 changes: 15 additions & 5 deletionspgml-sdks/pgml/src/lib.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -980,7 +980,7 @@ mod tests {
#[tokio::test]
async fn can_search_with_local_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_cswle_123";
let collection_name = "test_r_c_cswle_126";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand DownExpand Up@@ -1038,7 +1038,12 @@ mod tests {
"full_text_search": {
"title": {
"query": "test 9",
"boost": 4.0
"boost": 4.0,
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
}
},
"body": {
"query": "Test",
Expand All@@ -1051,7 +1056,12 @@ mod tests {
"parameters": {
"prompt": "query: ",
},
"boost": 2.0
"boost": 2.0,
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
}
},
"body": {
"query": "This is the body test",
Expand DownExpand Up@@ -1086,7 +1096,7 @@ mod tests {
.iter()
.map(|r| r["document"]["id"].as_u64().unwrap())
.collect();
assert_eq!(ids, vec![9, 3,4, 7, 5]);
assert_eq!(ids, vec![2,9, 3,8, 4]);

let pool = get_or_initialize_pool(&None).await?;

Expand All@@ -1111,7 +1121,7 @@ mod tests {
// Document ids are 1 based in the db not 0 based like they are here
assert_eq!(
search_results.iter().map(|sr| sr.2).collect::<Vec<i64>>(),
vec![10, 4,5, 8, 6]
vec![3,10, 4,9, 5]
);

let event = json!({"clicked": true});
Expand Down
177 changes: 161 additions & 16 deletionspgml-sdks/pgml/src/search_query_builder.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -25,13 +25,15 @@ struct ValidSemanticSearchAction {
query: String,
parameters: Option<Json>,
boost: Option<f32>,
rerank: Option<ValidRerank>,
}

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct ValidFullTextSearchAction {
query: String,
boost: Option<f32>,
rerank: Option<ValidRerank>,
}

#[derive(Debug, Deserialize)]
Expand All@@ -42,6 +44,20 @@ struct ValidQueryActions {
filter: Option<Json>,
}

const fn default_num_documents_to_rerank() -> u64 {
10
}

#[derive(Debug, Deserialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}

const fn default_limit() -> u64 {
10
}
Expand DownExpand Up@@ -106,7 +122,11 @@ pub async fn build_search_query(
// Build the CTE we actually use later
let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key);
let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key);
let cte_name = format!("{key}_embedding_score");
let cte_name = if vsa.rerank.is_some() {
format!("pre_rerank_{key}_embedding_score")
} else {
format!("{key}_embedding_score")
};
let boost = vsa.boost.unwrap_or(1.);
let mut score_cte_non_recursive = Query::select();
let mut score_cte_recurisive = Query::select();
Expand All@@ -131,6 +151,7 @@ pub async fn build_search_query(
score_cte_non_recursive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.join_as(
JoinType::InnerJoin,
chunks_table.to_table_tuple(),
Expand All@@ -157,6 +178,7 @@ pub async fn build_search_query(
score_cte_recurisive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#)))
.expr(Expr::cust(format!(
r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"#
Expand DownExpand Up@@ -213,6 +235,7 @@ pub async fn build_search_query(
score_cte_non_recursive
.from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings"))
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust("ARRAY[documents.id] as previous_document_ids"))
.expr(Expr::cust_with_values(
format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"),
Expand DownExpand Up@@ -249,6 +272,7 @@ pub async fn build_search_query(
Expr::cust("1 = 1"),
)
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr(Expr::cust(format!(
r#""{cte_name}".previous_document_ids || documents.id"#
)))
Expand DownExpand Up@@ -295,18 +319,75 @@ pub async fn build_search_query(
.from_subquery(score_cte_non_recursive, Alias::new("non_recursive"))
.union(sea_query::UnionType::All, score_cte_recurisive)
.to_owned();

let mut score_cte = CommonTableExpression::from_select(score_cte);
score_cte.table_name(Alias::new(&cte_name));
with_clause.cte(score_cte);

if let Some(rerank) = vsa.rerank {
// Add our row_number_pre_rerank CTE
let mut row_number_pre_rerank = Query::select();
row_number_pre_rerank
.column(SIden::Str("id"))
.column(SIden::Str("chunk"))
.from(SIden::String(cte_name.clone()))
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"))
.limit(rerank.num_documents_to_rerank);
let mut row_number_pre_rerank_cte =
CommonTableExpression::from_select(row_number_pre_rerank);
row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}")));
with_clause.cte(row_number_pre_rerank_cte);

// Our actual CTE
let mut query = Query::select();
query.column(SIden::Str("id"));
query.expr_as(
Expr::cust(format!("(rank).score * {boost}")),
Alias::new("score"),
);

// Build the actual CTE
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("row_number_{cte_name}")));

let mut sub_query = Query::select();
sub_query
.columns([SIden::Str("id"), SIden::Str("rank")])
.from_as(
SIden::String(format!("row_number_{cte_name}")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

query.from_subquery(sub_query, Alias::new("sub_query"));
let mut query_cte = CommonTableExpression::from_select(query);
query_cte.table_name(Alias::new(format!("{key}_embedding_score")));
with_clause.cte(query_cte);
}

// Add to the sum expression
sum_expression = if let Some(expr) = sum_expression {
Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))))
Some(expr.add(Expr::cust(format!(
r#"COALESCE("{key}_embedding_score".score, 0.0)"#
))))
} else {
Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))
Some(Expr::cust(format!(
r#"COALESCE("{key}_embedding_score".score, 0.0)"#
)))
};
score_table_names.push(cte_name);
score_table_names.push(format!("{key}_embedding_score"));
}

for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() {
Expand All@@ -315,10 +396,15 @@ pub async fn build_search_query(
let boost = vma.boost.unwrap_or(1.0);

// Build the score CTE
let cte_name = format!("{key}_tsvectors_score");
let cte_name = if vma.rerank.is_some() {
format!("pre_rerank_{key}_tsvectors_score")
} else {
format!("{key}_tsvectors_score")
};

let mut score_cte_non_recursive = Query::select()
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr_as(
Expr::cust_with_values(
format!(
Expand DownExpand Up@@ -361,6 +447,7 @@ pub async fn build_search_query(

let mut score_cte_recursive = Query::select()
.column((SIden::Str("documents"), SIden::Str("id")))
.column((SIden::Str("chunks"), SIden::Str("chunk")))
.expr_as(
Expr::cust_with_values(
format!(
Expand DownExpand Up@@ -425,13 +512,71 @@ pub async fn build_search_query(
score_cte.table_name(Alias::new(&cte_name));
with_clause.cte(score_cte);

if let Some(rerank) = vma.rerank {
// Add our row_number_pre_rerank CTE
let mut row_number_pre_rerank = Query::select();
row_number_pre_rerank
.column(SIden::Str("id"))
.column(SIden::Str("chunk"))
.from(SIden::String(cte_name.clone()))
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"))
.limit(rerank.num_documents_to_rerank);
let mut row_number_pre_rerank_cte =
CommonTableExpression::from_select(row_number_pre_rerank);
row_number_pre_rerank_cte.table_name(Alias::new(format!("row_number_{cte_name}")));
with_clause.cte(row_number_pre_rerank_cte);

// Our actual CTE
let mut query = Query::select();
query.column(SIden::Str("id"));
query.expr_as(
Expr::cust(format!("(rank).score * {boost}")),
Alias::new("score"),
);

// Build the actual CTE
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("row_number_{cte_name}")));

let mut sub_query = Query::select();
sub_query
.columns([SIden::Str("id"), SIden::Str("rank")])
.from_as(
SIden::String(format!("row_number_{cte_name}")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

query.from_subquery(sub_query, Alias::new("sub_query"));
let mut query_cte = CommonTableExpression::from_select(query);
query_cte.table_name(Alias::new(format!("{key}_tsvectors_score")));
with_clause.cte(query_cte);
}

// Add to the sum expression
sum_expression = if let Some(expr) = sum_expression {
Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))))
Some(expr.add(Expr::cust(format!(
r#"COALESCE("{key}_tsvectors_score".score, 0.0)"#
))))
} else {
Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))
Some(Expr::cust(format!(
r#"COALESCE("{key}_tsvectors_score".score, 0.0)"#
)))
};
score_table_names.push(cte_name);
score_table_names.push(format!("{key}_tsvectors_score"));
}

let query = if let Some(select_from) = score_table_names.first() {
Expand All@@ -440,9 +585,9 @@ pub async fn build_search_query(
.into_iter()
.map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into())
.collect();
let mutmain_query = Query::select();
let mutjoined_query = Query::select();
for i in 1..score_table_names_e.len() {
main_query.full_outer_join(
joined_query.full_outer_join(
SIden::String(score_table_names[i].to_string()),
Expr::col((
SIden::String(score_table_names[i].to_string()),
Expand All@@ -455,7 +600,8 @@ pub async fn build_search_query(

let sum_expression = sum_expression
.context("query requires some scoring through full_text_search or semantic_search")?;
main_query

joined_query
.expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id"))
.expr_as(sum_expression, Alias::new("score"))
.column(SIden::Str("document"))
Expand All@@ -468,10 +614,9 @@ pub async fn build_search_query(
)
.order_by(SIden::Str("score"), Order::Desc)
.limit(limit);

let mut main_query = CommonTableExpression::from_select(main_query);
main_query.table_name(Alias::new("main"));
with_clause.cte(main_query);
let mut joined_query = CommonTableExpression::from_select(joined_query);
joined_query.table_name(Alias::new("main"));
with_clause.cte(joined_query);

// Insert into searches table
let searches_table = format!("{}_{}.searches", collection.name, pipeline.name);
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp