11use anyhow:: Context ;
2- use futures:: Stream ;
32use rust_bridge:: { alias, alias_methods} ;
4- use sqlx:: { postgres:: PgRow , Row } ;
5- use sqlx:: { Postgres , Transaction } ;
6- use std:: collections:: VecDeque ;
7- use std:: future:: Future ;
8- use std:: pin:: Pin ;
9- use std:: task:: Poll ;
3+ use sqlx:: Row ;
104use tracing:: instrument;
115
126/// Provides access to builtin database methods
@@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json};
2216#[ cfg( feature ="python" ) ]
2317use crate :: types:: { GeneralJsonAsyncIteratorPython , JsonPython } ;
2418
25- #[ allow( clippy:: type_complexity) ]
26- struct TransformerStream {
27- transaction : Option < Transaction < ' static , Postgres > > ,
28- future : Option < Pin < Box < dyn Future < Output =Result < Vec < PgRow > , sqlx:: Error > > +Send +' static > > > ,
29- commit : Option < Pin < Box < dyn Future < Output =Result < ( ) , sqlx:: Error > > +Send +' static > > > ,
30- done : bool ,
31- query : String ,
32- db_batch_size : i32 ,
33- results : VecDeque < PgRow > ,
34- }
35-
36- impl std:: fmt:: Debug for TransformerStream {
37- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
38- f. debug_struct ( "TransformerStream" ) . finish ( )
39- }
40- }
41-
42- impl TransformerStream {
43- fn new ( transaction : Transaction < ' static , Postgres > , db_batch_size : i32 ) ->Self {
44- let query =format ! ( "FETCH {} FROM c" , db_batch_size) ;
45- Self {
46- transaction : Some ( transaction) ,
47- future : None ,
48- commit : None ,
49- done : false ,
50- query,
51- db_batch_size,
52- results : VecDeque :: new ( ) ,
53- }
54- }
55- }
56-
57- impl Stream for TransformerStream {
58- type Item = anyhow:: Result < Json > ;
59-
60- fn poll_next (
61- mut self : Pin < & mut Self > ,
62- cx : & mut std:: task:: Context < ' _ > ,
63- ) ->Poll < Option < Self :: Item > > {
64- if self . done {
65- if let Some ( c) =self . commit . as_mut ( ) {
66- if c. as_mut ( ) . poll ( cx) . is_ready ( ) {
67- self . commit =None ;
68- }
69- }
70- } else {
71- if self . future . is_none ( ) {
72- unsafe {
73- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
74- let s: * mut Self = s;
75- let s =Box :: leak ( Box :: from_raw ( s) ) ;
76- s. future =Some ( Box :: pin (
77- sqlx:: query ( & s. query ) . fetch_all ( & mut * * s. transaction . as_mut ( ) . unwrap ( ) ) ,
78- ) ) ;
79- }
80- }
81-
82- if let Poll :: Ready ( o) =self . as_mut ( ) . future . as_mut ( ) . unwrap ( ) . as_mut ( ) . poll ( cx) {
83- let rows = o?;
84- if rows. len ( ) <self . db_batch_size as usize {
85- self . done =true ;
86- unsafe {
87- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
88- let transaction = std:: mem:: take ( & mut s. transaction ) . unwrap ( ) ;
89- s. commit =Some ( Box :: pin ( transaction. commit ( ) ) ) ;
90- }
91- } else {
92- unsafe {
93- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
94- let s: * mut Self = s;
95- let s =Box :: leak ( Box :: from_raw ( s) ) ;
96- s. future =Some ( Box :: pin (
97- sqlx:: query ( & s. query ) . fetch_all ( & mut * * s. transaction . as_mut ( ) . unwrap ( ) ) ,
98- ) ) ;
99- }
100- }
101- for rin rows. into_iter ( ) {
102- self . results . push_back ( r)
103- }
104- }
105- }
106-
107- if !self . results . is_empty ( ) {
108- let r =self . results . pop_front ( ) . unwrap ( ) ;
109- Poll :: Ready ( Some ( Ok ( r. get :: < Json , _ > ( 0 ) ) ) )
110- } else if self . done {
111- Poll :: Ready ( None )
112- } else {
113- Poll :: Pending
114- }
115- }
116- }
117-
11819#[ alias_methods( new, transform, transform_stream) ]
11920impl TransformerPipeline {
12021/// Creates a new [TransformerPipeline]
@@ -200,7 +101,7 @@ impl TransformerPipeline {
200101) -> anyhow:: Result < GeneralJsonAsyncIterator > {
201102let pool =get_or_initialize_pool ( & self . database_url ) . await ?;
202103let args = args. unwrap_or_default ( ) ;
203- let batch_size = batch_size. unwrap_or ( 10 ) ;
104+ let batch_size = batch_size. unwrap_or ( 1 ) ;
204105
205106let mut transaction = pool. begin ( ) . await ?;
206107// We set the task in the new constructor so we can unwrap here
@@ -234,10 +135,37 @@ impl TransformerPipeline {
234135. await ?;
235136}
236137
237- Ok ( GeneralJsonAsyncIterator ( Box :: pin ( TransformerStream :: new (
238- transaction,
239- batch_size,
240- ) ) ) )
138+ let s = futures:: stream:: try_unfold ( transaction, move |mut transaction|async move {
139+ let query =format ! ( "FETCH {} FROM c" , batch_size) ;
140+ let mut res: Vec < Json > = sqlx:: query_scalar ( & query)
141+ . fetch_all ( & mut * transaction)
142+ . await ?;
143+ if !res. is_empty ( ) {
144+ if batch_size >1 {
145+ let res: Vec < String > = res
146+ . into_iter ( )
147+ . map ( |v|{
148+ v. 0 . as_array ( )
149+ . context ( "internal SDK error - cannot parse db value as array. Please post a new github issue" )
150+ . map ( |v|{
151+ v[ 0 ] . as_str ( )
152+ . context (
153+ "internal SDK error - cannot parse db value as string. Please post a new github issue" ,
154+ )
155+ . map ( |v| v. to_owned ( ) )
156+ } )
157+ } )
158+ . collect :: < anyhow:: Result < anyhow:: Result < Vec < String > > > > ( ) ??;
159+ Ok ( Some ( ( serde_json:: json!( res) . into ( ) , transaction) ) )
160+ } else {
161+ Ok ( Some ( ( std:: mem:: take ( & mut res[ 0 ] ) , transaction) ) )
162+ }
163+ } else {
164+ transaction. commit ( ) . await ?;
165+ Ok ( None )
166+ }
167+ } ) ;
168+ Ok ( GeneralJsonAsyncIterator ( Box :: pin ( s) ) )
241169}
242170}
243171
@@ -305,7 +233,7 @@ mod tests {
305233 serde_json:: json!( "AI is going to" ) . into ( ) ,
306234Some (
307235 serde_json:: json!( {
308- "max_new_tokens" : 10
236+ "max_new_tokens" : 30
309237} )
310238. into ( ) ,
311239) ,