Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit79f9833

Browse files
authored
organize python related modules (#962)
1 parent6bdcf00 commit79f9833

File tree

14 files changed

+127
-120
lines changed

14 files changed

+127
-120
lines changed

‎pgml-extension/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ sacremoses==0.0.53
1717
scikit-learn==1.3.0
1818
sentencepiece==0.1.99
1919
sentence-transformers==2.2.2
20+
tokenizers==0.13.3
2021
torch==2.0.1
2122
torchaudio==2.0.2
2223
torchvision==0.15.2

‎pgml-extension/src/api.rs

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
77

88
#[cfg(feature ="python")]
9-
use pyo3::prelude::*;
109
use serde_json::json;
1110

1211
#[cfg(feature ="python")]
13-
usecrate::bindings::sklearn::package_version;
1412
usecrate::orm::*;
1513

1614
macro_rules! unwrap_or_error{
@@ -25,38 +23,13 @@ macro_rules! unwrap_or_error {
2523
#[cfg(feature ="python")]
2624
#[pg_extern]
2725
pubfnactivate_venv(venv:&str) ->bool{
28-
unwrap_or_error!(crate::bindings::venv::activate_venv(venv))
26+
unwrap_or_error!(crate::bindings::python::activate_venv(venv))
2927
}
3028

3129
#[cfg(feature ="python")]
3230
#[pg_extern(immutable, parallel_safe)]
3331
pubfnvalidate_python_dependencies() ->bool{
34-
unwrap_or_error!(crate::bindings::venv::activate());
35-
36-
Python::with_gil(|py|{
37-
let sys =PyModule::import(py,"sys").unwrap();
38-
let version:String = sys.getattr("version").unwrap().extract().unwrap();
39-
info!("Python version: {version}");
40-
for modulein["xgboost","lightgbm","numpy","sklearn"]{
41-
match py.import(module){
42-
Ok(_) =>(),
43-
Err(e) =>{
44-
panic!(
45-
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
46-
);
47-
}
48-
}
49-
}
50-
});
51-
52-
let sklearn =unwrap_or_error!(package_version("sklearn"));
53-
let xgboost =unwrap_or_error!(package_version("xgboost"));
54-
let lightgbm =unwrap_or_error!(package_version("lightgbm"));
55-
let numpy =unwrap_or_error!(package_version("numpy"));
56-
57-
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
58-
59-
true
32+
unwrap_or_error!(crate::bindings::python::validate_dependencies())
6033
}
6134

6235
#[cfg(not(feature ="python"))]
@@ -66,8 +39,7 @@ pub fn validate_python_dependencies() {}
6639
#[cfg(feature ="python")]
6740
#[pg_extern]
6841
pubfnpython_package_version(name:&str) ->String{
69-
unwrap_or_error!(crate::bindings::venv::activate());
70-
unwrap_or_error!(package_version(name))
42+
unwrap_or_error!(crate::bindings::python::package_version(name))
7143
}
7244

7345
#[cfg(not(feature ="python"))]
@@ -79,13 +51,19 @@ pub fn python_package_version(name: &str) {
7951
#[cfg(feature ="python")]
8052
#[pg_extern]
8153
pubfnpython_pip_freeze() ->TableIterator<'static,(name!(package,String),)>{
82-
unwrap_or_error!(crate::bindings::venv::activate());
54+
unwrap_or_error!(crate::bindings::python::pip_freeze())
55+
}
8356

84-
let packages =unwrap_or_error!(crate::bindings::venv::freeze())
85-
.into_iter()
86-
.map(|package|(package,));
57+
#[cfg(feature ="python")]
58+
#[pg_extern]
59+
fnpython_version() ->String{
60+
unwrap_or_error!(crate::bindings::python::version())
61+
}
8762

88-
TableIterator::new(packages)
63+
#[cfg(not(feature ="python"))]
64+
#[pg_extern]
65+
pubfnpython_version() ->String{
66+
String::from("Python is not installed, recompile with `--features python`")
8967
}
9068

9169
#[pg_extern]
@@ -104,26 +82,6 @@ pub fn validate_shared_library() {
10482
}
10583
}
10684

107-
#[cfg(feature ="python")]
108-
#[pg_extern]
109-
fnpython_version() ->String{
110-
unwrap_or_error!(crate::bindings::venv::activate());
111-
letmut version =String::new();
112-
113-
Python::with_gil(|py|{
114-
let sys =PyModule::import(py,"sys").unwrap();
115-
version = sys.getattr("version").unwrap().extract().unwrap();
116-
});
117-
118-
version
119-
}
120-
121-
#[cfg(not(feature ="python"))]
122-
#[pg_extern]
123-
pubfnpython_version() ->String{
124-
String::from("Python is not installed, recompile with `--features python`")
125-
}
126-
12785
#[pg_extern(immutable, parallel_safe)]
12886
fnversion() ->String{
12987
crate::VERSION.to_string()

‎pgml-extension/src/bindings/langchain.rsrenamed to ‎pgml-extension/src/bindings/langchain/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ use pyo3::types::PyTuple;
66

77
usecrate::{bindings::TracebackError, create_pymodule};
88

9-
create_pymodule!("/src/bindings/langchain.py");
9+
create_pymodule!("/src/bindings/langchain/langchain.py");
1010

1111
pubfnchunk(splitter:&str,text:&str,kwargs:&serde_json::Value) ->Result<Vec<String>>{
12-
crate::bindings::venv::activate()?;
12+
crate::bindings::python::activate()?;
1313

1414
let kwargs = serde_json::to_string(kwargs).unwrap();
1515

‎pgml-extension/src/bindings/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ pub mod langchain;
3838
pubmod lightgbm;
3939
pubmod linfa;
4040
#[cfg(feature ="python")]
41+
pubmod python;
42+
#[cfg(feature ="python")]
4143
pubmod sklearn;
4244
#[cfg(feature ="python")]
4345
pubmod transformers;
44-
#[cfg(feature ="python")]
45-
pubmod venv;
4646
pubmod xgboost;
4747

4848
pubtypeFit =fn(dataset:&Dataset,hyperparams:&Hyperparams) ->Result<Box<dynBindings>>;
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//! Use virtualenv.
2+
3+
use anyhow::Result;
4+
use once_cell::sync::Lazy;
5+
use pgrx::iter::TableIterator;
6+
use pgrx::*;
7+
use pyo3::prelude::*;
8+
use pyo3::types::PyTuple;
9+
10+
usecrate::config::get_config;
11+
usecrate::{bindings::TracebackError, create_pymodule};
12+
13+
staticCONFIG_NAME:&str ="pgml.venv";
14+
15+
create_pymodule!("/src/bindings/python/python.py");
16+
17+
pubfnactivate_venv(venv:&str) ->Result<bool>{
18+
Python::with_gil(|py|{
19+
let activate_venv:Py<PyAny> =get_module!(PY_MODULE).getattr(py,"activate_venv")?;
20+
let result:Py<PyAny> =
21+
activate_venv.call1(py,PyTuple::new(py,&[venv.to_string().into_py(py)]))?;
22+
23+
Ok(result.extract(py)?)
24+
})
25+
}
26+
27+
pubfnactivate() ->Result<bool>{
28+
matchget_config(CONFIG_NAME){
29+
Some(venv) =>activate_venv(&venv),
30+
None =>Ok(false),
31+
}
32+
}
33+
34+
pubfnpip_freeze() ->Result<TableIterator<'static,(name!(package,String),)>>{
35+
activate()?;
36+
let packages =Python::with_gil(|py| ->Result<Vec<String>>{
37+
let freeze =get_module!(PY_MODULE).getattr(py,"freeze")?;
38+
let result = freeze.call0(py)?;
39+
40+
Ok(result.extract(py)?)
41+
})?;
42+
43+
Ok(TableIterator::new(
44+
packages.into_iter().map(|package|(package,)),
45+
))
46+
}
47+
48+
pubfnvalidate_dependencies() ->Result<bool>{
49+
activate()?;
50+
Python::with_gil(|py|{
51+
let sys =PyModule::import(py,"sys").unwrap();
52+
let version:String = sys.getattr("version").unwrap().extract().unwrap();
53+
info!("Python version: {version}");
54+
for modulein["xgboost","lightgbm","numpy","sklearn"]{
55+
match py.import(module){
56+
Ok(_) =>(),
57+
Err(e) =>{
58+
panic!(
59+
"The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"
60+
);
61+
}
62+
}
63+
}
64+
});
65+
66+
let sklearn =package_version("sklearn")?;
67+
let xgboost =package_version("xgboost")?;
68+
let lightgbm =package_version("lightgbm")?;
69+
let numpy =package_version("numpy")?;
70+
71+
info!("Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}",);
72+
73+
Ok(true)
74+
}
75+
76+
pubfnversion() ->Result<String>{
77+
activate()?;
78+
Python::with_gil(|py|{
79+
let sys =PyModule::import(py,"sys").unwrap();
80+
let version:String = sys.getattr("version").unwrap().extract().unwrap();
81+
Ok(version)
82+
})
83+
}
84+
85+
pubfnpackage_version(name:&str) ->Result<String>{
86+
activate()?;
87+
Python::with_gil(|py|{
88+
let package = py.import(name)?;
89+
Ok(package.getattr("__version__")?.extract()?)
90+
})
91+
}

‎pgml-extension/src/bindings/sklearn.rsrenamed to ‎pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ use once_cell::sync::Lazy;
1515
use pyo3::prelude::*;
1616
use pyo3::types::PyTuple;
1717

18-
usecrate::bindings::Bindings;
18+
usecrate::{
19+
bindings::{Bindings,TracebackError},
20+
create_pymodule,
21+
orm::*,
22+
};
1923

20-
usecrate::{bindings::TracebackError, create_pymodule, orm::*};
21-
22-
create_pymodule!("/src/bindings/sklearn.py");
24+
create_pymodule!("/src/bindings/sklearn/sklearn.py");
2325

2426
macro_rules! wrap_fit{
2527
($fn_name:tt, $task:literal) =>{
@@ -355,10 +357,3 @@ pub fn cluster_metrics(
355357
Ok(scores)
356358
})
357359
}
358-
359-
pubfnpackage_version(name:&str) ->Result<String>{
360-
Python::with_gil(|py|{
361-
let package = py.import(name)?;
362-
Ok(package.getattr("__version__")?.extract()?)
363-
})
364-
}

‎pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn transform(
2424
args:&serde_json::Value,
2525
inputs:Vec<&str>,
2626
) ->Result<serde_json::Value>{
27-
crate::bindings::venv::activate()?;
27+
crate::bindings::python::activate()?;
2828

2929
whitelist::verify_task(task)?;
3030

@@ -70,7 +70,7 @@ pub fn embed(
7070
inputs:Vec<&str>,
7171
kwargs:&serde_json::Value,
7272
) ->Result<Vec<Vec<f32>>>{
73-
crate::bindings::venv::activate()?;
73+
crate::bindings::python::activate()?;
7474

7575
let kwargs = serde_json::to_string(kwargs)?;
7676
Python::with_gil(|py| ->Result<Vec<Vec<f32>>>{
@@ -101,7 +101,7 @@ pub fn tune(
101101
hyperparams:&JsonB,
102102
path:&Path,
103103
) ->Result<HashMap<String,f64>>{
104-
crate::bindings::venv::activate()?;
104+
crate::bindings::python::activate()?;
105105

106106
let task = task.to_string();
107107
let hyperparams = serde_json::to_string(&hyperparams.0)?;
@@ -131,7 +131,7 @@ pub fn tune(
131131
}
132132

133133
pubfngenerate(model_id:i64,inputs:Vec<&str>,config:JsonB) ->Result<Vec<String>>{
134-
crate::bindings::venv::activate()?;
134+
crate::bindings::python::activate()?;
135135

136136
Python::with_gil(|py| ->Result<Vec<String>>{
137137
let generate =get_module!(PY_MODULE)
@@ -219,7 +219,7 @@ pub fn load_dataset(
219219
limit:Option<usize>,
220220
kwargs:&serde_json::Value,
221221
) ->Result<usize>{
222-
crate::bindings::venv::activate()?;
222+
crate::bindings::python::activate()?;
223223

224224
let kwargs = serde_json::to_string(kwargs)?;
225225

@@ -376,7 +376,7 @@ pub fn load_dataset(
376376
}
377377

378378
pubfnclear_gpu_cache(memory_usage:Option<f32>) ->Result<bool>{
379-
crate::bindings::venv::activate().unwrap();
379+
crate::bindings::python::activate().unwrap();
380380

381381
Python::with_gil(|py| ->Result<bool>{
382382
let clear_gpu_cache:Py<PyAny> =get_module!(PY_MODULE)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp