Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

SDK - Allow parallel batch uploads#1465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
SilasMarvin merged 1 commit intomasterfromsilas-allow-parallel-batch-uploads
May 14, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 144 additions & 102 deletionspgml-sdks/pgml/src/collection.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
use sea_query::Alias;
use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query};
use sea_query_binder::SqlxBinder;
use serde_json::json;
use sqlx::Executor;
use serde_json::{json, Value};
use sqlx::PgConnection;
use sqlx::{Executor, Pool, Postgres};
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::Path;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use tokio::task::JoinSet;
use tracing::{instrument, warn};
use walkdir::WalkDir;

use crate::debug_sqlx_query;
use crate::filter_builder::FilterBuilder;
use crate::pipeline::FieldAction;
use crate::search_query_builder::build_search_query;
use crate::vector_search_query_builder::build_vector_search_query;
use crate::{
Expand DownExpand Up@@ -496,28 +498,80 @@ impl Collection {
// -> Insert the document
// -> Foreach pipeline check if we need to resync the document and if so sync the document
// -> Commit the transaction
let mut args = args.unwrap_or_default();
let args = args.as_object_mut().context("args must be a JSON object")?;

self.verify_in_database(false).await?;
let mut pipelines = self.get_pipelines().await?;

let pool = get_or_initialize_pool(&self.database_url).await?;

let mut parsed_schemas = vec![];
let project_info = &self.database_data.as_ref().unwrap().project_info;
let mut parsed_schemas = vec![];
for pipeline in &mut pipelines {
let parsed_schema = pipeline
.get_parsed_schema(project_info, &pool)
.await
.expect("Error getting parsed schema for pipeline");
parsed_schemas.push(parsed_schema);
}
let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect();
let pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)> =
pipelines.into_iter().zip(parsed_schemas).collect();

let args = args.unwrap_or_default();
let args = args.as_object().context("args must be a JSON object")?;
let batch_size = args
.remove("batch_size")
.map(|x| x.try_to_u64())
.unwrap_or(Ok(100))?;

let parallel_batches = args
.get("parallel_batches")
.map(|x| x.try_to_u64())
.unwrap_or(Ok(1))? as usize;

let progress_bar = utils::default_progress_bar(documents.len() as u64);
progress_bar.println("Upserting Documents...");

let mut set = JoinSet::new();
for batch in documents.chunks(batch_size as usize) {
if set.len() < parallel_batches {
let local_self = self.clone();
let local_batch = batch.to_owned();
let local_args = args.clone();
let local_pipelines = pipelines.clone();
let local_pool = pool.clone();
set.spawn(async move {
local_self
._upsert_documents(local_batch, local_args, local_pipelines, local_pool)
.await
});
} else {
if let Some(res) = set.join_next().await {
res??;
progress_bar.inc(batch_size);
}
}
}

while let Some(res) = set.join_next().await {
res??;
progress_bar.inc(batch_size);
}

progress_bar.println("Done Upserting Documents\n");
progress_bar.finish();

Ok(())
}

async fn _upsert_documents(
self,
batch: Vec<Json>,
args: serde_json::Map<String, Value>,
mut pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)>,
pool: Pool<Postgres>,
) -> anyhow::Result<()> {
let project_info = &self.database_data.as_ref().unwrap().project_info;

let query = if args
.get("merge")
.map(|v| v.as_bool().unwrap_or(false))
Expand All@@ -539,111 +593,99 @@ impl Collection {
)
};

let batch_size = args
.get("batch_size")
.map(TryToNumeric::try_to_u64)
.unwrap_or(Ok(100))?;

