- Notifications
You must be signed in to change notification settings - Fork2.2k
Description
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/trainer.py", line 1628, in train
return inner_training_loop(
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/trainer.py", line 1895, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/trainer.py", line 2637, in training_step
loss = self.compute_loss(model, inputs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/trainer.py", line 2669, in compute_loss
outputs = model(**inputs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/_utils.py", line 543, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/peft/peft_model.py", line 529, in forward
return self.base_model(
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/accelerate/hooks.py", line 158, in new_forward
output = old_forward(*args, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 852, in forward
outputs = self.model.decoder(
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/accelerate/hooks.py", line 158, in new_forward
output = old_forward(*args, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 616, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
outputs = run_function(*args)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 612, in custom_forward
return module(*inputs, output_attentions, None)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/accelerate/hooks.py", line 158, in new_forward
output = old_forward(*args, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 305, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/accelerate/hooks.py", line 158, in new_forward
output = old_forward(*args, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 167, in forward
value_states = self.v_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/peft/tuners/lora.py", line 522, in forward
result = super().forward(x)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 242, in forward
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 488, in matmul
return MatMul8bitLt.apply(A, B, out, bias, state)
File "/opt/miniconda3/envs/textgen/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 397, in forward
output += torch.matmul(subA, state.subB)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x7 and 8x4096)