Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

SDK - Added re-ranking into vector search#1516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
SilasMarvin merged 2 commits intomasterfromsilas-sdk-add-reranking
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletionspgml-sdks/pgml/Cargo.lock
View file
Open in desktop

Some generated files are not rendered by default. Learn more abouthow customized files appear on GitHub.

21 changes: 13 additions & 8 deletionspgml-sdks/pgml/src/collection.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -1051,6 +1051,7 @@ impl Collection {
/// }).into(), &mut pipeline).await?;
/// Ok(())
/// }
#[allow(clippy::type_complexity)]
#[instrument(skip(self))]
pub async fn vector_search(
&mut self,
Expand All@@ -1061,7 +1062,7 @@ impl Collection {

let (built_query, values) =
build_vector_search_query(query.clone(), self, pipeline).await?;
let results: Result<Vec<(Json, String, f64)>, _> =
let results: Result<Vec<(Json, String, f64, Option<f64>)>, _> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await;
Expand All@@ -1072,7 +1073,8 @@ impl Collection {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand All@@ -1087,7 +1089,7 @@ impl Collection {
.await?;
let (built_query, values) =
build_vector_search_query(query, self, pipeline).await?;
let results: Vec<(Json, String, f64)> =
let results: Vec<(Json, String, f64, Option<f64>)> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
Expand All@@ -1097,7 +1099,8 @@ impl Collection {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand All@@ -1121,16 +1124,18 @@ impl Collection {
let pool = get_or_initialize_pool(&self.database_url).await?;
let (built_query, values) =
build_vector_search_query(query.clone(), self, pipeline).await?;
let results: Vec<(Json, String, f64)> = sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
let results: Vec<(Json, String, f64, Option<f64>)> =
sqlx::query_as_with(&built_query, values)
.fetch_all(&pool)
.await?;
Ok(results
.into_iter()
.map(|v| {
serde_json::json!({
"document": v.0,
"chunk": v.1,
"score": v.2
"score": v.2,
"rerank_score": v.3
})
.into()
})
Expand Down
87 changes: 87 additions & 0 deletionspgml-sdks/pgml/src/lib.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -1553,6 +1553,88 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswlear_1";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
pipeline_name,
Some(
json!({
"title": {
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
},
"full_text_search": {
"configuration": "english"
}
},
"body": {
"splitter": {
"model": "recursive_character"
},
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
},
},
})
.into(),
),
)?;
collection.add_pipeline(&mut pipeline).await?;
let results = collection
.vector_search(
json!({
"query": {
"fields": {
"title": {
"query": "Test document: 2",
"parameters": {
"prompt": "passage: "
},
"full_text_filter": "test",
"boost": 1.2
},
"body": {
"query": "Test document: 2",
"parameters": {
"prompt": "passage: "
},
"boost": 1.0
},
}
},
Copy link
ContributorAuthor

@SilasMarvinSilasMarvinJun 10, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

@montanalow How does this "rerank" key look?

query is the text to compare against.
model is the model to use
num_documents_to_rerank are the number of results to return from vector search and rerank against before limiting it to thelimit parameter defined in the next section

montanalow reacted with thumbs up emoji
"rerank": {
"query": "Test document 2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Seems likequery is being repeated a few places in this example, which may be pretty typical. One enhancement would be to move the query string out and reuse it everywhere, and make passing specific sub clause query strings optional. Not a launch blocker though.

SilasMarvin reacted with thumbs up emoji
Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Got it, I will think more on making that optional and reusing it, but will merge this and get it out in the meantime.

"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

What about calling this justlimit. Does llamaindex or transformers have a similarly named parameter name?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Oh sorry missed this before merging. I think it might be a little confusing if we make it limit as we already have a limit key, and this isn't actually the limit. We already defined limit with llama index to mean the final number of items returned, but I'm not sure if they or langchain use it elsewhere.

},
"limit": 5
})
.into(),
&mut pipeline,
)
.await?;
assert!(results[0]["rerank_score"].as_f64().is_some());
let ids: Vec<u64> = results
.into_iter()
.map(|r| r["document"]["id"].as_u64().unwrap())
.collect();
assert_eq!(ids, vec![2, 1, 3, 8, 6]);
collection.archive().await?;
Ok(())
}

///////////////////////////////
// Working With Documents /////
///////////////////////////////
Expand DownExpand Up@@ -2207,6 +2289,11 @@ mod tests {
"id"
]
},
"rerank": {
"query": "Test document 2",
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"num_documents_to_rerank": 100
},
"limit": 5
},
"aggregate": {
Expand Down
4 changes: 1 addition & 3 deletionspgml-sdks/pgml/src/rag_query_builder.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -212,9 +212,7 @@ pub async fn build_rag_query(
r#"(SELECT string_agg(chunk, '{}') FROM "{var_name}")"#,
vector_search.aggregate.join
),
format!(
r#"(SELECT json_agg(jsonb_build_object('chunk', chunk, 'document', document, 'score', score)) FROM "{var_name}")"#
),
format!(r#"(SELECT json_agg(j) FROM "{var_name}" j)"#),
)
}
ValidVariable::RawSQL(sql) => (format!("({})", sql.sql), format!("({})", sql.sql)),
Expand Down
99 changes: 96 additions & 3 deletionspgml-sdks/pgml/src/vector_search_query_builder.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -41,6 +41,20 @@ struct ValidDocument {
keys: Option<Vec<String>>,
}

const fn default_num_documents_to_rerank() -> u64 {
10
}

#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}

const fn default_limit() -> u64 {
10
}
Expand All@@ -56,6 +70,8 @@ pub struct ValidQuery {
limit: u64,
// Document related items
document: Option<ValidDocument>,
// Rerank related items
rerank: Option<ValidRerank>,
}

pub async fn build_sqlx_query(
Expand All@@ -66,9 +82,14 @@ pub async fn build_sqlx_query(
prefix: Option<&str>,
) -> anyhow::Result<(SelectStatement, Vec<CommonTableExpression>)> {
let valid_query: ValidQuery = serde_json::from_value(query.0)?;
let limit = valid_query.limit;
let fields = valid_query.query.fields.unwrap_or_default();

let search_limit = if let Some(rerank) = valid_query.rerank.as_ref() {
rerank.num_documents_to_rerank
} else {
valid_query.limit
};

let prefix = prefix.unwrap_or("");

if fields.is_empty() {
Expand DownExpand Up@@ -209,7 +230,7 @@ pub async fn build_sqlx_query(
Expr::col((SIden::Str("documents"), SIden::Str("id")))
.equals((SIden::Str("chunks"), SIden::Str("document_id"))),
)
.limit(limit);
.limit(search_limit);

if let Some(filter) = &valid_query.query.filter {
let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?;
Expand DownExpand Up@@ -272,7 +293,79 @@ pub async fn build_sqlx_query(
// Resort and limit
query
.order_by(SIden::Str("score"), Order::Desc)
.limit(limit);
.limit(search_limit);

// Rerank
let query = if let Some(rerank) = &valid_query.rerank {
// Add our vector_search CTE
let mut vector_search_cte = CommonTableExpression::from_select(query);
vector_search_cte.table_name(Alias::new(format!("{prefix}_vector_search")));
ctes.push(vector_search_cte);

// Add our row_number_vector_search CTE
let mut row_number_vector_search = Query::select();
row_number_vector_search
.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
])
.from(SIden::String(format!("{prefix}_vector_search")));
row_number_vector_search
.expr_as(Expr::cust("ROW_NUMBER() OVER ()"), Alias::new("row_number"));
let mut row_number_vector_search_cte =
CommonTableExpression::from_select(row_number_vector_search);
row_number_vector_search_cte
.table_name(Alias::new(format!("{prefix}_row_number_vector_search")));
ctes.push(row_number_vector_search_cte);

// Our actual select statement
let mut query = Query::select();
query.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
]);
query.expr_as(Expr::cust("(rank).score"), Alias::new("rank_score"));

// Build the actual select statement sub query
let mut sub_query_rank_call = Query::select();
let model_expr = Expr::cust_with_values("$1", [rerank.model.clone()]);
let query_expr = Expr::cust_with_values("$1", [rerank.query.clone()]);
let parameters_expr =
Expr::cust_with_values("$1", [rerank.parameters.clone().unwrap_or_default().0]);
sub_query_rank_call.expr_as(Expr::cust_with_exprs(
format!(r#"pgml.rank($1, $2, array_agg("chunk"), '{{"return_documents": false, "top_k": {}}}'::jsonb || $3)"#, valid_query.limit),
[model_expr, query_expr, parameters_expr],
), Alias::new("rank"))
.from(SIden::String(format!("{prefix}_row_number_vector_search")));

let mut sub_query = Query::select();
sub_query
.columns([
SIden::Str("document"),
SIden::Str("chunk"),
SIden::Str("score"),
SIden::Str("rank"),
])
.from_as(
SIden::String(format!("{prefix}_row_number_vector_search")),
Alias::new("rnsv1"),
)
.join_subquery(
JoinType::InnerJoin,
sub_query_rank_call,
Alias::new("rnsv2"),
Expr::cust("((rank).corpus_id + 1) = rnsv1.row_number"),
);

// Query from the sub query
query.from_subquery(sub_query, Alias::new("sub_query"));

query
} else {
query
};

Ok((query, ctes))
}
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp