@@ -2,12 +2,15 @@ use std::fmt::Write;
22use std:: str:: FromStr ;
33
44use ndarray:: Zip ;
5+ use once_cell:: sync:: OnceCell ;
56use pgrx:: iter:: { SetOfIterator , TableIterator } ;
67use pgrx:: * ;
8+ use serde_json:: Value ;
79
810#[ cfg( feature ="python" ) ]
911use serde_json:: json;
1012
13+ use crate :: bindings:: vllm:: { LLMBuilder , LLM } ;
1114#[ cfg( feature ="python" ) ]
1215use crate :: orm:: * ;
1316
@@ -610,7 +613,7 @@ pub fn transform_json(
610613inputs : default ! ( Vec <& str >, "ARRAY[]::TEXT[]" ) ,
611614cache : default ! ( bool , false ) ,
612615) ->JsonB {
613- match crate :: bindings :: transformers :: transform ( & task. 0 , & args. 0 , inputs) {
616+ match transform ( task. 0 , args. 0 , inputs) {
614617Ok ( output) =>JsonB ( output) ,
615618Err ( e) =>error ! ( "{e}" ) ,
616619}
@@ -632,6 +635,34 @@ pub fn transform_string(
632635}
633636}
634637
638+ fn transform ( mut task : Value , args : Value , inputs : Vec < & str > ) -> anyhow:: Result < Value > {
639+ // use vLLM if model present in task and backend is set to vllm
640+ let use_vllm = task. as_object_mut ( ) . is_some_and ( |obj|{
641+ obj. contains_key ( "model" ) &&matches ! ( obj. get( "backend" ) , Some ( Value :: String ( backend) ) if backend. to_string( ) . to_ascii_lowercase( ) =="vllm" )
642+ } ) ;
643+
644+ if use_vllm{
645+ crate :: bindings:: python:: activate ( ) . unwrap ( ) ;
646+
647+ static LAZY_LLM : OnceCell < LLM > =OnceCell :: new ( ) ;
648+ let llm =LAZY_LLM . get_or_init ( move ||{
649+ let builder =match LLMBuilder :: try_from ( task) {
650+ Ok ( b) => b,
651+ Err ( e) =>error ! ( "{e}" ) ,
652+ } ;
653+ builder. build ( ) . unwrap ( )
654+ } ) ;
655+
656+ Ok ( json ! ( llm. generate( & inputs, None ) ?) )
657+ } else {
658+ if let Some ( map) = task. as_object_mut ( ) {
659+ // pop backend keyword, if present
660+ let _ = map. remove ( "backend" ) ;
661+ }
662+ crate :: bindings:: transformers:: transform ( & task, & args, inputs)
663+ }
664+ }
665+
635666#[ cfg( feature ="python" ) ]
636667#[ pg_extern( immutable, parallel_safe, name ="generate" ) ]
637668fn generate ( project_name : & str , inputs : & str , config : default ! ( JsonB , "'{}'" ) ) ->String {