Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

LLM fine-tuning#1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
santiatpml merged 36 commits intomasterfromsanti-llm-fine-tuning
Mar 26, 2024
Merged

LLM fine-tuning#1350

santiatpml merged 36 commits intomasterfromsanti-llm-fine-tuning
Mar 26, 2024

Conversation

santiatpml
Copy link
Contributor

@santiatpmlsantiatpml commentedMar 4, 2024
edited
Loading

  • Example:https://github.com/postgresml/postgresml/tree/santi-llm-fine-tuning?tab=readme-ov-file#llm-fine-tuning

  • Refactored TextDataSet to handle different NLP tasks

  • Three tasks: text classification, text pair classification, conversation

  • PEFT/LoRA for conversation task

  • Pypgrx for callbacks to print info statements and insert logs into pgml.logs table

  • New tasks have to be added to pgml.tasks:
    ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
    ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';

  • Newpgml.logs table has to be added:

CREATE TABLE pgml.logs (    id SERIAL PRIMARY KEY,    model_id BIGINT,    project_id BIGINT,    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,    logs JSONB);
SELECT pgml.tune(    'financial_phrasebank_sentiment',    task => 'text-classification',    relation_name => 'pgml.financial_phrasebank_view',    model_name => 'distilbert-base-uncased',    test_size => 0.2,    test_sampling => 'last',    hyperparams => '{        "training_args" : {          "learning_rate": 2e-5,          "per_device_train_batch_size": 16,          "per_device_eval_batch_size": 16,          "num_train_epochs": 10,          "weight_decay": 0.01,          "hub_token" : "token",          "push_to_hub" : true        },        "dataset_args" : {           "text_column" : "sentence",           "class_column" : "class"         }    }');
  • Text pair classification.
    Note: Training is initialized using a previous run and model from HF Hub.
SELECT pgml.tune(    'glue_mrpc_nli_2',    task => 'text_pair_classification',    relation_name => 'pgml.glue_view',    model_name => 'santiadavani/glue_mrpc_nli_2',    test_size => 0.5,    test_sampling => 'last',    hyperparams => '{        "training_args" : {            "learning_rate": 2e-5,            "per_device_train_batch_size": 16,            "per_device_eval_batch_size": 16,            "num_train_epochs": 1,            "weight_decay": 0.01        },        "dataset_args" : { "text1_column" : "sentence1", "text2_column" : "sentence2", "class_column" : "class" }    }');
  • Conversation
SELECT pgml.tune(    'alpaca-gpt4-conversation-llama2-7b-chat',    task => 'conversation',    relation_name => 'pgml.chat_sample',    model_name => 'meta-llama/Llama-2-7b-chat-hf',    test_size => 0.8,    test_sampling => 'last',    hyperparams => '{        "training_args" : {            "learning_rate": 2e-5,            "per_device_train_batch_size": 4,            "per_device_eval_batch_size": 4,            "num_train_epochs": 1,            "weight_decay": 0.01,            "hub_token" : "read_write_token",             "push_to_hub" : true,            "optim" : "adamw_bnb_8bit",            "gradient_accumulation_steps" : 4,            "gradient_checkpointing" : true        },        "dataset_args" : { "system_column" : "instruction", "user_column" : "input", "assistant_column" : "output" },        "lora_config" : {"r": 2, "lora_alpha" : 4, "lora_dropout" : 0.05, "bias": "none", "task_type": "CAUSAL_LM"},        "load_in_8bit" : false,        "token" : "read_token"    }');

fn insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> {

let id_value = Spi::get_one_with_args::<i64>(
"INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Did we include a migration for this table somewhere? We need to make sure it's created on all databases running PostgresML.

santiatpml reacted with thumbs up emoji
Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yes, need to add the following three to our migration once we freeze on the version number.

ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';CREATE TABLE IF NOT EXISTS pgml.logs (    id SERIAL PRIMARY KEY,    model_id BIGINT,    project_id BIGINT,    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,    logs JSONB);

MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.2
maturin==1.4.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Don't think you need maturin inside PostgresML deployments. This may be a "leak" from the pypgrx extension venv.

santiatpml reacted with thumbs up emoji
@@ -803,7 +803,7 @@ fn tune(
project_name: &str,
task: default!(Option<&str>, "NULL"),
relation_name: default!(Option<&str>, "NULL"),
y_column_name: default!(Option<&str>, "NULL"),
_y_column_name: default!(Option<&str>, "NULL"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why the underscore? Is it because it's not used?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

That's correct.

from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl.trainer import ConstantLengthDataset
from peft import LoraConfig, get_peft_model
from pypgrx import print_info, insert_logs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Need to make sure we either import this conditionally (only for fine tuning) and we include this in requirements.linux.txt. I didn't see a Mac OS build for this and for the M1/M2 architecture, we've been doing releases manually from our Macs (Github actions doesn't have M1 builders).

This makes me thing we should start cross-compiling soon. Rust supports this pretty well, maturin may need a patch.

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I couldn't get fine tuning to work on Mac OS. It keeps crashing. How about I check for the operating system and bail out if it is mac?
requirements.linux.txt is updated with trl and peft.

logs["step"] = state.global_step
logs["max_steps"] = state.max_steps
logs["timestamp"] = str(datetime.now())
print_info(json.dumps(logs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

If you use useprint(), this will appear in Postgres logs. It won't be pretty, but we can add a function that formats it correctly.

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I will add indent in json.dumps() to pretty print.

trainable_model_params += param.numel()

# Calculate and print the number and percentage of trainable parameters
print_info(f"Trainable model parameters: {trainable_model_params}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

@kczimm This will require us to use the main thread for ML workloads in our cloud.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

A PR with that is close. What's the reason we need main thread here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

We need logging visibility during fine tuning.

santiatpml reacted with thumbs up emoji
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Thanks to a commit by@levkk, we should be able to log from any thread.

santiatpmland others added19 commitsMarch 5, 2024 09:50
#######################


class PGMLCallback(TrainerCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I wouldn't be opposed to this functionality living in it's own file liketune.py, since transformers is getting a bit beefy.

kczimm reacted with thumbs up emoji
Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

transformers.py is hardcoded in several places. Needs some more refactoring and testing to accomplish moving finetuning code to tune.py. Will revisit this in the next iteration.#1378

self.model_id = model_id

def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why throw awaytotal_flos?

santiatpml reacted with thumbs up emoji
}

#[pyfunction]
fn print_info(info: String) -> PyResult<String> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think this would be more reusable aslog(level, msg)

santiatpml reacted with thumbs up emoji
else:
self.model_name = hyperparameters.pop("model_name")

if "token" in hyperparameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Isn't this a model init param, not a hyperparam, like many other things in this list? Maybe hyperparams covers everything?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

That's correct. Moved all the parameters to hyperparams.

trainable_model_params += param.numel()

# Calculate and print the number and percentage of trainable parameters
print_info(f"Trainable model parameters: {trainable_model_params}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

We need logging visibility during fine tuning.

santiatpml reacted with thumbs up emoji
y_train,
x_test,
y_test,
Ok::<std::option::Option<()>, i64>(Some(())) // this return type is nonsense
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

:)


let text1_column_value = dataset_args
.0
.get("text1_column")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

do we require these column names?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yes, for text pair classification - (natural language inference, qnli etc.), we need three columns - text1, text2 and the class.


let system_column_value = dataset_args
.0
.get("system_column")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

How standard are these names these days?

Copy link
ContributorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

For conversation task, system, user and assistant have become standard keys.

Ok(info)
}
/// A Python module implemented in Rust.
#[pymodule]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Since this crate is interdependent, what if we moved this whole pymodule into the main pgml-extension crate, under bindings/python/mod.rs instead of publishing it as a separate crate?

santiatpml reacted with thumbs up emoji
@@ -14,3 +14,5 @@
.DS_Store


# venv
pgml-venv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

newline

santiatpml reacted with thumbs up emoji
@santiatpmlsantiatpml requested a review fromlevkkMarch 26, 2024 19:54
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

new line

santiatpml reacted with thumbs up emoji
@santiatpmlsantiatpml merged commitf75114b intomasterMar 26, 2024
@santiatpmlsantiatpml deleted the santi-llm-fine-tuning branchMarch 26, 2024 20:31
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@montanalowmontanalowmontanalow left review comments

@levkklevkklevkk left review comments

@kczimmkczimmAwaiting requested review from kczimm

@SilasMarvinSilasMarvinAwaiting requested review from SilasMarvin

Assignees
No one assigned
Labels
None yet
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

5 participants
@santiatpml@montanalow@kczimm@levkk@SilasMarvin

[8]ページ先頭

©2009-2025 Movatter.jp