Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd31b6f4

Browse files
authored
SDK - Allow parallel batch uploads (#1465)
1 parent6d061ed commitd31b6f4

File tree

2 files changed

+224
-102
lines changed

2 files changed

+224
-102
lines changed

‎pgml-sdks/pgml/src/collection.rs

Lines changed: 144 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
66
use sea_query::Alias;
77
use sea_query::{Expr,NullOrdering,Order,PostgresQueryBuilder,Query};
88
use sea_query_binder::SqlxBinder;
9-
use serde_json::json;
10-
use sqlx::Executor;
9+
use serde_json::{json,Value};
1110
use sqlx::PgConnection;
11+
use sqlx::{Executor,Pool,Postgres};
1212
use std::borrow::Cow;
1313
use std::collections::HashMap;
1414
use std::path::Path;
1515
use std::time::SystemTime;
1616
use std::time::UNIX_EPOCH;
17+
use tokio::task::JoinSet;
1718
use tracing::{instrument, warn};
1819
use walkdir::WalkDir;
1920

2021
usecrate::debug_sqlx_query;
2122
usecrate::filter_builder::FilterBuilder;
23+
usecrate::pipeline::FieldAction;
2224
usecrate::search_query_builder::build_search_query;
2325
usecrate::vector_search_query_builder::build_vector_search_query;
2426
usecrate::{
@@ -496,28 +498,80 @@ impl Collection {
496498
// -> Insert the document
497499
// -> Foreach pipeline check if we need to resync the document and if so sync the document
498500
// -> Commit the transaction
501+
letmut args = args.unwrap_or_default();
502+
let args = args.as_object_mut().context("args must be a JSON object")?;
503+
499504
self.verify_in_database(false).await?;
500505
letmut pipelines =self.get_pipelines().await?;
501506

502507
let pool =get_or_initialize_pool(&self.database_url).await?;
503508

504-
letmut parsed_schemas =vec![];
505509
let project_info =&self.database_data.as_ref().unwrap().project_info;
510+
letmut parsed_schemas =vec![];
506511
for pipelinein&mut pipelines{
507512
let parsed_schema = pipeline
508513
.get_parsed_schema(project_info,&pool)
509514
.await
510515
.expect("Error getting parsed schema for pipeline");
511516
parsed_schemas.push(parsed_schema);
512517
}
513-
letmut pipelines:Vec<(Pipeline,_)> = pipelines.into_iter().zip(parsed_schemas).collect();
518+
let pipelines:Vec<(Pipeline,HashMap<String,FieldAction>)> =
519+
pipelines.into_iter().zip(parsed_schemas).collect();
514520

515-
let args = args.unwrap_or_default();
516-
let args = args.as_object().context("args must be a JSON object")?;
521+
let batch_size = args
522+
.remove("batch_size")
523+
.map(|x| x.try_to_u64())
524+
.unwrap_or(Ok(100))?;
525+
526+
let parallel_batches = args
527+
.get("parallel_batches")
528+
.map(|x| x.try_to_u64())
529+
.unwrap_or(Ok(1))?asusize;
517530

518531
let progress_bar = utils::default_progress_bar(documents.len()asu64);
519532
progress_bar.println("Upserting Documents...");
520533

534+
letmut set =JoinSet::new();
535+
for batchin documents.chunks(batch_sizeasusize){
536+
if set.len() < parallel_batches{
537+
let local_self =self.clone();
538+
let local_batch = batch.to_owned();
539+
let local_args = args.clone();
540+
let local_pipelines = pipelines.clone();
541+
let local_pool = pool.clone();
542+
set.spawn(asyncmove{
543+
local_self
544+
._upsert_documents(local_batch, local_args, local_pipelines, local_pool)
545+
.await
546+
});
547+
}else{
548+
ifletSome(res) = set.join_next().await{
549+
res??;
550+
progress_bar.inc(batch_size);
551+
}
552+
}
553+
}
554+
555+
whileletSome(res) = set.join_next().await{
556+
res??;
557+
progress_bar.inc(batch_size);
558+
}
559+
560+
progress_bar.println("Done Upserting Documents\n");
561+
progress_bar.finish();
562+
563+
Ok(())
564+
}
565+
566+
asyncfn_upsert_documents(
567+
self,
568+
batch:Vec<Json>,
569+
args: serde_json::Map<String,Value>,
570+
mutpipelines:Vec<(Pipeline,HashMap<String,FieldAction>)>,
571+
pool:Pool<Postgres>,
572+
) -> anyhow::Result<()>{
573+
let project_info =&self.database_data.as_ref().unwrap().project_info;
574+
521575
let query =if args
522576
.get("merge")
523577
.map(|v| v.as_bool().unwrap_or(false))
@@ -539,111 +593,99 @@ impl Collection {
539593
)
540594
};
541595

542-
let batch_size = args
543-
.get("batch_size")
544-
.map(TryToNumeric::try_to_u64)
545-
.unwrap_or(Ok(100))?;
546-
547-
for batchin documents.chunks(batch_sizeasusize){
548-
letmut transaction = pool.begin().await?;
549-
550-
letmut query_values =String::new();
551-
letmut binding_parameter_counter =1;
552-
for _in0..batch.len(){
553-
query_values =format!(
554-
"{query_values}, (${}, ${}, ${})",
555-
binding_parameter_counter,
556-
binding_parameter_counter +1,
557-
binding_parameter_counter +2
558-
);
559-
binding_parameter_counter +=3;
560-
}
596+
letmut transaction = pool.begin().await?;
561597

562-
let query = query.replace(
563-
"{values_parameters}",
564-
&query_values.chars().skip(1).collect::<String>(),
565-
);
566-
let query = query.replace(
567-
"{binding_parameter}",
568-
&format!("${binding_parameter_counter}"),
598+
letmut query_values =String::new();
599+
letmut binding_parameter_counter =1;
600+
for _in0..batch.len(){
601+
query_values =format!(
602+
"{query_values}, (${}, ${}, ${})",
603+
binding_parameter_counter,
604+
binding_parameter_counter +1,
605+
binding_parameter_counter +2
569606
);
607+
binding_parameter_counter +=3;
608+
}
570609

571-
letmut query = sqlx::query_as(&query);
572-
573-
letmut source_uuids =vec![];
574-
for documentin batch{
575-
let id = document
576-
.get("id")
577-
.context("`id` must be a key in document")?
578-
.to_string();
579-
let md5_digest = md5::compute(id.as_bytes());
580-
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
581-
source_uuids.push(source_uuid);
582-
583-
let start =SystemTime::now();
584-
let timestamp = start
585-
.duration_since(UNIX_EPOCH)
586-
.expect("Time went backwards")
587-
.as_millis();
588-
589-
let versions:HashMap<String, serde_json::Value> = document
590-
.as_object()
591-
.context("document must be an object")?
592-
.iter()
593-
.try_fold(HashMap::new(), |mut acc,(key, value)|{
594-
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
595-
let md5_digest =format!("{md5_digest:x}");
596-
acc.insert(
597-
key.to_owned(),
598-
serde_json::json!({
599-
"last_updated": timestamp,
600-
"md5": md5_digest
601-
}),
602-
);
603-
anyhow::Ok(acc)
604-
})?;
605-
let versions = serde_json::to_value(versions)?;
606-
607-
query = query.bind(source_uuid).bind(document).bind(versions);
608-
}
610+
let query = query.replace(
611+
"{values_parameters}",
612+
&query_values.chars().skip(1).collect::<String>(),
613+
);
614+
let query = query.replace(
615+
"{binding_parameter}",
616+
&format!("${binding_parameter_counter}"),
617+
);
609618

610-
let results:Vec<(i64,Option<Json>)> = query
611-
.bind(source_uuids)
612-
.fetch_all(&mut*transaction)
613-
.await?;
619+
letmut query = sqlx::query_as(&query);
620+
621+
letmut source_uuids =vec![];
622+
for documentin&batch{
623+
let id = document
624+
.get("id")
625+
.context("`id` must be a key in document")?
626+
.to_string();
627+
let md5_digest = md5::compute(id.as_bytes());
628+
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
629+
source_uuids.push(source_uuid);
630+
631+
let start =SystemTime::now();
632+
let timestamp = start
633+
.duration_since(UNIX_EPOCH)
634+
.expect("Time went backwards")
635+
.as_millis();
636+
637+
let versions:HashMap<String, serde_json::Value> = document
638+
.as_object()
639+
.context("document must be an object")?
640+
.iter()
641+
.try_fold(HashMap::new(), |mut acc,(key, value)|{
642+
let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes());
643+
let md5_digest =format!("{md5_digest:x}");
644+
acc.insert(
645+
key.to_owned(),
646+
serde_json::json!({
647+
"last_updated": timestamp,
648+
"md5": md5_digest
649+
}),
650+
);
651+
anyhow::Ok(acc)
652+
})?;
653+
let versions = serde_json::to_value(versions)?;
614654

615-
let dp:Vec<(i64,Json,Option<Json>)> = results
616-
.into_iter()
617-
.zip(batch)
618-
.map(|((id, previous_document), document)|{
619-
(id, document.to_owned(), previous_document)
655+
query = query.bind(source_uuid).bind(document).bind(versions);
656+
}
657+
658+
let results:Vec<(i64,Option<Json>)> = query
659+
.bind(source_uuids)
660+
.fetch_all(&mut*transaction)
661+
.await?;
662+
663+
let dp:Vec<(i64,Json,Option<Json>)> = results
664+
.into_iter()
665+
.zip(batch)
666+
.map(|((id, previous_document), document)|(id, document.to_owned(), previous_document))
667+
.collect();
668+
669+
for(pipeline, parsed_schema)in&mut pipelines{
670+
let ids_to_run_on:Vec<i64> = dp
671+
.iter()
672+
.filter(|(_, document, previous_document)|match previous_document{
673+
Some(previous_document) => parsed_schema
674+
.iter()
675+
.any(|(key, _)| document[key] != previous_document[key]),
676+
None =>true,
620677
})
678+
.map(|(document_id, _, _)|*document_id)
621679
.collect();
622-
623-
for(pipeline, parsed_schema)in&mut pipelines{
624-
let ids_to_run_on:Vec<i64> = dp
625-
.iter()
626-
.filter(|(_, document, previous_document)|match previous_document{
627-
Some(previous_document) => parsed_schema
628-
.iter()
629-
.any(|(key, _)| document[key] != previous_document[key]),
630-
None =>true,
631-
})
632-
.map(|(document_id, _, _)|*document_id)
633-
.collect();
634-
if !ids_to_run_on.is_empty(){
635-
pipeline
636-
.sync_documents(ids_to_run_on, project_info,&mut transaction)
637-
.await
638-
.expect("Failed to execute pipeline");
639-
}
680+
if !ids_to_run_on.is_empty(){
681+
pipeline
682+
.sync_documents(ids_to_run_on, project_info,&mut transaction)
683+
.await
684+
.expect("Failed to execute pipeline");
640685
}
641-
642-
transaction.commit().await?;
643-
progress_bar.inc(batch_size);
644686
}
645-
progress_bar.println("Done Upserting Documents\n");
646-
progress_bar.finish();
687+
688+
transaction.commit().await?;
647689
Ok(())
648690
}
649691

‎pgml-sdks/pgml/src/lib.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,86 @@ mod tests {
431431
Ok(())
432432
}
433433

434+
#[tokio::test]
435+
asyncfncan_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()>{
436+
internal_init_logger(None,None).ok();
437+
let collection_name ="test_r_c_capaud_107";
438+
let pipeline_name ="test_r_p_capaud_6";
439+
letmut pipeline =Pipeline::new(
440+
pipeline_name,
441+
Some(
442+
json!({
443+
"title":{
444+
"semantic_search":{
445+
"model":"intfloat/e5-small"
446+
}
447+
},
448+
"body":{
449+
"splitter":{
450+
"model":"recursive_character",
451+
"parameters":{
452+
"chunk_size":1000,
453+
"chunk_overlap":40
454+
}
455+
},
456+
"semantic_search":{
457+
"model":"hkunlp/instructor-base",
458+
"parameters":{
459+
"instruction":"Represent the Wikipedia document for retrieval"
460+
}
461+
},
462+
"full_text_search":{
463+
"configuration":"english"
464+
}
465+
}
466+
})
467+
.into(),
468+
),
469+
)?;
470+
letmut collection =Collection::new(collection_name,None)?;
471+
collection.add_pipeline(&mut pipeline).await?;
472+
let documents =generate_dummy_documents(20);
473+
collection
474+
.upsert_documents(
475+
documents.clone(),
476+
Some(
477+
json!({
478+
"batch_size":4,
479+
"parallel_batches":5
480+
})
481+
.into(),
482+
),
483+
)
484+
.await?;
485+
let pool =get_or_initialize_pool(&None).await?;
486+
let documents_table =format!("{}.documents", collection_name);
487+
let queried_documents:Vec<models::Document> =
488+
sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table))
489+
.fetch_all(&pool)
490+
.await?;
491+
assert!(queried_documents.len() ==20);
492+
let chunks_table =format!("{}_{}.title_chunks", collection_name, pipeline_name);
493+
let title_chunks:Vec<models::Chunk> =
494+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
495+
.fetch_all(&pool)
496+
.await?;
497+
assert!(title_chunks.len() ==20);
498+
let chunks_table =format!("{}_{}.body_chunks", collection_name, pipeline_name);
499+
let body_chunks:Vec<models::Chunk> =
500+
sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table))
501+
.fetch_all(&pool)
502+
.await?;
503+
assert!(body_chunks.len() ==120);
504+
let tsvectors_table =format!("{}_{}.body_tsvectors", collection_name, pipeline_name);
505+
let tsvectors:Vec<models::TSVector> =
506+
sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table))
507+
.fetch_all(&pool)
508+
.await?;
509+
assert!(tsvectors.len() ==120);
510+
collection.archive().await?;
511+
Ok(())
512+
}
513+
434514
#[tokio::test]
435515
asyncfncan_upsert_documents_and_add_pipeline() -> anyhow::Result<()>{
436516
internal_init_logger(None,None).ok();

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp