Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit3e8cc28

Browse files
authored
Add streaming (#1145)
1 parent37a888f commit3e8cc28

File tree

4 files changed

+278
-58
lines changed

4 files changed

+278
-58
lines changed

‎pgml-extension/src/api.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::str::FromStr;
44
use ndarray::Zip;
55
use pgrx::iter::{SetOfIterator,TableIterator};
66
use pgrx::*;
7+
use pyo3::prelude::*;
8+
use pyo3::types::{IntoPyDict,PyDict};
79

810
#[cfg(feature ="python")]
911
use serde_json::json;
@@ -632,6 +634,75 @@ pub fn transform_string(
632634
}
633635
}
634636

637+
structTransformStreamIterator{
638+
locals:Py<PyDict>,
639+
}
640+
641+
implTransformStreamIterator{
642+
fnnew(python_iter:Py<PyAny>) ->Self{
643+
let locals =Python::with_gil(|py| ->Result<Py<PyDict>,PyErr>{
644+
Ok([("python_iter", python_iter)].into_py_dict(py).into())
645+
})
646+
.map_err(|e|error!("{e}"))
647+
.unwrap();
648+
Self{ locals}
649+
}
650+
}
651+
652+
implIteratorforTransformStreamIterator{
653+
typeItem =String;
654+
fnnext(&mutself) ->Option<Self::Item>{
655+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
656+
Python::with_gil(|py| ->Result<Option<String>,PyErr>{
657+
let code ="next(python_iter)";
658+
let res:&PyAny = py.eval(code,Some(self.locals.as_ref(py)),None)?;
659+
if res.is_none(){
660+
Ok(None)
661+
}else{
662+
let res:String = res.extract()?;
663+
Ok(Some(res))
664+
}
665+
})
666+
.map_err(|e|error!("{e}"))
667+
.unwrap()
668+
}
669+
}
670+
671+
#[cfg(all(feature ="python", not(feature ="use_as_lib")))]
672+
#[pg_extern(immutable, parallel_safe, name ="transform_stream")]
673+
#[allow(unused_variables)]// cache is maintained for api compatibility
674+
pubfntransform_stream_json(
675+
task:JsonB,
676+
args:default!(JsonB,"'{}'"),
677+
input:default!(&str,"''"),
678+
cache:default!(bool,false),
679+
) ->SetOfIterator<'static,String>{
680+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
681+
let python_iter =crate::bindings::transformers::transform_stream(&task.0,&args.0, input)
682+
.map_err(|e|error!("{e}"))
683+
.unwrap();
684+
let res =TransformStreamIterator::new(python_iter);
685+
SetOfIterator::new(res)
686+
}
687+
688+
#[cfg(all(feature ="python", not(feature ="use_as_lib")))]
689+
#[pg_extern(immutable, parallel_safe, name ="transform_stream")]
690+
#[allow(unused_variables)]// cache is maintained for api compatibility
691+
pubfntransform_stream_string(
692+
task:String,
693+
args:default!(JsonB,"'{}'"),
694+
input:default!(&str,"''"),
695+
cache:default!(bool,false),
696+
) ->SetOfIterator<'static,String>{
697+
let task_json =json!({"task": task});
698+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
699+
let python_iter =crate::bindings::transformers::transform_stream(&task_json,&args.0, input)
700+
.map_err(|e|error!("{e}"))
701+
.unwrap();
702+
let res =TransformStreamIterator::new(python_iter);
703+
SetOfIterator::new(res)
704+
}
705+
635706
#[cfg(feature ="python")]
636707
#[pg_extern(immutable, parallel_safe, name ="generate")]
637708
fngenerate(project_name:&str,inputs:&str,config:default!(JsonB,"'{}'")) ->String{

‎pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,10 @@ use super::TracebackError;
1616

1717
pubmod whitelist;
1818

19-
create_pymodule!("/src/bindings/transformers/transformers.py");
20-
21-
pubfntransform(
22-
task:&serde_json::Value,
23-
args:&serde_json::Value,
24-
inputs:Vec<&str>,
25-
) ->Result<serde_json::Value>{
26-
crate::bindings::python::activate()?;
27-
28-
whitelist::verify_task(task)?;
29-
30-
let task = serde_json::to_string(task)?;
31-
let args = serde_json::to_string(args)?;
32-
let inputs = serde_json::to_string(&inputs)?;
19+
mod transformers;
20+
pubuse transformers::*;
3321

34-
let results =Python::with_gil(|py| ->Result<String>{
35-
let transform:Py<PyAny> =get_module!(PY_MODULE)
36-
.getattr(py,"transform")
37-
.format_traceback(py)?;
38-
39-
let output = transform
40-
.call1(
41-
py,
42-
PyTuple::new(
43-
py,
44-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
45-
),
46-
)
47-
.format_traceback(py)?;
48-
49-
output.extract(py).format_traceback(py)
50-
})?;
51-
52-
Ok(serde_json::from_str(&results)?)
53-
}
22+
create_pymodule!("/src/bindings/transformers/transformers.py");
5423

5524
pubfnget_model_from(task:&Value) ->Result<String>{
5625
Python::with_gil(|py| ->Result<String>{

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp