Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[example] Added gemma support#115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
Chillee wants to merge2 commits intomain
base:main
Choose a base branch
Loading
fromgemma
Open

[example] Added gemma support#115

Chillee wants to merge2 commits intomainfromgemma

Conversation

@Chillee
Copy link
Contributor

No description provided.

yifuwang, yanboliang, BoyuanFeng, and petrutionut2001-spec reacted with rocket emoji
@facebook-github-botfacebook-github-bot added the CLA SignedThis label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labelFeb 29, 2024
@ChilleeChillee changed the titleAdded gemma support[example] Added gemma supportMar 1, 2024
@shaahji
Copy link

Can you extend support for gemma-7b as well?

@shaahji
Copy link

shaahji commentedMar 3, 2024
edited
Loading

convert_hf_checkpoint.py fails with following error for gemma-7b -

Model config {'block_size': 2048, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'dim': 3072, 'intermediate_size': 24576, 'n_local_heads': 16, 'head_dim': 192, 'rope_base': 10000, 'norm_eps': 1e-05}Traceback (most recent call last):  File "scripts/convert_hf_checkpoint.py", line 111, in <module>    convert_hf_checkpoint(  File "......../site-packages/torch/utils/_contextlib.py", line 115, in decorate_context    return func(*args, **kwargs)  File "scripts/convert_hf_checkpoint.py", line 91, in convert_hf_checkpoint    q = permute(q, config.n_head)  File "scripts/convert_hf_checkpoint.py", line 62, in permute    w.view(n_head, 2, config.head_dim // 2, dim)RuntimeError: shape '[16, 2, 96, 3072]' is invalid for input of size 12582912

The issue is withModelArgs.head_dim being computed as 192 but HF config dictates it to be 256. I tried by forcing it but then it fails during inferencing with the following error for each layer -

size mismatch for layers.0.attention.wo.weight: copying a param with shape torch.Size([3072, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 3072]).

@guangy10
Copy link

Same error as@shaahji saw above. I tried with config{'dim': 3072, 'vocab_size': 256000, 'n_layer': 28, 'n_head': 16, 'n_local_heads': 16, 'intermediate_size': 24576} according to gemma-7b/config.json

{  "architectures": [    "GemmaForCausalLM"  ],  "attention_bias": false,  "attention_dropout": 0.0,  "bos_token_id": 2,  "eos_token_id": 1,  "head_dim": 256,  "hidden_act": "gelu",  "hidden_size": 3072,  "initializer_range": 0.02,  "intermediate_size": 24576,  "max_position_embeddings": 8192,  "model_type": "gemma",  "num_attention_heads": 16,  "num_hidden_layers": 28,  "num_key_value_heads": 16,  "pad_token_id": 0,  "rms_norm_eps": 1e-06,  "rope_scaling": null,  "rope_theta": 10000.0,  "torch_dtype": "bfloat16",  "transformers_version": "4.38.0.dev0",  "use_cache": true,  "vocab_size": 256000}

It seems like we force setting'head_dim':256, the 'dim' will be bumped to 4096. Also I've no idea where the number "12582912" comes from. It's like a puzzle to figure out how those numbers are mapped and determined.@Chillee could you elaborate?

@ChilleeChilleeforce-pushed thegemma branch 2 times, most recently from2b7f921 tofc64185CompareMarch 7, 2024 18:57
@Chillee
Copy link
ContributorAuthor

I added support for gemma-7b. The main non-trivial component here was thathead_dim * n_heads != dim, so some parts of the model definition needed to be patched.

I'm getting 83 tok/s for fp16

cc:@guangy10@shaahji

guangy10 reacted with thumbs up emoji

@facebook-github-bot

Hi@Chillee!

Thank you for your pull request.

Werequire contributors to sign ourContributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to beresubmitted.

Process

In order for us to review and merge your suggested changes, please sign athttps://code.facebook.com/cla.If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, thepull request will be tagged withCLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us atcla@meta.com. Thanks!

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@msaroufimmsaroufimmsaroufim left review comments

+1 more reviewer

@guangy10guangy10guangy10 approved these changes

Reviewers whose approvals may not affect merge requirements

Assignees

No one assigned

Labels

CLA SignedThis label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

6 participants

@Chillee@shaahji@guangy10@facebook-github-bot@msaroufim

[8]ページ先頭

©2009-2025 Movatter.jp