Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Conflict between Lightning and Huggingface Transformers (device_map).#17878

richarddwang started this conversation inGeneral
Discussion options

Hugginface will addAlignDeviceHook under certain cases, intercept nn.modules' forward, and send input tensors to some devices that is automatically decided by their logics. It, in many cases, will conflict withdevices we set forL.Trainer.

For example,

When

  • Use quantization:
    e.g.
AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(                load_in_4bit=True,                bnb_4bit_quant_type="nf4",                bnb_4bit_use_double_quant=True,                bnb_4bit_compute_dtype=torch.bfloat16,), device_map='auto')
  • or usse PeftModel:
    e.g.
from peft import get_peft_config, get_peft_model, TaskTypeclass lightningmodule  def __init__    model = AutoModelForCausalLM.from_pretrained(...)    sef.model = get_peft_model(model , peft_config)

In these cases huggingface secretly send input tensors to cuda:0 or cuda:1, which conflicts withL.trainer(devices=[2]) that send model weights to cuda:2. All of this results in error of that tensors are at differnet devices.

This conflict make lightning hard to train hugginface transformers in many cases, currently I manually remove huggingface's hook that send tensors to differnet device by

def remove_hook_from_module(module: torch.nn.Module, recurse=False, hook_cls=AlignDevicesHook):    if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, hook_cls):        module._hf_hook.detach_hook(module)        delattr(module, "_hf_hook")        if hasattr(module, "_old_forward"):            module.forward = module._old_forward            delattr(module, "_old_forward")    if recurse:        for child in module.children():            remove_hook_from_module(child, recurse)    return module

wonder if there is better way for use lightning and huggingface together...

You must be logged in to vote

Replies: 4 comments 1 reply

Comment options

how can i use this function?

You must be logged in to vote
1 reply
@richarddwang
Comment options

fromaccelerate.hooksimportAlignDevicesHookdefremove_hook_from_module(module:torch.nn.Module,recurse=False,hook_cls=AlignDevicesHook):ifhasattr(module,"_hf_hook")andisinstance(module._hf_hook,hook_cls):module._hf_hook.detach_hook(module)delattr(module,"_hf_hook")ifhasattr(module,"_old_forward"):module.forward=module._old_forwarddelattr(module,"_old_forward")ifrecurse:forchildinmodule.children():remove_hook_from_module(child,recurse)returnmodulemodel=auto_cls.from_pretrained(model_name,**config_kwargs)remove_hook_from_module(model,recurse=True)
Comment options

Legend

You must be logged in to vote
0 replies
Comment options

Question, how can i make this one works with peft model as:

model = auto_cls.from_pretrained(model_name, **config_kwargs)model = get_peft_model(model, peft_config)remove_hook_from_module(model, recurse=True)

as for now getting this error:

AttributeError: 'PeftModelForCausalLM' object has no attribute '_hf_hook'

You must be logged in to vote
0 replies
Comment options

So if we specify the device with a number instead of a list, this problem can be avoided?

You must be logged in to vote
0 replies
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
General
Labels
None yet
5 participants
@richarddwang@qvuer7@leeglg@GoldenStain@AndriiZelenko

[8]ページ先頭

©2009-2025 Movatter.jp