Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5759ee3

Browse files
authored
Moved python functions (#1374)
1 parent66c65c8 commit5759ee3

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

‎pgml-extension/src/bindings/mod.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,31 @@ use std::fmt::Debug;
33
use anyhow::{anyhow,Result};
44
#[allow(unused_imports)]// used for test macros
55
use pgrx::*;
6-
use pyo3::{PyResult,Python};
6+
use pyo3::{pyfunction,PyResult,Python};
77

88
usecrate::orm::*;
99

10+
#[pyfunction]
11+
fnr_insert_logs(project_id:i64,model_id:i64,logs:String) ->PyResult<String>{
12+
let id_value =Spi::get_one_with_args::<i64>(
13+
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
14+
vec![
15+
(PgBuiltInOids::INT8OID.oid(), project_id.into_datum()),
16+
(PgBuiltInOids::INT8OID.oid(), model_id.into_datum()),
17+
(PgBuiltInOids::TEXTOID.oid(), logs.into_datum()),
18+
],
19+
)
20+
.unwrap()
21+
.unwrap();
22+
Ok(format!("Inserted logs with id: {}", id_value))
23+
}
24+
25+
#[pyfunction]
26+
fnr_print_info(info:String) ->PyResult<String>{
27+
info!("{}", info);
28+
Ok(info)
29+
}
30+
1031
#[cfg(feature ="python")]
1132
#[macro_export]
1233
macro_rules! create_pymodule{
@@ -16,11 +37,11 @@ macro_rules! create_pymodule {
1637
pyo3::Python::with_gil(|py| -> anyhow::Result<pyo3::Py<pyo3::types::PyModule>>{
1738
use $crate::bindings::TracebackError;
1839
let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile));
19-
Ok(
20-
pyo3::types::PyModule::from_code(py, src,"transformers.py","__main__")
21-
.format_traceback(py)?
22-
.into(),
23-
)
40+
let module = pyo3::types::PyModule::from_code(py, src,"transformers.py","__main__")
41+
.format_traceback(py)?;
42+
module.add_function(wrap_pyfunction!($crate::bindings::r_insert_logs, module)?)?;
43+
module.add_function(wrap_pyfunction!($crate::bindings::r_print_info, module)?)?;
44+
Ok(module.into())
2445
})
2546
});
2647
};

‎pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
fromtrlimportSFTTrainer,DataCollatorForCompletionOnlyLM
5656
fromtrl.trainerimportConstantLengthDataset
5757
frompeftimportLoraConfig,get_peft_model
58-
frompypgrximportprint_info,insert_logs
5958
fromabcimportabstractmethod
6059

6160
transformers.logging.set_verbosity_info()
@@ -1017,8 +1016,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
10171016
logs["step"]=state.global_step
10181017
logs["max_steps"]=state.max_steps
10191018
logs["timestamp"]=str(datetime.now())
1020-
print_info(json.dumps(logs,indent=4))
1021-
insert_logs(self.project_id,self.model_id,json.dumps(logs))
1019+
r_print_info(json.dumps(logs,indent=4))
10221020

10231021

10241022
classFineTuningBase:
@@ -1100,9 +1098,9 @@ def print_number_of_trainable_model_parameters(self, model):
11001098
trainable_model_params+=param.numel()
11011099

11021100
# Calculate and print the number and percentage of trainable parameters
1103-
print_info(f"Trainable model parameters:{trainable_model_params}")
1104-
print_info(f"All model parameters:{all_model_params}")
1105-
print_info(
1101+
r_print_info(f"Trainable model parameters:{trainable_model_params}")
1102+
r_print_info(f"All model parameters:{all_model_params}")
1103+
r_print_info(
11061104
f"Percentage of trainable model parameters:{100*trainable_model_params/all_model_params:.2f}%"
11071105
)
11081106

@@ -1398,7 +1396,7 @@ def __init__(
13981396
"bias":"none",
13991397
"task_type":"CAUSAL_LM",
14001398
}
1401-
print_info(
1399+
r_print_info(
14021400
"LoRA configuration are not set. Using default parameters"
14031401
+json.dumps(self.lora_config_params)
14041402
)
@@ -1465,7 +1463,7 @@ def formatting_prompts_func(example):
14651463
peft_config=LoraConfig(**self.lora_config_params),
14661464
callbacks=[PGMLCallback(self.project_id,self.model_id)],
14671465
)
1468-
print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
1466+
r_print_info("Creating Supervised Fine Tuning trainer done. Training ... ")
14691467

14701468
# Train
14711469
self.trainer.train()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp