Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit16a7799

Browse files
committed
Finalized re-ranking in document search
1 parentd364962 commit16a7799

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

‎pgml-sdks/pgml/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ mod tests {
980980
#[tokio::test]
981981
asyncfncan_search_with_local_embeddings() -> anyhow::Result<()>{
982982
internal_init_logger(None,None).ok();
983-
let collection_name ="test_r_c_cswle_123";
983+
let collection_name ="test_r_c_cswle_126";
984984
letmut collection =Collection::new(collection_name,None)?;
985985
let documents =generate_dummy_documents(10);
986986
collection.upsert_documents(documents.clone(),None).await?;
@@ -1096,7 +1096,7 @@ mod tests {
10961096
.iter()
10971097
.map(|r| r["document"]["id"].as_u64().unwrap())
10981098
.collect();
1099-
assert_eq!(ids, vec![9,3,4,5,6]);
1099+
assert_eq!(ids, vec![2,9,3,8,4]);
11001100

11011101
let pool =get_or_initialize_pool(&None).await?;
11021102

@@ -1121,7 +1121,7 @@ mod tests {
11211121
// Document ids are 1 based in the db not 0 based like they are here
11221122
assert_eq!(
11231123
search_results.iter().map(|sr| sr.2).collect::<Vec<i64>>(),
1124-
vec![10,4,5,8,6]
1124+
vec![3,10,4,9,5]
11251125
);
11261126

11271127
let event =json!({"clicked":true});

‎pgml-sdks/pgml/src/search_query_builder.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ pub async fn build_search_query(
151151
score_cte_non_recursive
152152
.from_as(embeddings_table.to_table_tuple(),Alias::new("embeddings"))
153153
.column((SIden::Str("documents"),SIden::Str("id")))
154+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
154155
.join_as(
155156
JoinType::InnerJoin,
156157
chunks_table.to_table_tuple(),
@@ -177,6 +178,7 @@ pub async fn build_search_query(
177178
score_cte_recurisive
178179
.from_as(embeddings_table.to_table_tuple(),Alias::new("embeddings"))
179180
.column((SIden::Str("documents"),SIden::Str("id")))
181+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
180182
.expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#)))
181183
.expr(Expr::cust(format!(
182184
r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"#
@@ -233,6 +235,7 @@ pub async fn build_search_query(
233235
score_cte_non_recursive
234236
.from_as(embeddings_table.to_table_tuple(),Alias::new("embeddings"))
235237
.column((SIden::Str("documents"),SIden::Str("id")))
238+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
236239
.expr(Expr::cust("ARRAY[documents.id] as previous_document_ids"))
237240
.expr(Expr::cust_with_values(
238241
format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"),
@@ -269,6 +272,7 @@ pub async fn build_search_query(
269272
Expr::cust("1 = 1"),
270273
)
271274
.column((SIden::Str("documents"),SIden::Str("id")))
275+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
272276
.expr(Expr::cust(format!(
273277
r#""{cte_name}".previous_document_ids || documents.id"#
274278
)))
@@ -324,6 +328,7 @@ pub async fn build_search_query(
324328
letmut row_number_pre_rerank =Query::select();
325329
row_number_pre_rerank
326330
.column(SIden::Str("id"))
331+
.column(SIden::Str("chunk"))
327332
.from(SIden::String(cte_name.clone()))
328333
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"),Alias::new("row_number"))
329334
.limit(rerank.num_documents_to_rerank);
@@ -335,7 +340,10 @@ pub async fn build_search_query(
335340
// Our actual CTE
336341
letmut query =Query::select();
337342
query.column(SIden::Str("id"));
338-
query.expr_as(Expr::cust("(rank).score"),Alias::new("score"));
343+
query.expr_as(
344+
Expr::cust(format!("(rank).score * {boost}")),
345+
Alias::new("score"),
346+
);
339347

340348
// Build the actual CTE
341349
letmut sub_query_rank_call =Query::select();
@@ -347,14 +355,7 @@ pub async fn build_search_query(
347355
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
348356
[model_expr, query_expr, parameters_expr],
349357
),Alias::new("rank"))
350-
.from(SIden::String(format!("row_number_{cte_name}")))
351-
.join_as(
352-
JoinType::InnerJoin,
353-
chunks_table.to_table_tuple(),
354-
Alias::new("chunks"),
355-
Expr::col((SIden::Str("chunks"),SIden::Str("id")))
356-
.equals((SIden::String(format!("row_number_{cte_name}")),SIden::Str("id"))),
357-
);
358+
.from(SIden::String(format!("row_number_{cte_name}")));
358359

359360
letmut sub_query =Query::select();
360361
sub_query
@@ -403,6 +404,7 @@ pub async fn build_search_query(
403404

404405
letmut score_cte_non_recursive =Query::select()
405406
.column((SIden::Str("documents"),SIden::Str("id")))
407+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
406408
.expr_as(
407409
Expr::cust_with_values(
408410
format!(
@@ -445,6 +447,7 @@ pub async fn build_search_query(
445447

446448
letmut score_cte_recursive =Query::select()
447449
.column((SIden::Str("documents"),SIden::Str("id")))
450+
.column((SIden::Str("chunks"),SIden::Str("chunk")))
448451
.expr_as(
449452
Expr::cust_with_values(
450453
format!(
@@ -514,6 +517,7 @@ pub async fn build_search_query(
514517
letmut row_number_pre_rerank =Query::select();
515518
row_number_pre_rerank
516519
.column(SIden::Str("id"))
520+
.column(SIden::Str("chunk"))
517521
.from(SIden::String(cte_name.clone()))
518522
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"),Alias::new("row_number"))
519523
.limit(rerank.num_documents_to_rerank);
@@ -525,7 +529,10 @@ pub async fn build_search_query(
525529
// Our actual CTE
526530
letmut query =Query::select();
527531
query.column(SIden::Str("id"));
528-
query.expr_as(Expr::cust("(rank).score"),Alias::new("score"));
532+
query.expr_as(
533+
Expr::cust(format!("(rank).score * {boost}")),
534+
Alias::new("score"),
535+
);
529536

530537
// Build the actual CTE
531538
letmut sub_query_rank_call =Query::select();
@@ -537,14 +544,7 @@ pub async fn build_search_query(
537544
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
538545
[model_expr, query_expr, parameters_expr],
539546
),Alias::new("rank"))
540-
.from(SIden::String(format!("row_number_{cte_name}")))
541-
.join_as(
542-
JoinType::InnerJoin,
543-
chunks_table.to_table_tuple(),
544-
Alias::new("chunks"),
545-
Expr::col((SIden::Str("chunks"),SIden::Str("id")))
546-
.equals((SIden::String(format!("row_number_{cte_name}")),SIden::Str("id"))),
547-
);
547+
.from(SIden::String(format!("row_number_{cte_name}")));
548548

549549
letmut sub_query =Query::select();
550550
sub_query

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp