Expand Up @@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods}; use sea_query::Alias; use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use serde_json::json; use sqlx::Executor; use serde_json::{json, Value}; use sqlx::PgConnection; use sqlx::{Executor, Pool, Postgres}; use std::borrow::Cow; use std::collections::HashMap; use std::path::Path; use std::time::SystemTime; use std::time::UNIX_EPOCH; use tokio::task::JoinSet; use tracing::{instrument, warn}; use walkdir::WalkDir; use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; use crate::pipeline::FieldAction; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ Expand Down Expand Up @@ -496,28 +498,80 @@ impl Collection { // -> Insert the document // -> Foreach pipeline check if we need to resync the document and if so sync the document // -> Commit the transaction let mut args = args.unwrap_or_default(); let args = args.as_object_mut().context("args must be a JSON object")?; self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; let pool = get_or_initialize_pool(&self.database_url).await?; let mut parsed_schemas = vec![]; let project_info = &self.database_data.as_ref().unwrap().project_info; let mut parsed_schemas = vec![]; for pipeline in &mut pipelines { let parsed_schema = pipeline .get_parsed_schema(project_info, &pool) .await .expect("Error getting parsed schema for pipeline"); parsed_schemas.push(parsed_schema); } let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); let pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)> = pipelines.into_iter().zip(parsed_schemas).collect(); let args = args.unwrap_or_default(); let args = args.as_object().context("args must be a JSON object")?; let batch_size = args .remove("batch_size") .map(|x| x.try_to_u64()) .unwrap_or(Ok(100))?; let parallel_batches = args .get("parallel_batches") .map(|x| x.try_to_u64()) .unwrap_or(Ok(1))? as usize; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); let mut set = JoinSet::new(); for batch in documents.chunks(batch_size as usize) { if set.len() < parallel_batches { let local_self = self.clone(); let local_batch = batch.to_owned(); let local_args = args.clone(); let local_pipelines = pipelines.clone(); let local_pool = pool.clone(); set.spawn(async move { local_self ._upsert_documents(local_batch, local_args, local_pipelines, local_pool) .await }); } else { if let Some(res) = set.join_next().await { res??; progress_bar.inc(batch_size); } } } while let Some(res) = set.join_next().await { res??; progress_bar.inc(batch_size); } progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); Ok(()) } async fn _upsert_documents( self, batch: Vec<Json>, args: serde_json::Map<String, Value>, mut pipelines: Vec<(Pipeline, HashMap<String, FieldAction>)>, pool: Pool<Postgres>, ) -> anyhow::Result<()> { let project_info = &self.database_data.as_ref().unwrap().project_info; let query = if args .get("merge") .map(|v| v.as_bool().unwrap_or(false)) Expand All @@ -539,111 +593,99 @@ impl Collection { ) }; let batch_size = args .get("batch_size") .map(TryToNumeric::try_to_u64) .unwrap_or(Ok(100))?; for batch in documents.chunks(batch_size as usize) { let mut transaction = pool.begin().await?; let mut query_values = String::new(); let mut binding_parameter_counter = 1; for _ in 0..batch.len() { query_values = format!( "{query_values}, (${}, ${}, ${})", binding_parameter_counter, binding_parameter_counter + 1, binding_parameter_counter + 2 ); binding_parameter_counter += 3; } let mut transaction = pool.begin().await?; let query = query.replace( "{values_parameters}", &query_values.chars().skip(1).collect::<String>(), ); let query = query.replace( "{binding_parameter}", &format!("${binding_parameter_counter}"), let mut query_values = String::new(); let mut binding_parameter_counter = 1; for _ in 0..batch.len() { query_values = format!( "{query_values}, (${}, ${}, ${})", binding_parameter_counter, binding_parameter_counter + 1, binding_parameter_counter + 2 ); binding_parameter_counter += 3; } let mut query = sqlx::query_as(&query); let mut source_uuids = vec![]; for document in batch { let id = document .get("id") .context("`id` must be a key in document")? .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; source_uuids.push(source_uuid); let start = SystemTime::now(); let timestamp = start .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_millis(); let versions: HashMap<String, serde_json::Value> = document .as_object() .context("document must be an object")? .iter() .try_fold(HashMap::new(), |mut acc, (key, value)| { let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); let md5_digest = format!("{md5_digest:x}"); acc.insert( key.to_owned(), serde_json::json!({ "last_updated": timestamp, "md5": md5_digest }), ); anyhow::Ok(acc) })?; let versions = serde_json::to_value(versions)?; query = query.bind(source_uuid).bind(document).bind(versions); } let query = query.replace( "{values_parameters}", &query_values.chars().skip(1).collect::<String>(), ); let query = query.replace( "{binding_parameter}", &format!("${binding_parameter_counter}"), ); let results: Vec<(i64, Option<Json>)> = query .bind(source_uuids) .fetch_all(&mut *transaction) .await?; let mut query = sqlx::query_as(&query); let mut source_uuids = vec![]; for document in &batch { let id = document .get("id") .context("`id` must be a key in document")? .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; source_uuids.push(source_uuid); let start = SystemTime::now(); let timestamp = start .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_millis(); let versions: HashMap<String, serde_json::Value> = document .as_object() .context("document must be an object")? .iter() .try_fold(HashMap::new(), |mut acc, (key, value)| { let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); let md5_digest = format!("{md5_digest:x}"); acc.insert( key.to_owned(), serde_json::json!({ "last_updated": timestamp, "md5": md5_digest }), ); anyhow::Ok(acc) })?; let versions = serde_json::to_value(versions)?; let dp: Vec<(i64, Json, Option<Json>)> = results .into_iter() .zip(batch) .map(|((id, previous_document), document)| { (id, document.to_owned(), previous_document) query = query.bind(source_uuid).bind(document).bind(versions); } let results: Vec<(i64, Option<Json>)> = query .bind(source_uuids) .fetch_all(&mut *transaction) .await?; let dp: Vec<(i64, Json, Option<Json>)> = results .into_iter() .zip(batch) .map(|((id, previous_document), document)| (id, document.to_owned(), previous_document)) .collect(); for (pipeline, parsed_schema) in &mut pipelines { let ids_to_run_on: Vec<i64> = dp .iter() .filter(|(_, document, previous_document)| match previous_document { Some(previous_document) => parsed_schema .iter() .any(|(key, _)| document[key] != previous_document[key]), None => true, }) .map(|(document_id, _, _)| *document_id) .collect(); for (pipeline, parsed_schema) in &mut pipelines { let ids_to_run_on: Vec<i64> = dp .iter() .filter(|(_, document, previous_document)| match previous_document { Some(previous_document) => parsed_schema .iter() .any(|(key, _)| document[key] != previous_document[key]), None => true, }) .map(|(document_id, _, _)| *document_id) .collect(); if !ids_to_run_on.is_empty() { pipeline .sync_documents(ids_to_run_on, project_info, &mut transaction) .await .expect("Failed to execute pipeline"); } if !ids_to_run_on.is_empty() { pipeline .sync_documents(ids_to_run_on, project_info, &mut transaction) .await .expect("Failed to execute pipeline"); } transaction.commit().await?; progress_bar.inc(batch_size); } progress_bar.println("Done Upserting Documents\n"); progress_bar.finish() ;transaction.commit().await? ; Ok(()) } Expand Down