Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc3a8514

Browse files
authored
SDK - Added re-ranking into vector search (#1516)
1 parent34e64d8 commitc3a8514

File tree

5 files changed

+197
-17
lines changed

5 files changed

+197
-17
lines changed

‎pgml-sdks/pgml/Cargo.lock

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more aboutcustomizing how changed files appear on GitHub.

‎pgml-sdks/pgml/src/collection.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ impl Collection {
10511051
/// }).into(), &mut pipeline).await?;
10521052
/// Ok(())
10531053
/// }
1054+
#[allow(clippy::type_complexity)]
10541055
#[instrument(skip(self))]
10551056
pubasyncfnvector_search(
10561057
&mutself,
@@ -1061,7 +1062,7 @@ impl Collection {
10611062

10621063
let(built_query, values) =
10631064
build_vector_search_query(query.clone(),self, pipeline).await?;
1064-
let results:Result<Vec<(Json,String,f64)>,_> =
1065+
let results:Result<Vec<(Json,String,f64,Option<f64>)>,_> =
10651066
sqlx::query_as_with(&built_query, values)
10661067
.fetch_all(&pool)
10671068
.await;
@@ -1072,7 +1073,8 @@ impl Collection {
10721073
serde_json::json!({
10731074
"document": v.0,
10741075
"chunk": v.1,
1075-
"score": v.2
1076+
"score": v.2,
1077+
"rerank_score": v.3
10761078
})
10771079
.into()
10781080
})
@@ -1087,7 +1089,7 @@ impl Collection {
10871089
.await?;
10881090
let(built_query, values) =
10891091
build_vector_search_query(query,self, pipeline).await?;
1090-
let results:Vec<(Json,String,f64)> =
1092+
let results:Vec<(Json,String,f64,Option<f64>)> =
10911093
sqlx::query_as_with(&built_query, values)
10921094
.fetch_all(&pool)
10931095
.await?;
@@ -1097,7 +1099,8 @@ impl Collection {
10971099
serde_json::json!({
10981100
"document": v.0,
10991101
"chunk": v.1,
1100-
"score": v.2
1102+
"score": v.2,
1103+
"rerank_score": v.3
11011104
})
11021105
.into()
11031106
})
@@ -1121,16 +1124,18 @@ impl Collection {
11211124
let pool =get_or_initialize_pool(&self.database_url).await?;
11221125
let(built_query, values) =
11231126
build_vector_search_query(query.clone(),self, pipeline).await?;
1124-
let results:Vec<(Json,String,f64)> = sqlx::query_as_with(&built_query, values)
1125-
.fetch_all(&pool)
1126-
.await?;
1127+
let results:Vec<(Json,String,f64,Option<f64>)> =
1128+
sqlx::query_as_with(&built_query, values)
1129+
.fetch_all(&pool)
1130+
.await?;
11271131
Ok(results
11281132
.into_iter()
11291133
.map(|v|{
11301134
serde_json::json!({
11311135
"document": v.0,
11321136
"chunk": v.1,
1133-
"score": v.2
1137+
"score": v.2,
1138+
"rerank_score": v.3
11341139
})
11351140
.into()
11361141
})

‎pgml-sdks/pgml/src/lib.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,88 @@ mod tests {
15531553
Ok(())
15541554
}
15551555

1556+
#[tokio::test]
1557+
asyncfncan_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()>{
1558+
internal_init_logger(None,None).ok();
1559+
let collection_name ="test r_c_cvswlear_1";
1560+
letmut collection =Collection::new(collection_name,None)?;
1561+
let documents =generate_dummy_documents(10);
1562+
collection.upsert_documents(documents.clone(),None).await?;
1563+
let pipeline_name ="0";
1564+
letmut pipeline =Pipeline::new(
1565+
pipeline_name,
1566+
Some(
1567+
json!({
1568+
"title":{
1569+
"semantic_search":{
1570+
"model":"intfloat/e5-small-v2",
1571+
"parameters":{
1572+
"prompt":"passage: "
1573+
}
1574+
},
1575+
"full_text_search":{
1576+
"configuration":"english"
1577+
}
1578+
},
1579+
"body":{
1580+
"splitter":{
1581+
"model":"recursive_character"
1582+
},
1583+
"semantic_search":{
1584+
"model":"intfloat/e5-small-v2",
1585+
"parameters":{
1586+
"prompt":"passage: "
1587+
}
1588+
},
1589+
},
1590+
})
1591+
.into(),
1592+
),
1593+
)?;
1594+
collection.add_pipeline(&mut pipeline).await?;
1595+
let results = collection
1596+
.vector_search(
1597+
json!({
1598+
"query":{
1599+
"fields":{
1600+
"title":{
1601+
"query":"Test document: 2",
1602+
"parameters":{
1603+
"prompt":"passage: "
1604+
},
1605+
"full_text_filter":"test",
1606+
"boost":1.2
1607+
},
1608+
"body":{
1609+
"query":"Test document: 2",
1610+
"parameters":{
1611+
"prompt":"passage: "
1612+
},
1613+
"boost":1.0
1614+
},
1615+
}
1616+
},
1617+
"rerank":{
1618+
"query":"Test document 2",
1619+
"model":"mixedbread-ai/mxbai-rerank-base-v1",
1620+
"num_documents_to_rerank":100
1621+
},
1622+
"limit":5
1623+
})
1624+
.into(),
1625+
&mut pipeline,
1626+
)
1627+
.await?;
1628+
assert!(results[0]["rerank_score"].as_f64().is_some());
1629+
let ids:Vec<u64> = results
1630+
.into_iter()
1631+
.map(|r| r["document"]["id"].as_u64().unwrap())
1632+
.collect();
1633+
assert_eq!(ids, vec![2,1,3,8,6]);
1634+
collection.archive().await?;
1635+
Ok(())
1636+
}
1637+
15561638
///////////////////////////////
15571639
// Working With Documents /////
15581640
///////////////////////////////
@@ -2207,6 +2289,11 @@ mod tests {
22072289
"id"
22082290
]
22092291
},
2292+
"rerank":{
2293+
"query":"Test document 2",
2294+
"model":"mixedbread-ai/mxbai-rerank-base-v1",
2295+
"num_documents_to_rerank":100
2296+
},
22102297
"limit":5
22112298
},
22122299
"aggregate":{

‎pgml-sdks/pgml/src/rag_query_builder.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,7 @@ pub async fn build_rag_query(
212212
r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#,
213213
vector_search.aggregate.join
214214
),
215-
format!(
216-
r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"#
217-
),
215+
format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#),
218216
)
219217
}
220218
ValidVariable::RawSQL(sql) =>(format!("({})", sql.sql),format!("({})", sql.sql)),

‎pgml-sdks/pgml/src/vector_search_query_builder.rs

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ struct ValidDocument {
4141
keys:Option<Vec<String>>,
4242
}
4343

44+
constfndefault_num_documents_to_rerank() ->u64{
45+
10
46+
}
47+
48+
#[derive(Debug,Deserialize,Serialize,Clone)]
49+
#[serde(deny_unknown_fields)]
50+
structValidRerank{
51+
query:String,
52+
model:String,
53+
#[serde(default ="default_num_documents_to_rerank")]
54+
num_documents_to_rerank:u64,
55+
parameters:Option<Json>,
56+
}
57+
4458
constfndefault_limit() ->u64{
4559
10
4660
}
@@ -56,6 +70,8 @@ pub struct ValidQuery {
5670
limit:u64,
5771
// Document related items
5872
document:Option<ValidDocument>,
73+
// Rerank related items
74+
rerank:Option<ValidRerank>,
5975
}
6076

6177
pubasyncfnbuild_sqlx_query(
@@ -66,9 +82,14 @@ pub async fn build_sqlx_query(
6682
prefix:Option<&str>,
6783
) -> anyhow::Result<(SelectStatement,Vec<CommonTableExpression>)>{
6884
let valid_query:ValidQuery = serde_json::from_value(query.0)?;
69-
let limit = valid_query.limit;
7085
let fields = valid_query.query.fields.unwrap_or_default();
7186

87+
let search_limit =ifletSome(rerank) = valid_query.rerank.as_ref(){
88+
rerank.num_documents_to_rerank
89+
}else{
90+
valid_query.limit
91+
};
92+
7293
let prefix = prefix.unwrap_or("");
7394

7495
if fields.is_empty(){
@@ -209,7 +230,7 @@ pub async fn build_sqlx_query(
209230
Expr::col((SIden::Str("documents"),SIden::Str("id")))
210231
.equals((SIden::Str("chunks"),SIden::Str("document_id"))),
211232
)
212-
.limit(limit);
233+
.limit(search_limit);
213234

214235
ifletSome(filter) =&valid_query.query.filter{
215236
let filter =FilterBuilder::new(filter.clone().0,"documents","document").build()?;
@@ -272,7 +293,79 @@ pub async fn build_sqlx_query(
272293
// Resort and limit
273294
query
274295
.order_by(SIden::Str("score"),Order::Desc)
275-
.limit(limit);
296+
.limit(search_limit);
297+
298+
// Rerank
299+
let query =ifletSome(rerank) =&valid_query.rerank{
300+
// Add our vector_search CTE
301+
letmut vector_search_cte =CommonTableExpression::from_select(query);
302+
vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search")));
303+
ctes.push(vector_search_cte);
304+
305+
// Add our row_number_vector_search CTE
306+
letmut row_number_vector_search =Query::select();
307+
row_number_vector_search
308+
.columns([
309+
SIden::Str("document"),
310+
SIden::Str("chunk"),
311+
SIden::Str("score"),
312+
])
313+
.from(SIden::String(format!("{prefix}_vector_search")));
314+
row_number_vector_search
315+
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"),Alias::new("row_number"));
316+
letmut row_number_vector_search_cte =
317+
CommonTableExpression::from_select(row_number_vector_search);
318+
row_number_vector_search_cte
319+
.table_name(Alias::new(format!("{prefix}_row_number_vector_search")));
320+
ctes.push(row_number_vector_search_cte);
321+
322+
// Our actual select statement
323+
letmut query =Query::select();
324+
query.columns([
325+
SIden::Str("document"),
326+
SIden::Str("chunk"),
327+
SIden::Str("score"),
328+
]);
329+
query.expr_as(Expr::cust("(rank).score"),Alias::new("rank_score"));
330+
331+
// Build the actual select statement sub query
332+
letmut sub_query_rank_call =Query::select();
333+
let model_expr =Expr::cust_with_values("$1",[rerank.model.clone()]);
334+
let query_expr =Expr::cust_with_values("$1",[rerank.query.clone()]);
335+
let parameters_expr =
336+
Expr::cust_with_values("$1",[rerank.parameters.clone().unwrap_or_default().0]);
337+
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
338+
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
339+
[model_expr, query_expr, parameters_expr],
340+
),Alias::new("rank"))
341+
.from(SIden::String(format!("{prefix}_row_number_vector_search")));
342+
343+
letmut sub_query =Query::select();
344+
sub_query
345+
.columns([
346+
SIden::Str("document"),
347+
SIden::Str("chunk"),
348+
SIden::Str("score"),
349+
SIden::Str("rank"),
350+
])
351+
.from_as(
352+
SIden::String(format!("{prefix}_row_number_vector_search")),
353+
Alias::new("rnsv1"),
354+
)
355+
.join_subquery(
356+
JoinType::InnerJoin,
357+
sub_query_rank_call,
358+
Alias::new("rnsv2"),
359+
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
360+
);
361+
362+
// Query from the sub query
363+
query.from_subquery(sub_query,Alias::new("sub_query"));
364+
365+
query
366+
}else{
367+
query
368+
};
276369

277370
Ok((query, ctes))
278371
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp