Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit384921f

Browse files
authored
separate embed model creation and usage (#1022)
1 parent3b088a4 commit384921f

File tree

6 files changed

+33
-32
lines changed

6 files changed

+33
-32
lines changed

‎pgml-extension/src/bindings/langchain/mod.rs‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use anyhow::Result;
2-
use once_cell::sync::Lazy;
32
use pgrx::*;
43
use pyo3::prelude::*;
54
use pyo3::types::PyTuple;
65

7-
usecrate::{bindings::TracebackError,create_pymodule};
6+
usecrate::create_pymodule;
87

98
create_pymodule!("/src/bindings/langchain/langchain.py");
109

‎pgml-extension/src/bindings/python/mod.rs‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
//! Use virtualenv.
22
33
use anyhow::Result;
4-
use once_cell::sync::Lazy;
54
use pgrx::iter::TableIterator;
65
use pgrx::*;
76
use pyo3::prelude::*;
87
use pyo3::types::PyTuple;
98

109
usecrate::config::get_config;
11-
usecrate::{bindings::TracebackError,create_pymodule};
10+
usecrate::create_pymodule;
1211

1312
staticCONFIG_NAME:&str ="pgml.venv";
1413

‎pgml-extension/src/bindings/sklearn/mod.rs‎

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@ use pgrx::*;
1111
use std::collections::HashMap;
1212

1313
use anyhow::Result;
14-
use once_cell::sync::Lazy;
1514
use pyo3::prelude::*;
1615
use pyo3::types::PyTuple;
1716

18-
usecrate::{
19-
bindings::{Bindings,TracebackError},
20-
create_pymodule,
21-
orm::*,
22-
};
17+
usecrate::{bindings::Bindings, create_pymodule, orm::*};
2318

2419
create_pymodule!("/src/bindings/sklearn/sklearn.py");
2520

‎pgml-extension/src/bindings/transformers/mod.rs‎

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::str::FromStr;
44
use std::{collections::HashMap, path::Path};
55

66
use anyhow::{anyhow, bail,Context,Result};
7-
use once_cell::sync::Lazy;
87
use pgrx::*;
98
use pyo3::prelude::*;
109
use pyo3::types::PyTuple;
@@ -47,22 +46,22 @@ pub fn transform(
4746
)
4847
.format_traceback(py)?;
4948

50-
Ok(output.extract(py).format_traceback(py)?)
49+
output.extract(py).format_traceback(py)
5150
})?;
5251

5352
Ok(serde_json::from_str(&results)?)
5453
}
5554

5655
pubfnget_model_from(task:&Value) ->Result<String>{
57-
Ok(Python::with_gil(|py| ->Result<String>{
56+
Python::with_gil(|py| ->Result<String>{
5857
let get_model_from =get_module!(PY_MODULE)
5958
.getattr(py,"get_model_from")
6059
.format_traceback(py)?;
6160
let model = get_model_from
6261
.call1(py,PyTuple::new(py,&[task.to_string().into_py(py)]))
6362
.format_traceback(py)?;
64-
Ok(model.extract(py).format_traceback(py)?)
65-
})?)
63+
model.extract(py).format_traceback(py)
64+
})
6665
}
6766

6867
pubfnembed(
@@ -91,7 +90,7 @@ pub fn embed(
9190
)
9291
.format_traceback(py)?;
9392

94-
Ok(output.extract(py).format_traceback(py)?)
93+
output.extract(py).format_traceback(py)
9594
})
9695
}
9796

@@ -126,7 +125,7 @@ pub fn tune(
126125
)
127126
.format_traceback(py)?;
128127

129-
Ok(output.extract(py).format_traceback(py)?)
128+
output.extract(py).format_traceback(py)
130129
})
131130
}
132131

@@ -176,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
176175
}
177176
Ok(o) => o,
178177
};
179-
Ok(result.extract(py).format_traceback(py)?)
178+
result.extract(py).format_traceback(py)
180179
})
181180
}
182181

@@ -227,7 +226,7 @@ pub fn load_dataset(
227226
let load_dataset:Py<PyAny> =get_module!(PY_MODULE)
228227
.getattr(py,"load_dataset")
229228
.format_traceback(py)?;
230-
Ok(load_dataset
229+
load_dataset
231230
.call1(
232231
py,
233232
PyTuple::new(
@@ -242,7 +241,7 @@ pub fn load_dataset(
242241
)
243242
.format_traceback(py)?
244243
.extract(py)
245-
.format_traceback(py)?)
244+
.format_traceback(py)
246245
})?;
247246

248247
let table_name =format!("pgml.\"{}\"", name);

‎pgml-extension/src/bindings/transformers/transformers.py‎

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,29 +241,38 @@ def transform(task, args, inputs):
241241
returnorjson.dumps(pipe(inputs,**args),default=orjson_default).decode()
242242

243243

244-
defembed(transformer,inputs,kwargs):
245-
kwargs=orjson.loads(kwargs)
244+
defcreate_embedding(transformer):
245+
instructor=transformer.startswith("hkunlp/instructor")
246+
klass=INSTRUCTORifinstructorelseSentenceTransformer
247+
returnklass(transformer)
248+
249+
250+
defembed_using(model,transformer,inputs,kwargs):
251+
ifisinstance(kwargs,str):
252+
kwargs=orjson.loads(kwargs)
246253

247-
ensure_device(kwargs)
248254
instructor=transformer.startswith("hkunlp/instructor")
249-
250255
ifinstructor:
251-
klass=INSTRUCTOR
252-
253256
texts_with_instructions= []
254257
instruction=kwargs.pop("instruction")
255258
fortextininputs:
256259
texts_with_instructions.append([instruction,text])
257260

258261
inputs=texts_with_instructions
259-
else:
260-
klass=SentenceTransformer
262+
263+
returnmodel.encode(inputs,**kwargs)
264+
265+
266+
defembed(transformer,inputs,kwargs):
267+
kwargs=orjson.loads(kwargs)
268+
269+
ensure_device(kwargs)
261270

262271
iftransformernotin__cache_sentence_transformer_by_name:
263-
__cache_sentence_transformer_by_name[transformer]=klass(transformer)
272+
__cache_sentence_transformer_by_name[transformer]=create_embedding(transformer)
264273
model=__cache_sentence_transformer_by_name[transformer]
265274

266-
returnmodel.encode(inputs,**kwargs)
275+
returnembed_using(model,transformer,inputs,kwargs)
267276

268277

269278
defclear_gpu_cache(memory_usage:None):

‎pgml-extension/src/orm/model.rs‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ impl Model {
378378
Ok(())
379379
})?;
380380

381-
Ok(model.ok_or_else(||{
381+
model.ok_or_else(||{
382382
anyhow!(
383383
"pgml.models WHERE id = {:?} could not be loaded. Does it exist?",
384384
id
385385
)
386-
})?)
386+
})
387387
}
388388

389389
pubfnfind_cached(id:i64) ->Result<Arc<Model>>{

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp