1
1
use anyhow:: Context ;
2
- use futures:: Stream ;
3
2
use rust_bridge:: { alias, alias_methods} ;
4
- use sqlx:: { postgres:: PgRow , Row } ;
5
- use sqlx:: { Postgres , Transaction } ;
6
- use std:: collections:: VecDeque ;
7
- use std:: future:: Future ;
8
- use std:: pin:: Pin ;
9
- use std:: task:: Poll ;
3
+ use sqlx:: Row ;
10
4
use tracing:: instrument;
11
5
12
6
/// Provides access to builtin database methods
@@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json};
22
16
#[ cfg( feature ="python" ) ]
23
17
use crate :: types:: { GeneralJsonAsyncIteratorPython , JsonPython } ;
24
18
25
- #[ allow( clippy:: type_complexity) ]
26
- struct TransformerStream {
27
- transaction : Option < Transaction < ' static , Postgres > > ,
28
- future : Option < Pin < Box < dyn Future < Output =Result < Vec < PgRow > , sqlx:: Error > > +Send +' static > > > ,
29
- commit : Option < Pin < Box < dyn Future < Output =Result < ( ) , sqlx:: Error > > +Send +' static > > > ,
30
- done : bool ,
31
- query : String ,
32
- db_batch_size : i32 ,
33
- results : VecDeque < PgRow > ,
34
- }
35
-
36
- impl std:: fmt:: Debug for TransformerStream {
37
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
38
- f. debug_struct ( "TransformerStream" ) . finish ( )
39
- }
40
- }
41
-
42
- impl TransformerStream {
43
- fn new ( transaction : Transaction < ' static , Postgres > , db_batch_size : i32 ) ->Self {
44
- let query =format ! ( "FETCH {} FROM c" , db_batch_size) ;
45
- Self {
46
- transaction : Some ( transaction) ,
47
- future : None ,
48
- commit : None ,
49
- done : false ,
50
- query,
51
- db_batch_size,
52
- results : VecDeque :: new ( ) ,
53
- }
54
- }
55
- }
56
-
57
- impl Stream for TransformerStream {
58
- type Item = anyhow:: Result < Json > ;
59
-
60
- fn poll_next (
61
- mut self : Pin < & mut Self > ,
62
- cx : & mut std:: task:: Context < ' _ > ,
63
- ) ->Poll < Option < Self :: Item > > {
64
- if self . done {
65
- if let Some ( c) =self . commit . as_mut ( ) {
66
- if c. as_mut ( ) . poll ( cx) . is_ready ( ) {
67
- self . commit =None ;
68
- }
69
- }
70
- } else {
71
- if self . future . is_none ( ) {
72
- unsafe {
73
- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
74
- let s: * mut Self = s;
75
- let s =Box :: leak ( Box :: from_raw ( s) ) ;
76
- s. future =Some ( Box :: pin (
77
- sqlx:: query ( & s. query ) . fetch_all ( & mut * * s. transaction . as_mut ( ) . unwrap ( ) ) ,
78
- ) ) ;
79
- }
80
- }
81
-
82
- if let Poll :: Ready ( o) =self . as_mut ( ) . future . as_mut ( ) . unwrap ( ) . as_mut ( ) . poll ( cx) {
83
- let rows = o?;
84
- if rows. len ( ) <self . db_batch_size as usize {
85
- self . done =true ;
86
- unsafe {
87
- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
88
- let transaction = std:: mem:: take ( & mut s. transaction ) . unwrap ( ) ;
89
- s. commit =Some ( Box :: pin ( transaction. commit ( ) ) ) ;
90
- }
91
- } else {
92
- unsafe {
93
- let s =self . as_mut ( ) . get_unchecked_mut ( ) ;
94
- let s: * mut Self = s;
95
- let s =Box :: leak ( Box :: from_raw ( s) ) ;
96
- s. future =Some ( Box :: pin (
97
- sqlx:: query ( & s. query ) . fetch_all ( & mut * * s. transaction . as_mut ( ) . unwrap ( ) ) ,
98
- ) ) ;
99
- }
100
- }
101
- for rin rows. into_iter ( ) {
102
- self . results . push_back ( r)
103
- }
104
- }
105
- }
106
-
107
- if !self . results . is_empty ( ) {
108
- let r =self . results . pop_front ( ) . unwrap ( ) ;
109
- Poll :: Ready ( Some ( Ok ( r. get :: < Json , _ > ( 0 ) ) ) )
110
- } else if self . done {
111
- Poll :: Ready ( None )
112
- } else {
113
- Poll :: Pending
114
- }
115
- }
116
- }
117
-
118
19
#[ alias_methods( new, transform, transform_stream) ]
119
20
impl TransformerPipeline {
120
21
/// Creates a new [TransformerPipeline]
@@ -200,7 +101,7 @@ impl TransformerPipeline {
200
101
) -> anyhow:: Result < GeneralJsonAsyncIterator > {
201
102
let pool =get_or_initialize_pool ( & self . database_url ) . await ?;
202
103
let args = args. unwrap_or_default ( ) ;
203
- let batch_size = batch_size. unwrap_or ( 10 ) ;
104
+ let batch_size = batch_size. unwrap_or ( 1 ) ;
204
105
205
106
let mut transaction = pool. begin ( ) . await ?;
206
107
// We set the task in the new constructor so we can unwrap here
@@ -234,10 +135,37 @@ impl TransformerPipeline {
234
135
. await ?;
235
136
}
236
137
237
- Ok ( GeneralJsonAsyncIterator ( Box :: pin ( TransformerStream :: new (
238
- transaction,
239
- batch_size,
240
- ) ) ) )
138
+ let s = futures:: stream:: try_unfold ( transaction, move |mut transaction|async move {
139
+ let query =format ! ( "FETCH {} FROM c" , batch_size) ;
140
+ let mut res: Vec < Json > = sqlx:: query_scalar ( & query)
141
+ . fetch_all ( & mut * transaction)
142
+ . await ?;
143
+ if !res. is_empty ( ) {
144
+ if batch_size >1 {
145
+ let res: Vec < String > = res
146
+ . into_iter ( )
147
+ . map ( |v|{
148
+ v. 0 . as_array ( )
149
+ . context ( "internal SDK error - cannot parse db value as array. Please post a new github issue" )
150
+ . map ( |v|{
151
+ v[ 0 ] . as_str ( )
152
+ . context (
153
+ "internal SDK error - cannot parse db value as string. Please post a new github issue" ,
154
+ )
155
+ . map ( |v| v. to_owned ( ) )
156
+ } )
157
+ } )
158
+ . collect :: < anyhow:: Result < anyhow:: Result < Vec < String > > > > ( ) ??;
159
+ Ok ( Some ( ( serde_json:: json!( res) . into ( ) , transaction) ) )
160
+ } else {
161
+ Ok ( Some ( ( std:: mem:: take ( & mut res[ 0 ] ) , transaction) ) )
162
+ }
163
+ } else {
164
+ transaction. commit ( ) . await ?;
165
+ Ok ( None )
166
+ }
167
+ } ) ;
168
+ Ok ( GeneralJsonAsyncIterator ( Box :: pin ( s) ) )
241
169
}
242
170
}
243
171
@@ -305,7 +233,7 @@ mod tests {
305
233
serde_json:: json!( "AI is going to" ) . into ( ) ,
306
234
Some (
307
235
serde_json:: json!( {
308
- "max_new_tokens" : 10
236
+ "max_new_tokens" : 30
309
237
} )
310
238
. into ( ) ,
311
239
) ,