Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfb5b502

Browse files
authored
Added TransformerPipeline (#1128)
1 parentf2e4517 commitfb5b502

File tree

3 files changed

+106
-72
lines changed

3 files changed

+106
-72
lines changed

‎pgml-sdks/pgml/src/lib.rs‎

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod query_builder;
2626
mod query_runner;
2727
mod remote_embeddings;
2828
mod splitter;
29+
mod transformer_pipeline;
2930
pubmod types;
3031
mod utils;
3132

@@ -35,6 +36,7 @@ pub use collection::Collection;
3536
pubuse model::Model;
3637
pubuse pipeline::Pipeline;
3738
pubuse splitter::Splitter;
39+
pubuse transformer_pipeline::TransformerPipeline;
3840

3941
// This is use when inserting collections to set the sdk_version used during creation
4042
staticSDK_VERSION:&str ="0.9.2";
@@ -149,6 +151,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
149151
m.add_class::<model::ModelPython>()?;
150152
m.add_class::<splitter::SplitterPython>()?;
151153
m.add_class::<builtins::BuiltinsPython>()?;
154+
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
152155
Ok(())
153156
}
154157

@@ -193,6 +196,10 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
193196
cx.export_function("newModel", model::ModelJavascript::new)?;
194197
cx.export_function("newSplitter", splitter::SplitterJavascript::new)?;
195198
cx.export_function("newBuiltins", builtins::BuiltinsJavascript::new)?;
199+
cx.export_function(
200+
"newTransformerPipeline",
201+
transformer_pipeline::TransformerPipelineJavascript::new,
202+
)?;
196203
cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?;
197204
Ok(())
198205
}
@@ -448,7 +455,6 @@ mod tests {
448455
Some("text-embedding-ada-002".to_string()),
449456
Some("openai".to_string()),
450457
None,
451-
None,
452458
);
453459
let splitter =Splitter::default();
454460
letmut pipeline =Pipeline::new(
@@ -527,7 +533,6 @@ mod tests {
527533
Some("hkunlp/instructor-base".to_string()),
528534
Some("python".to_string()),
529535
Some(json!({"instruction":"Represent the Wikipedia document for retrieval: "}).into()),
530-
None,
531536
);
532537
let splitter =Splitter::default();
533538
letmut pipeline =Pipeline::new(
@@ -579,7 +584,6 @@ mod tests {
579584
Some("text-embedding-ada-002".to_string()),
580585
Some("openai".to_string()),
581586
None,
582-
None,
583587
);
584588
let splitter =Splitter::default();
585589
letmut pipeline =Pipeline::new(
@@ -660,7 +664,6 @@ mod tests {
660664
Some("text-embedding-ada-002".to_string()),
661665
Some("openai".to_string()),
662666
None,
663-
None,
664667
);
665668
let splitter =Splitter::default();
666669
letmut pipeline =Pipeline::new(

‎pgml-sdks/pgml/src/model.rs‎

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use anyhow::Context;
22
use rust_bridge::{alias, alias_methods};
3-
use serde_json::json;
43
use sqlx::postgres::PgPool;
5-
use sqlx::Row;
64
use tracing::instrument;
75

86
usecrate::{
@@ -61,14 +59,11 @@ pub struct Model {
6159
pubparameters:Json,
6260
project_info:Option<ProjectInfo>,
6361
pub(crate)database_data:Option<ModelDatabaseData>,
64-
// This database_url is specifically used only for the model when calling transform and other
65-
// one-off methods
66-
database_url:Option<String>,
6762
}
6863

6964
implDefaultforModel{
7065
fndefault() ->Self{
71-
Self::new(None,None,None,None)
66+
Self::new(None,None,None)
7267
}
7368
}
7469

@@ -88,12 +83,7 @@ impl Model {
8883
/// use pgml::Model;
8984
/// let model = Model::new(Some("intfloat/e5-small".to_string()), None, None, None);
9085
/// ```
91-
pubfnnew(
92-
name:Option<String>,
93-
source:Option<String>,
94-
parameters:Option<Json>,
95-
database_url:Option<String>,
96-
) ->Self{
86+
pubfnnew(name:Option<String>,source:Option<String>,parameters:Option<Json>) ->Self{
9787
let name = name.unwrap_or("intfloat/e5-small".to_string());
9888
let parameters = parameters.unwrap_or(Json(serde_json::json!({})));
9989
let source = source.unwrap_or("pgml".to_string());
@@ -105,7 +95,6 @@ impl Model {
10595
parameters,
10696
project_info:None,
10797
database_data:None,
108-
database_url,
10998
}
11099
}
111100

@@ -191,30 +180,6 @@ impl Model {
191180
.database_url;
192181
get_or_initialize_pool(database_url).await
193182
}
194-
195-
pubasyncfntransform(
196-
&self,
197-
task:&str,
198-
inputs:Vec<String>,
199-
args:Option<Json>,
200-
) -> anyhow::Result<Json>{
201-
let pool =get_or_initialize_pool(&self.database_url).await?;
202-
let task =json!({
203-
"task": task,
204-
"model":self.name,
205-
"trust_remote_code":true
206-
});
207-
let args = args.unwrap_or_default();
208-
let query = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)");
209-
let results = query
210-
.bind(task)
211-
.bind(inputs)
212-
.bind(&args)
213-
.fetch_all(&pool)
214-
.await?;
215-
let results = results.get(0).unwrap().get::<serde_json::Value,_>(0);
216-
Ok(Json(results))
217-
}
218183
}
219184

220185
implFrom<models::PipelineWithModelAndSplitter>forModel{
@@ -228,7 +193,6 @@ impl From<models::PipelineWithModelAndSplitter> for Model {
228193
id: x.model_id,
229194
created_at: x.model_created_at,
230195
}),
231-
database_url:None,
232196
}
233197
}
234198
}
@@ -244,36 +208,6 @@ impl From<models::Model> for Model {
244208
id: model.id,
245209
created_at: model.created_at,
246210
}),
247-
database_url:None,
248211
}
249212
}
250213
}
251-
252-
#[cfg(test)]
253-
mod tests{
254-
usesuper::*;
255-
usecrate::internal_init_logger;
256-
257-
#[sqlx::test]
258-
asyncfnmodel_can_transform() -> anyhow::Result<()>{
259-
internal_init_logger(None,None).ok();
260-
let model =Model::new(
261-
Some("Helsinki-NLP/opus-mt-en-fr".to_string()),
262-
Some("pgml".to_string()),
263-
None,
264-
None,
265-
);
266-
let results = model
267-
.transform(
268-
"translation",
269-
vec![
270-
"How are you doing today?".to_string(),
271-
"What is a good song?".to_string(),
272-
],
273-
None,
274-
)
275-
.await?;
276-
assert!(results.as_array().is_some());
277-
Ok(())
278-
}
279-
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use rust_bridge::{alias, alias_methods};
2+
use sqlx::Row;
3+
use tracing::instrument;
4+
5+
/// Provides access to builtin database methods
6+
#[derive(alias,Debug,Clone)]
7+
pubstructTransformerPipeline{
8+
task:Json,
9+
database_url:Option<String>,
10+
}
11+
12+
usecrate::{get_or_initialize_pool, types::Json};
13+
14+
#[cfg(feature ="python")]
15+
usecrate::types::JsonPython;
16+
17+
#[alias_methods(new, transform)]
18+
implTransformerPipeline{
19+
pubfnnew(
20+
task:&str,
21+
model:Option<String>,
22+
args:Option<Json>,
23+
database_url:Option<String>,
24+
) ->Self{
25+
letmut args = args.unwrap_or_default();
26+
let a = args.as_object_mut().expect("args must be an object");
27+
a.insert("task".to_string(), task.to_string().into());
28+
ifletSome(m) = model{
29+
a.insert("model".to_string(), m.into());
30+
}
31+
32+
Self{
33+
task: args,
34+
database_url,
35+
}
36+
}
37+
38+
#[instrument(skip(self))]
39+
pubasyncfntransform(&self,inputs:Vec<String>,args:Option<Json>) -> anyhow::Result<Json>{
40+
let pool =get_or_initialize_pool(&self.database_url).await?;
41+
let args = args.unwrap_or_default();
42+
43+
let results = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)")
44+
.bind(&self.task)
45+
.bind(inputs)
46+
.bind(&args)
47+
.fetch_all(&pool)
48+
.await?;
49+
let results = results.get(0).unwrap().get::<serde_json::Value,_>(0);
50+
Ok(Json(results))
51+
}
52+
}
53+
54+
#[cfg(test)]
55+
mod tests{
56+
usesuper::*;
57+
usecrate::internal_init_logger;
58+
59+
#[sqlx::test]
60+
asyncfntransformer_pipeline_can_transform() -> anyhow::Result<()>{
61+
internal_init_logger(None,None).ok();
62+
let t =TransformerPipeline::new(
63+
"translation_en_to_fr",
64+
Some("t5-base".to_string()),
65+
None,
66+
None,
67+
);
68+
let results = t
69+
.transform(
70+
vec![
71+
"How are you doing today?".to_string(),
72+
"What is a good song?".to_string(),
73+
],
74+
None,
75+
)
76+
.await?;
77+
assert!(results.as_array().is_some());
78+
Ok(())
79+
}
80+
81+
#[sqlx::test]
82+
asyncfntransformer_pipeline_can_transform_with_default_model() -> anyhow::Result<()>{
83+
internal_init_logger(None,None).ok();
84+
let t =TransformerPipeline::new("translation_en_to_fr",None,None,None);
85+
let results = t
86+
.transform(
87+
vec![
88+
"How are you doing today?".to_string(),
89+
"What is a good song?".to_string(),
90+
],
91+
None,
92+
)
93+
.await?;
94+
assert!(results.as_array().is_some());
95+
Ok(())
96+
}
97+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp