@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
66use sea_query:: Alias ;
77use sea_query:: { Expr , NullOrdering , Order , PostgresQueryBuilder , Query } ;
88use sea_query_binder:: SqlxBinder ;
9- use serde_json:: json;
10- use sqlx:: Executor ;
9+ use serde_json:: { json, Value } ;
1110use sqlx:: PgConnection ;
11+ use sqlx:: { Executor , Pool , Postgres } ;
1212use std:: borrow:: Cow ;
1313use std:: collections:: HashMap ;
1414use std:: path:: Path ;
1515use std:: time:: SystemTime ;
1616use std:: time:: UNIX_EPOCH ;
17+ use tokio:: task:: JoinSet ;
1718use tracing:: { instrument, warn} ;
1819use walkdir:: WalkDir ;
1920
2021use crate :: debug_sqlx_query;
2122use crate :: filter_builder:: FilterBuilder ;
23+ use crate :: pipeline:: FieldAction ;
2224use crate :: search_query_builder:: build_search_query;
2325use crate :: vector_search_query_builder:: build_vector_search_query;
2426use crate :: {
@@ -496,28 +498,80 @@ impl Collection {
496498// -> Insert the document
497499// -> Foreach pipeline check if we need to resync the document and if so sync the document
498500// -> Commit the transaction
501+ let mut args = args. unwrap_or_default ( ) ;
502+ let args = args. as_object_mut ( ) . context ( "args must be a JSON object" ) ?;
503+
499504self . verify_in_database ( false ) . await ?;
500505let mut pipelines =self . get_pipelines ( ) . await ?;
501506
502507let pool =get_or_initialize_pool ( & self . database_url ) . await ?;
503508
504- let mut parsed_schemas =vec ! [ ] ;
505509let project_info =& self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
510+ let mut parsed_schemas =vec ! [ ] ;
506511for pipelinein & mut pipelines{
507512let parsed_schema = pipeline
508513. get_parsed_schema ( project_info, & pool)
509514. await
510515. expect ( "Error getting parsed schema for pipeline" ) ;
511516 parsed_schemas. push ( parsed_schema) ;
512517}
513- let mut pipelines: Vec < ( Pipeline , _ ) > = pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
518+ let pipelines: Vec < ( Pipeline , HashMap < String , FieldAction > ) > =
519+ pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
514520
515- let args = args. unwrap_or_default ( ) ;
516- let args = args. as_object ( ) . context ( "args must be a JSON object" ) ?;
521+ let batch_size = args
522+ . remove ( "batch_size" )
523+ . map ( |x| x. try_to_u64 ( ) )
524+ . unwrap_or ( Ok ( 100 ) ) ?;
525+
526+ let parallel_batches = args
527+ . get ( "parallel_batches" )
528+ . map ( |x| x. try_to_u64 ( ) )
529+ . unwrap_or ( Ok ( 1 ) ) ?as usize ;
517530
518531let progress_bar = utils:: default_progress_bar ( documents. len ( ) as u64 ) ;
519532 progress_bar. println ( "Upserting Documents..." ) ;
520533
534+ let mut set =JoinSet :: new ( ) ;
535+ for batchin documents. chunks ( batch_sizeas usize ) {
536+ if set. len ( ) < parallel_batches{
537+ let local_self =self . clone ( ) ;
538+ let local_batch = batch. to_owned ( ) ;
539+ let local_args = args. clone ( ) ;
540+ let local_pipelines = pipelines. clone ( ) ;
541+ let local_pool = pool. clone ( ) ;
542+ set. spawn ( async move {
543+ local_self
544+ . _upsert_documents ( local_batch, local_args, local_pipelines, local_pool)
545+ . await
546+ } ) ;
547+ } else {
548+ if let Some ( res) = set. join_next ( ) . await {
549+ res??;
550+ progress_bar. inc ( batch_size) ;
551+ }
552+ }
553+ }
554+
555+ while let Some ( res) = set. join_next ( ) . await {
556+ res??;
557+ progress_bar. inc ( batch_size) ;
558+ }
559+
560+ progress_bar. println ( "Done Upserting Documents\n " ) ;
561+ progress_bar. finish ( ) ;
562+
563+ Ok ( ( ) )
564+ }
565+
566+ async fn _upsert_documents (
567+ self ,
568+ batch : Vec < Json > ,
569+ args : serde_json:: Map < String , Value > ,
570+ mut pipelines : Vec < ( Pipeline , HashMap < String , FieldAction > ) > ,
571+ pool : Pool < Postgres > ,
572+ ) -> anyhow:: Result < ( ) > {
573+ let project_info =& self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
574+
521575let query =if args
522576. get ( "merge" )
523577. map ( |v| v. as_bool ( ) . unwrap_or ( false ) )
@@ -539,111 +593,99 @@ impl Collection {
539593)
540594} ;
541595
542- let batch_size = args
543- . get ( "batch_size" )
544- . map ( TryToNumeric :: try_to_u64)
545- . unwrap_or ( Ok ( 100 ) ) ?;
546-
547- for batchin documents. chunks ( batch_sizeas usize ) {
548- let mut transaction = pool. begin ( ) . await ?;
549-
550- let mut query_values =String :: new ( ) ;
551- let mut binding_parameter_counter =1 ;
552- for _in 0 ..batch. len ( ) {
553- query_values =format ! (
554- "{query_values}, (${}, ${}, ${})" ,
555- binding_parameter_counter,
556- binding_parameter_counter +1 ,
557- binding_parameter_counter +2
558- ) ;
559- binding_parameter_counter +=3 ;
560- }
596+ let mut transaction = pool. begin ( ) . await ?;
561597
562- let query = query. replace (
563- "{values_parameters}" ,
564- & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
565- ) ;
566- let query = query. replace (
567- "{binding_parameter}" ,
568- & format ! ( "${binding_parameter_counter}" ) ,
598+ let mut query_values =String :: new ( ) ;
599+ let mut binding_parameter_counter =1 ;
600+ for _in 0 ..batch. len ( ) {
601+ query_values =format ! (
602+ "{query_values}, (${}, ${}, ${})" ,
603+ binding_parameter_counter,
604+ binding_parameter_counter +1 ,
605+ binding_parameter_counter +2
569606) ;
607+ binding_parameter_counter +=3 ;
608+ }
570609
571- let mut query = sqlx:: query_as ( & query) ;
572-
573- let mut source_uuids =vec ! [ ] ;
574- for documentin batch{
575- let id = document
576- . get ( "id" )
577- . context ( "`id` must be a key in document" ) ?
578- . to_string ( ) ;
579- let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
580- let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
581- source_uuids. push ( source_uuid) ;
582-
583- let start =SystemTime :: now ( ) ;
584- let timestamp = start
585- . duration_since ( UNIX_EPOCH )
586- . expect ( "Time went backwards" )
587- . as_millis ( ) ;
588-
589- let versions: HashMap < String , serde_json:: Value > = document
590- . as_object ( )
591- . context ( "document must be an object" ) ?
592- . iter ( )
593- . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) |{
594- let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
595- let md5_digest =format ! ( "{md5_digest:x}" ) ;
596- acc. insert (
597- key. to_owned ( ) ,
598- serde_json:: json!( {
599- "last_updated" : timestamp,
600- "md5" : md5_digest
601- } ) ,
602- ) ;
603- anyhow:: Ok ( acc)
604- } ) ?;
605- let versions = serde_json:: to_value ( versions) ?;
606-
607- query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
608- }
610+ let query = query. replace (
611+ "{values_parameters}" ,
612+ & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
613+ ) ;
614+ let query = query. replace (
615+ "{binding_parameter}" ,
616+ & format ! ( "${binding_parameter_counter}" ) ,
617+ ) ;
609618
610- let results: Vec < ( i64 , Option < Json > ) > = query
611- . bind ( source_uuids)
612- . fetch_all ( & mut * transaction)
613- . await ?;
619+ let mut query = sqlx:: query_as ( & query) ;
620+
621+ let mut source_uuids =vec ! [ ] ;
622+ for documentin & batch{
623+ let id = document
624+ . get ( "id" )
625+ . context ( "`id` must be a key in document" ) ?
626+ . to_string ( ) ;
627+ let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
628+ let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
629+ source_uuids. push ( source_uuid) ;
630+
631+ let start =SystemTime :: now ( ) ;
632+ let timestamp = start
633+ . duration_since ( UNIX_EPOCH )
634+ . expect ( "Time went backwards" )
635+ . as_millis ( ) ;
636+
637+ let versions: HashMap < String , serde_json:: Value > = document
638+ . as_object ( )
639+ . context ( "document must be an object" ) ?
640+ . iter ( )
641+ . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) |{
642+ let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
643+ let md5_digest =format ! ( "{md5_digest:x}" ) ;
644+ acc. insert (
645+ key. to_owned ( ) ,
646+ serde_json:: json!( {
647+ "last_updated" : timestamp,
648+ "md5" : md5_digest
649+ } ) ,
650+ ) ;
651+ anyhow:: Ok ( acc)
652+ } ) ?;
653+ let versions = serde_json:: to_value ( versions) ?;
614654
615- let dp: Vec < ( i64 , Json , Option < Json > ) > = results
616- . into_iter ( )
617- . zip ( batch)
618- . map ( |( ( id, previous_document) , document) |{
619- ( id, document. to_owned ( ) , previous_document)
655+ query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
656+ }
657+
658+ let results: Vec < ( i64 , Option < Json > ) > = query
659+ . bind ( source_uuids)
660+ . fetch_all ( & mut * transaction)
661+ . await ?;
662+
663+ let dp: Vec < ( i64 , Json , Option < Json > ) > = results
664+ . into_iter ( )
665+ . zip ( batch)
666+ . map ( |( ( id, previous_document) , document) |( id, document. to_owned ( ) , previous_document) )
667+ . collect ( ) ;
668+
669+ for ( pipeline, parsed_schema) in & mut pipelines{
670+ let ids_to_run_on: Vec < i64 > = dp
671+ . iter ( )
672+ . filter ( |( _, document, previous_document) |match previous_document{
673+ Some ( previous_document) => parsed_schema
674+ . iter ( )
675+ . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
676+ None =>true ,
620677} )
678+ . map ( |( document_id, _, _) |* document_id)
621679. collect ( ) ;
622-
623- for ( pipeline, parsed_schema) in & mut pipelines{
624- let ids_to_run_on: Vec < i64 > = dp
625- . iter ( )
626- . filter ( |( _, document, previous_document) |match previous_document{
627- Some ( previous_document) => parsed_schema
628- . iter ( )
629- . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
630- None =>true ,
631- } )
632- . map ( |( document_id, _, _) |* document_id)
633- . collect ( ) ;
634- if !ids_to_run_on. is_empty ( ) {
635- pipeline
636- . sync_documents ( ids_to_run_on, project_info, & mut transaction)
637- . await
638- . expect ( "Failed to execute pipeline" ) ;
639- }
680+ if !ids_to_run_on. is_empty ( ) {
681+ pipeline
682+ . sync_documents ( ids_to_run_on, project_info, & mut transaction)
683+ . await
684+ . expect ( "Failed to execute pipeline" ) ;
640685}
641-
642- transaction. commit ( ) . await ?;
643- progress_bar. inc ( batch_size) ;
644686}
645- progress_bar . println ( "Done Upserting Documents \n " ) ;
646- progress_bar . finish ( ) ;
687+
688+ transaction . commit ( ) . await ? ;
647689Ok ( ( ) )
648690}
649691