Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

LLM fine-tuning#1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
santiatpml merged 36 commits intomasterfromsanti-llm-fine-tuning
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
36 commits
Select commitHold shift + click to select a range
e3bea27
fine-tuning text classification in progress
santiatpmlJan 31, 2024
c4cf332
More commit messages
santiatpmlFeb 1, 2024
fb7cc2a
Working text classification with dataset args and training args
santiatpmlFeb 6, 2024
5584487
finetuing with text dataset enum to handle different tasks
santiatpmlFeb 7, 2024
82cb4f7
text pair classification task support
santiatpmlFeb 7, 2024
c10de47
saving model after training
santiatpmlFeb 7, 2024
63ee09b
removed device to cpu
santiatpmlFeb 7, 2024
865ae28
updated transforemrs
Feb 8, 2024
097a8cf
Working e2e finetunig for two tasks
Feb 8, 2024
2dd50e6
Integration with huggingface hub and wandb
Feb 9, 2024
6ac8722
Conversation dataset + training placeholder
Feb 13, 2024
1e40cd8
Updated rust to fix failing tests
Feb 13, 2024
312d893
working version of conversation with lora + load 8bit + hf hub
Feb 13, 2024
afc2e93
Tested llama2-7b finetuning
Feb 22, 2024
22ee5c7
pypgrx first working version
Feb 27, 2024
97d455d
refactoring finetuning code to add callbacks
santiatpmlFeb 27, 2024
b700944
fixed merge conflicts
santiatpmlMar 5, 2024
65d2f8b
Refactored finetuning + conversation + pgml callbacks
Mar 2, 2024
5f1b5f4
removed wandb dependency
Mar 4, 2024
08084bf
removed local pypgrx from requirements
Mar 4, 2024
dc0c6ee
removed maturin from requirements
Mar 4, 2024
421af8f
removed flash attn
Mar 4, 2024
4bbca96
Added indent for info display
Mar 5, 2024
3db857c
Updated readme with LLM fine-tuning for text classification
santiatpmlMar 7, 2024
7cbee43
README updates
santiatpmlMar 7, 2024
9284cf1
Added a tutorial for 9 classes - draft 1
santiatpmlMar 8, 2024
66c65c8
README updates
santiatpmlMar 8, 2024
5759ee3
Moved python functions (#1374)
SilasMarvinMar 18, 2024
b539168
README updates
santiatpmlMar 19, 2024
31215b8
migrations and removed pypgrx
santiatpmlMar 20, 2024
dae6b74
Added r_log to take log level and message
santiatpmlMar 20, 2024
dae5ffc
Updated version and requirements
Mar 22, 2024
435f5bd
Changed version 2.8.3
Mar 22, 2024
aeb2683
README updates for conversation task fine-tuning using lora
santiatpmlMar 22, 2024
e5221cc
minor readme updates
santiatpmlMar 26, 2024
6db147e
added new line
santiatpmlMar 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
737 changes: 737 additions & 0 deletionsREADME.md
View file
Open in desktop

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletionspgml-extension/.gitignore
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -14,3 +14,5 @@
.DS_Store


# venv
pgml-venv
56 changes: 38 additions & 18 deletionspgml-extension/Cargo.lock
View file
Open in desktop

Some generated files are not rendered by default. Learn more abouthow customized files appear on GitHub.

8 changes: 4 additions & 4 deletionspgml-extension/Cargo.toml
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
[package]
name = "pgml"
version = "2.8.2"
version = "2.8.3"
edition = "2021"

[lib]
Expand DownExpand Up@@ -39,8 +39,8 @@ openblas-src = { version = "0.10", features = ["cblas", "system"] }
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
ndarray-stats = "0.5.1"
parking_lot = "0.12"
pgrx = "=0.11.2"
pgrx-pg-sys = "=0.11.2"
pgrx = "=0.11.3"
pgrx-pg-sys = "=0.11.3"
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rand = "0.8"
rmp-serde = { version = "1.1" }
Expand All@@ -51,7 +51,7 @@ typetag = "0.2"
xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" }

[dev-dependencies]
pgrx-tests = "=0.11.2"
pgrx-tests = "=0.11.3"

[build-dependencies]
vergen = { version = "8", features = ["build", "git", "gitcl"] }
Expand Down
25 changes: 23 additions & 2 deletionspgml-extension/requirements.linux.txt
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
accelerate==0.25.0
accelerate==0.27.2
aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
auto-gptq==0.6.0
bitsandbytes==0.41.3.post2
black==24.1.1
catboost==1.2.2
certifi==2023.11.17
charset-normalizer==3.3.2
Expand All@@ -20,13 +22,18 @@ dataclasses-json==0.6.3
datasets==2.15.0
deepspeed==0.12.5
dill==0.3.7
docker-pycreds==0.4.0
docstring-parser==0.15
einops==0.7.0
evaluate==0.4.1
exceptiongroup==1.2.0
filelock==3.13.1
fonttools==4.47.0
frozenlist==1.4.1
fsspec==2023.10.0
gekko==1.0.6
gitdb==4.0.11
GitPython==3.1.41
graphviz==0.20.1
greenlet==3.0.2
hjson==3.1.0
Expand All@@ -45,9 +52,11 @@ langchain-core==0.1.1
langsmith==0.0.72
lightgbm==4.1.0
lxml==4.9.3
markdown-it-py==3.0.0
MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
Expand All@@ -72,8 +81,10 @@ optimum==1.16.1
orjson==3.9.10
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
peft==0.7.1
Pillow==10.1.0
platformdirs==4.2.0
plotly==5.18.0
portalocker==2.8.2
protobuf==4.25.1
Expand All@@ -83,13 +94,16 @@ pyarrow==11.0.0
pyarrow-hotfix==0.6
pydantic==2.5.2
pydantic_core==2.14.5
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
regex==2023.10.3
requests==2.31.0
responses==0.18.0
rich==13.7.1
rouge==1.0.1
sacrebleu==2.4.0
sacremoses==0.1.1
Expand All@@ -98,23 +112,30 @@ scikit-learn==1.3.2
scipy==1.11.4
sentence-transformers==2.2.2
sentencepiece==0.1.99
sentry-sdk==1.40.2
setproctitle==1.3.3
shtab==1.6.5
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
SQLAlchemy==2.0.23
sympy==1.12
tabulate==0.9.0
tenacity==8.2.3
threadpoolctl==3.2.0
tokenizers==0.15.0
tomli==2.0.1
torch==2.1.2
torchaudio==2.1.2
torchvision==0.16.2
tqdm==4.66.1
transformers==4.38.0
transformers==4.38.2
transformers-stream-generator==0.0.4
triton==2.1.0
trl==0.7.10
typing-inspect==0.9.0
typing_extensions==4.9.0
tyro==0.7.2
tzdata==2023.3
urllib3==2.1.0
xformers==0.0.23.post1
Expand Down
12 changes: 12 additions & 0 deletionspgml-extension/sql/pgml--2.8.2--2.8.3.sql
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
-- Add conversation, text-pair-classification task type
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text-pair-classification';

-- Crate pgml.logs table
CREATE TABLE IF NOT EXISTS pgml.logs (
id SERIAL PRIMARY KEY,
model_id BIGINT,
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
34 changes: 17 additions & 17 deletionspgml-extension/src/api.rs
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -803,7 +803,7 @@ fn tune(
project_name: &str,
task: default!(Option<&str>, "NULL"),
relation_name: default!(Option<&str>, "NULL"),
y_column_name: default!(Option<&str>, "NULL"),
_y_column_name: default!(Option<&str>, "NULL"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why the underscore? Is it because it's not used?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

That's correct.

model_name: default!(Option<&str>, "NULL"),
hyperparams: default!(JsonB, "'{}'"),
test_size: default!(f32, 0.25),
Expand DownExpand Up@@ -861,9 +861,7 @@ fn tune(

let snapshot = Snapshot::create(
relation_name,
Some(vec![y_column_name
.expect("You must pass a `y_column_name` when you pass a `relation_name`")
.to_string()]),
None,
test_size,
test_sampling,
materialize_snapshot,
Expand All@@ -885,13 +883,14 @@ fn tune(
// algorithm will be transformers, stash the model_name in a hyperparam for v1 compatibility.
let mut hyperparams = hyperparams.0.as_object().unwrap().clone();
hyperparams.insert(String::from("model_name"), json!(model_name));
hyperparams.insert(String::from("project_name"), json!(project_name));
let hyperparams = JsonB(json!(hyperparams));

// # Default repeatable random state when possible
// let algorithm = Model.algorithm_from_name_and_task(algorithm, task);
// if "random_state" in algorithm().get_params() and "random_state" not in hyperparams:
// hyperparams["random_state"] = 0
let model = Model::tune(&project, &mut snapshot, &hyperparams);
let model = Model::finetune(&project, &mut snapshot, &hyperparams);
let new_metrics: &serde_json::Value = &model.metrics.unwrap().0;
let new_metrics = new_metrics.as_object().unwrap();

Expand All@@ -915,18 +914,19 @@ fn tune(
Some(true) | None => {
if let Ok(Some(deployed_metrics)) = deployed_metrics {
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
if project.task.value_is_better(
deployed_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
new_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
) {

let deployed_value = deployed_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails

// Get the value for the default target metric from new_metrics or provide a default value
let new_value = new_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails

if project.task.value_is_better(deployed_value, new_value) {
deploy = false;
}
}
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp