1+ use anyhow:: Context ;
12use rust_bridge:: { alias, alias_methods} ;
23use sqlx:: Row ;
34use tracing:: instrument;
@@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
1314#[ cfg( feature ="python" ) ]
1415use crate :: { query_runner:: QueryRunnerPython , types:: JsonPython } ;
1516
16- #[ alias_methods( new, query, transform) ]
17+ #[ alias_methods( new, query, transform, embed , embed_batch ) ]
1718impl Builtins {
1819pub fn new ( database_url : Option < String > ) ->Self {
1920Self { database_url}
@@ -87,6 +88,55 @@ impl Builtins {
8788let results = results. first ( ) . unwrap ( ) . get :: < serde_json:: Value , _ > ( 0 ) ;
8889Ok ( Json ( results) )
8990}
91+
92+ /// Run the built-in `pgml.embed()` function.
93+ ///
94+ /// # Arguments
95+ ///
96+ /// * `model` - The model to use.
97+ /// * `text` - The text to embed.
98+ ///
99+ pub async fn embed ( & self , model : & str , text : & str ) -> anyhow:: Result < Json > {
100+ let pool =get_or_initialize_pool ( & self . database_url ) . await ?;
101+ let query = sqlx:: query ( "SELECT embed FROM pgml.embed($1, $2)" ) ;
102+ let result = query. bind ( model) . bind ( text) . fetch_one ( & pool) . await ?;
103+ let result = result. get :: < Vec < f32 > , _ > ( 0 ) ;
104+ let result = serde_json:: to_value ( result) ?;
105+ Ok ( Json ( result) )
106+ }
107+
108+ /// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
109+ ///
110+ /// # Arguments
111+ ///
112+ /// * `model` - The model to use.
113+ /// * `texts` - The texts to embed.
114+ ///
115+ pub async fn embed_batch ( & self , model : & str , texts : Json ) -> anyhow:: Result < Json > {
116+ let texts = texts
117+ . 0
118+ . as_array ( )
119+ . with_context ( ||"embed_batch takes an array of strings" ) ?
120+ . into_iter ( )
121+ . map ( |v|{
122+ v. as_str ( )
123+ . with_context ( ||"only text embeddings are supported" )
124+ . unwrap ( )
125+ . to_string ( )
126+ } )
127+ . collect :: < Vec < String > > ( ) ;
128+ let pool =get_or_initialize_pool ( & self . database_url ) . await ?;
129+ let query = sqlx:: query ( "SELECT embed AS embed_batch FROM pgml.embed($1, $2)" ) ;
130+ let results = query
131+ . bind ( model)
132+ . bind ( texts)
133+ . fetch_all ( & pool)
134+ . await ?
135+ . into_iter ( )
136+ . map ( |embeddings| embeddings. get :: < Vec < f32 > , _ > ( 0 ) )
137+ . collect :: < Vec < Vec < f32 > > > ( ) ;
138+ Ok ( Json ( serde_json:: to_value ( results) ?) )
139+ }
90140}
91141
92142#[ cfg( test) ]
@@ -117,4 +167,28 @@ mod tests {
117167assert ! ( results. as_array( ) . is_some( ) ) ;
118168Ok ( ( ) )
119169}
170+
171+ #[ tokio:: test]
172+ async fn can_embed ( ) -> anyhow:: Result < ( ) > {
173+ internal_init_logger ( None , None ) . ok ( ) ;
174+ let builtins =Builtins :: new ( None ) ;
175+ let results = builtins. embed ( "intfloat/e5-small-v2" , "test" ) . await ?;
176+ assert ! ( results. as_array( ) . is_some( ) ) ;
177+ Ok ( ( ) )
178+ }
179+
180+ #[ tokio:: test]
181+ async fn can_embed_batch ( ) -> anyhow:: Result < ( ) > {
182+ internal_init_logger ( None , None ) . ok ( ) ;
183+ let builtins =Builtins :: new ( None ) ;
184+ let results = builtins
185+ . embed_batch (
186+ "intfloat/e5-small-v2" ,
187+ Json ( serde_json:: json!( [ "test" , "test2" , ] ) ) ,
188+ )
189+ . await ?;
190+ assert ! ( results. as_array( ) . is_some( ) ) ;
191+ assert_eq ! ( results. as_array( ) . unwrap( ) . len( ) , 2 ) ;
192+ Ok ( ( ) )
193+ }
120194}