for batch in documents.chunks(batch_size as usize) {
let mut transaction = pool.begin().await?;

let mut query_values = String::new();
let mut binding_parameter_counter = 1;
for _ in 0..batch.len() {
query_values = format!(
"{query_values}, (${}, ${}, ${})",
binding_parameter_counter,
binding_parameter_counter + 1,
binding_parameter_counter + 2
);
binding_parameter_counter += 3;
}
let mut transaction = pool.begin().await?;

let query = query.replace(
"{values_parameters}",
&query_values.chars().skip(1).collect::<String>(),
);
let query = query.replace(
"{binding_parameter}",
&format!("${binding_parameter_counter}"),
let mut query_values = String::new();
let mut binding_parameter_counter = 1;
for _ in 0..batch.len() {
query_values = format!(
"{query_values}, (${}, ${}, ${})",
binding_parameter_counter,
binding_parameter_counter + 1,
binding_parameter_counter + 2
);
binding_parameter_counter += 3;
}

let mut query = sqlx::query_as(&query);

let mut source_uuids = vec![];
for document in batch {
let id = document
.get("id")
.context("`id` must be a key in document")?
.to_string();
let md5_digest = md5::compute(id.as_bytes());
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
source_uuids.push(source_uuid);

let start = SystemTime::now();
let timestamp = start
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis();

let versions: HashMap<String, serde_json::Value> = document
.as_object()
.context("document must be an object")?
.iter()
.try_fold(HashMap::new(), |mut acc, (key, value)| {
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
let md5_digest = format!("{md5_digest:x}");
acc.insert(
key.to_owned(),
serde_json::json!({
"last_updated": timestamp,
"md5": md5_digest
}),
);
anyhow::Ok(acc)
})?;
let versions = serde_json::to_value(versions)?;

query = query.bind(source_uuid).bind(document).bind(versions);
}
let query = query.replace(
"{values_parameters}",
&query_values.chars().skip(1).collect::<String>(),
);
let query = query.replace(
"{binding_parameter}",
&format!("${binding_parameter_counter}"),
);

let results: Vec<(i64, Option<Json>)> = query
.bind(source_uuids)
.fetch_all(&mut *transaction)
.await?;
let mut query = sqlx::query_as(&query);

let mut source_uuids = vec![];
for document in &batch {
let id = document
.get("id")
.context("`id` must be a key in document")?
.to_string();
let md5_digest = md5::compute(id.as_bytes());
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
source_uuids.push(source_uuid);

let start = SystemTime::now();
let timestamp = start
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis();

let versions: HashMap<String, serde_json::Value> = document
.as_object()
.context("document must be an object")?
.iter()
.try_fold(HashMap::new(), |mut acc, (key, value)| {
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
let md5_digest = format!("{md5_digest:x}");
acc.insert(
key.to_owned(),
serde_json::json!({
"last_updated": timestamp,
"md5": md5_digest
}),
);
anyhow::Ok(acc)
})?;
let versions = serde_json::to_value(versions)?;

let dp: Vec<(i64, Json, Option<Json>)> = results
.into_iter()
.zip(batch)
.map(|((id, previous_document), document)| {
(id, document.to_owned(), previous_document)
query = query.bind(source_uuid).bind(document).bind(versions);
}

let results: Vec<(i64, Option<Json>)> = query
.bind(source_uuids)
.fetch_all(&mut *transaction)
.await?;

let dp: Vec<(i64, Json, Option<Json>)> = results
.into_iter()
.zip(batch)
.map(|((id, previous_document), document)| (id, document.to_owned(), previous_document))
.collect();

for (pipeline, parsed_schema) in &mut pipelines {
let ids_to_run_on: Vec<i64> = dp
.iter()
.filter(|(_, document, previous_document)| match previous_document {
Some(previous_document) => parsed_schema
.iter()
.any(|(key, _)| document[key] != previous_document[key]),
None => true,
})
.map(|(document_id, _, _)| *document_id)
.collect();

for (pipeline, parsed_schema) in &mut pipelines {
let ids_to_run_on: Vec<i64> = dp
.iter()
.filter(|(_, document, previous_document)| match previous_document {
Some(previous_document) => parsed_schema
.iter()
.any(|(key, _)| document[key] != previous_document[key]),
None => true,
})
.map(|(document_id, _, _)| *document_id)
.collect();
if !ids_to_run_on.is_empty() {
pipeline
.sync_documents(ids_to_run_on, project_info, &mut transaction)
.await
.expect("Failed to execute pipeline");
}
if !ids_to_run_on.is_empty() {
pipeline
.sync_documents(ids_to_run_on, project_info, &mut transaction)
.await
.expect("Failed to execute pipeline");
}

transaction.commit().await?;
progress_bar.inc(batch_size);
}
progress_bar.println("Done Upserting Documents\n");
progress_bar.finish();

transaction.commit().await?;
Ok(())
}

Expand Down
80 changes: 80 additions & 0 deletionspgml-sdks/pgml/src/lib.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -431,6 +431,86 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn can_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_capaud_107";
let pipeline_name = "test_r_p_capaud_6";
let mut pipeline = Pipeline::new(
pipeline_name,
Some(
json!({
"title": {
"semantic_search": {
"model": "intfloat/e5-small"
}
},
"body": {
"splitter": {
"model": "recursive_character",
"parameters": {
"chunk_size": 1000,
"chunk_overlap": 40
}
},
"semantic_search": {
"model": "hkunlp/instructor-base",
"parameters": {
"instruction": "Represent the Wikipedia document for retrieval"
}
},
"full_text_search": {
"configuration": "english"
}
}
})
.into(),
),
)?;
let mut collection = Collection::new(collection_name, None)?;
collection.add_pipeline(&mut pipeline).await?;
let documents = generate_dummy_documents(20);
collection
.upsert_documents(
documents.clone(),
Some(
json!({
"batch_size": 4,
"parallel_batches": 5
})
.into(),
),
)
.await?;
let pool = get_or_initialize_pool(&None).await?;
let documents_table = format!("{}.documents", collection_name);
let queried_documents: Vec<models::Document> =
sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table))
.fetch_all(&pool)
.await?;
assert!(queried_documents.len() == 20);
let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name);
let title_chunks: Vec<models::Chunk> =
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
.fetch_all(&pool)
.await?;
assert!(title_chunks.len() == 20);
let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name);
let body_chunks: Vec<models::Chunk> =
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
.fetch_all(&pool)
.await?;
assert!(body_chunks.len() == 120);
let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name);
let tsvectors: Vec<models::TSVector> =
sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table))
.fetch_all(&pool)
.await?;
assert!(tsvectors.len() == 120);
collection.archive().await?;
Ok(())
}

#[tokio::test]
async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp