Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitec03371

Browse files
authored
update ultravox model and config for v0.5 (#276)
1 parent5c4c45e commitec03371

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

‎ultravox/model/ultravox_config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class LoraConfigSimplified:
1919
target_modules:Optional[List[str]]=dataclasses.field(
2020
default_factory=lambda: ["k_proj","q_proj","linear_k","linear_q"]
2121
)
22+
# A list of module names regex patterns to unfreeze. Only used if r == 0.
23+
unfreeze_layers:Optional[List[str]]=None
2224

2325

2426
classLossFunction(str,Enum):
@@ -28,7 +30,7 @@ class LossFunction(str, Enum):
2830

2931
@dataclasses.dataclass
3032
classLossConfig:
31-
loss_function:LossFunction=LossFunction.KL_Divergence
33+
loss_function:LossFunction=LossFunction.CrossEntropy
3234
kl_temperature:float=2.0
3335

3436
@property
@@ -70,7 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
7072
Example:
7173
7274
```python
73-
>>> from transformers importUltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
75+
>>> from transformers importUltravoxModel, Wav2Vec2Config, UltravoxConfig, LlamaConfig
7476
7577
>>> # Initializing an audio encoder config
7678
>>> audio_config = Wav2Vec2Config()
@@ -82,7 +84,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
8284
>>> configuration = UltravoxConfig(audio_config, text_config)
8385
8486
>>> # Initializing a completely untrained model from the configuration
85-
>>> model =UltravoxForConditionalGeneration(configuration)
87+
>>> model =UltravoxModel(configuration)
8688
8789
>>> # Accessing the model configuration
8890
>>> configuration = model.config
@@ -105,6 +107,7 @@ def __init__(
105107
stack_factor:int=8,
106108
norm_init:float=0.4,
107109
projector_act:str="swiglu",
110+
projector_ln_mid:bool=False,# defaults to False for compatibility with v0.4.1 and below
108111
text_model_lora_config:Optional[LoraConfigSimplified]=None,
109112
audio_model_lora_config:Optional[LoraConfigSimplified]=None,
110113
audio_latency_block_size:Optional[int]=None,
@@ -119,7 +122,7 @@ def __init__(
119122
self.stack_factor=stack_factor
120123
self.norm_init=norm_init
121124
self.projector_act=projector_act
122-
125+
self.projector_ln_mid=projector_ln_mid
123126
iftext_model_idisnotNone:
124127
self.text_config:transformers.LlamaConfig= (
125128
transformers.AutoConfig.from_pretrained(text_model_id)

‎ultravox/model/ultravox_config_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ def test_can_load_release(model_id: str):
1515
config_from_dict=ultravox_config.UltravoxConfig(**orig_config.to_dict())
1616
config_from_diff_dict=ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())
1717
# To not inadvertently ignore other keys, we explicitly define keys we require to ignore.
18-
keys_to_ignore= ("audio_latency_block_size",)
19-
orig_values= {
20-
**{k:Noneforkinkeys_to_ignore},
21-
**orig_config.to_dict(),
22-
}
18+
new_keys_default= {"audio_latency_block_size":None,"projector_ln_mid":False}
19+
orig_values= {**new_keys_default,**orig_config.to_dict()}
2320

2421
assertconfig_from_dict.to_dict()==orig_values
2522
assertconfig_from_diff_dict.to_dict()==orig_values

‎ultravox/model/ultravox_model.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
importlogging
2+
importre
23
fromtypingimportAny,Dict,Optional,Set,Tuple,Union
34

45
importpeft
@@ -36,6 +37,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
3637
config:UltravoxConfig# for type hinting
3738
# Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
3839
_keys_to_ignore_on_load_missing= ["audio_tower.*","language_model.*"]
40+
# Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
41+
# see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
42+
accepts_loss_kwargs=False
3943

4044
def__init__(self,config:UltravoxConfig):
4145
super().__init__(config)
@@ -283,7 +287,7 @@ def _create_audio_tower(
283287
cls,config:UltravoxConfig
284288
)->Union[transformers.Wav2Vec2Model,"ModifiedWhisperEncoder"]:
285289
ifconfig.audio_model_idisnotNone:
286-
if"whisper"inconfig.audio_model_idisnotNone:
290+
if"whisper"inconfig.audio_model_id.lower():
287291
audio_tower=ModifiedWhisperEncoder.from_pretrained(
288292
config.audio_model_id,torch_dtype=config.torch_dtype
289293
)
@@ -299,7 +303,7 @@ def _create_audio_tower(
299303
config.audio_model_id,torch_dtype=config.torch_dtype
300304
)
301305
else:
302-
if"whisper"inconfig.audio_config._name_or_path:
306+
if"whisper"inconfig.audio_config._name_or_path.lower():
303307
audio_tower=ModifiedWhisperEncoder(config.audio_config)
304308
audio_tower.init_latency_mask(
305309
config.audio_latency_block_size,dtype=config.torch_dtype
@@ -384,12 +388,11 @@ def merge_and_unload(self):
384388

385389
defpush_to_hub(self,*args,**kwargs):
386390
self.merge_and_unload()
387-
self.to(self.language_model.dtype)
388391
returnsuper().push_to_hub(*args,**kwargs)
389392

390-
defsave_pretrained(
391-
self,*args,state_dict:Optional[Dict[str,Any]]=None,**kwargs
392-
):
393+
defdiff_state_dict(
394+
self,state_dict:Optional[Dict[str,Any]]=None
395+
)->Dict[str,Any]:
393396
ifstate_dictisNone:
394397
state_dict=super().state_dict()
395398

@@ -402,6 +405,13 @@ def save_pretrained(
402405
or (kinnamed_paramsandnamed_params[k].requires_grad)
403406
}
404407

408+
returnstate_dict
409+
410+
defsave_pretrained(
411+
self,*args,state_dict:Optional[Dict[str,Any]]=None,**kwargs
412+
):
413+
state_dict=self.diff_state_dict(state_dict)
414+
405415
super().save_pretrained(*args,state_dict=state_dict,**kwargs)
406416

407417
def_pre_load_state_dict_hook(self,state_dict:Dict[str,Any],*args,**kwargs):
@@ -436,6 +446,7 @@ def print_trainable_parameters(self):
436446
)
437447

438448

449+
# TODO: refactor common parts to a shared module
439450
defis_cache_empty(
440451
past_key_values:Optional[Union[Tuple,transformers.cache_utils.Cache]]
441452
)->bool:
@@ -453,12 +464,18 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
453464
"""
454465
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
455466
"""
467+
unfreeze_layers=lora_config.pop("unfreeze_layers",None)
456468
lora_config=peft.LoraConfig(**lora_configor {})
457469

458470
iflora_config.r==0:
459-
# freeze the model entirely
460-
forparaminmodel.parameters():
461-
param.requires_grad=False
471+
# freeze the model entirely, except for the specified layers
472+
forname,paraminmodel.named_parameters():
473+
ifnotunfreeze_layersornotany(
474+
re.match(layer,name)forlayerinunfreeze_layers
475+
):
476+
param.requires_grad=False
477+
else:
478+
logging.info(f"Unfreezing layer:{name} with #{param.numel()} params")
462479
else:
463480
model=peft.get_peft_model(model,lora_config)
464481

@@ -502,25 +519,35 @@ def forward(self, x):
502519
returnF.silu(gate)*x
503520

504521

505-
classUltravoxProjector(nn.Sequential):
522+
classUltravoxProjector(nn.Module):
506523
def__init__(self,config:UltravoxConfig):
507524
super().__init__()
508525
self.hidden_dim=config.hidden_size
509526
self._pad_and_stack=StackAudioFrames(config.stack_factor)
510-
dim=config.audio_config.hidden_size*config.stack_factor
511-
self.ln_pre=RMSNorm(dim,init=config.norm_init)
512-
self.linear_1=nn.Linear(dim,self.hidden_dim,bias=False)
513-
dim=self.hidden_dim
527+
dim_in=config.audio_config.hidden_size*config.stack_factor
528+
self.ln_pre=RMSNorm(dim_in,init=config.norm_init)
529+
self.linear_1=nn.Linear(dim_in,self.hidden_dim,bias=False)
530+
dim_mid=self.hidden_dim
514531
self.act=transformers.activations.get_activation(config.projector_act)
515-
dim=dim//2ifconfig.projector_act=="swiglu"elsedim
516-
self.linear_2=nn.Linear(dim,config.text_config.hidden_size,bias=False)
517-
self.ln_post=RMSNorm(config.text_config.hidden_size,init=config.norm_init)
532+
dim_mid=dim_mid//2ifconfig.projector_act=="swiglu"elsedim_mid
533+
dim_out=config.text_config.hidden_size
534+
self.linear_2=nn.Linear(dim_mid,dim_out,bias=False)
535+
536+
# Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
537+
# while v0.5.0 and above uses layer_norm after the first linear layer.
538+
ifconfig.projector_ln_mid:
539+
self.ln_mid:nn.Module=RMSNorm(dim_mid,init=config.norm_init)
540+
self.ln_post:nn.Module=nn.Identity()
541+
else:
542+
self.ln_mid=nn.Identity()
543+
self.ln_post=RMSNorm(dim_out,init=config.norm_init)
518544

519545
defforward(self,audio_features:torch.Tensor)->torch.Tensor:
520546
audio_features=self._pad_and_stack(audio_features)
521547
audio_features=self.ln_pre(audio_features)
522548
hidden_states=self.linear_1(audio_features)
523549
hidden_states=self.act(hidden_states)
550+
hidden_states=self.ln_mid(hidden_states)
524551
hidden_states=self.linear_2(hidden_states)
525552
hidden_states=self.ln_post(hidden_states)
526553
returnhidden_states
@@ -544,6 +571,10 @@ class ModifiedWhisperEncoder(
544571
base_model_prefix="model.encoder"
545572
_no_split_modules= ["WhisperEncoderLayer"]
546573

574+
def__init__(self,config:transformers.WhisperConfig):
575+
super().__init__(config)
576+
self.config.is_decoder=False
577+
547578
definit_latency_mask(self,audio_latency_block_size:int,dtype:torch.dtype):
548579
ifaudio_latency_block_sizeisNone:
549580
self.audio_streaming_mask=None

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp