Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Raise error if disk is full before downloading weights#1903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
rasbt wants to merge8 commits intomain
base:main
Choose a base branch
Loading
fromdisk-full-message
Draft
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletionslitgpt/scripts/download.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -5,6 +5,7 @@
from contextlib import contextmanager
import importlib.util
from pathlib import Path
import shutil
from typing import List, Optional, Tuple

import torch
Expand DownExpand Up@@ -62,7 +63,40 @@ def download_from_hub(

download_files = ["tokenizer*", "generation_config.json", "config.json"]
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
bins, safetensors, info = find_weight_files(repo_id, access_token)

total_weight_size_bytes = 0
if bins:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".bin") or file.rfilename.endswith(".bin.index.json")
)
elif safetensors:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".safetensors")
)
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

weight_size_gb = total_weight_size_bytes / (1024**3)
free_space_bytes = shutil.disk_usage(str(checkpoint_dir)).free
free_space_gb = free_space_bytes / (1024**3)

# 2x because we create lit_model.pth before deleting the downloaded weights,
# so we intermittenly have 2 sets of weights on disk
if weight_size_gb > 2*free_space_gb:
if os.getenv("LIGHTNING_CLOUD_SPACE_ID") is not None:
studio_text = " Please switch to a larger Studio with more disk space."
else:
studio_text = ""
raise RuntimeError(
f"Not enough disk space to download {repo_id} weights. "
f"Needed: ~{2*weight_size_gb:.2f} GB, free: ~{free_space_gb:.2f} GB.{studio_text}"
)

if bins:
# covers `.bin` files and `.bin.index.json`
download_files.append("*.bin*")
Expand DownExpand Up@@ -104,11 +138,11 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
from huggingface_hub.utils import filter_repo_objects

with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
info = repo_info(repo_id, token=access_token, files_metadata=True)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors
return bins, safetensors, info


@contextmanager
Expand Down
140 changes: 110 additions & 30 deletionstests/test_rope.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from dataclasses import dataclass

import torch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama
from transformers.models.llama.configuration_llama import LlamaConfig

from litgpt.model import apply_rope, build_rope_cache

Expand All@@ -17,7 +18,23 @@ def test_rope_gptneox():
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
position_ids = torch.arange(seq_len).unsqueeze(0)

theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len)
@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int

config = RoPEConfig(
dim=head_size,
max_position_embeddings=seq_len,
rope_theta=10_000,
hidden_size=head_size * n_head,
num_attention_heads=n_head
)

theirs_rot_emb = GPTNeoXRotaryEmbedding(config)
theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)

ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
Expand All@@ -35,13 +52,32 @@ def test_rope_gptneox():
def test_rope_llama_2():
head_dim = 64
rope_theta = 10_000
num_heads = 4
batch_size, seq_len = 1, 10

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads
)

rot_emb = LlamaRotaryEmbedding(config)

qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All@@ -56,8 +92,6 @@ def test_rope_llama_2():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All@@ -76,13 +110,33 @@ def test_rope_llama_2():
def test_rope_llama_3():
head_dim = 64
rope_theta = 50_000
num_heads = 4
batch_size, seq_len = 1, 10

##################################
# Compare cos and sin
##################################

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
scaling_factor: float

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
scaling_factor=None
)

# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
rot_emb = LlamaRotaryEmbedding(config)
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All@@ -97,8 +151,6 @@ def test_rope_llama_3():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All@@ -117,6 +169,8 @@ def test_rope_llama_3():
def test_rope_llama_3_1():
head_dim = 32
rope_theta = 50_000
num_heads = 4
batch_size, seq_len = 1, 131_072

their_rope_config = {
"factor": 8.0,
Expand All@@ -133,18 +187,32 @@ def test_rope_llama_3_1():
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 131_072

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
rope_type: str
rope_scaling: dict

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
rope_type="llama3",
rope_scaling=their_rope_config
)

rot_emb = LlamaRotaryEmbedding(config=config)
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All@@ -159,8 +227,6 @@ def test_rope_llama_3_1():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All@@ -179,6 +245,8 @@ def test_rope_llama_3_1():
def test_rope_llama_3_2():
head_dim = 32
rope_theta = 50_000
batch_size, seq_len = 1, 131_072
num_heads = 4

their_rope_config = {
"factor": 32.0,
Expand All@@ -195,18 +263,32 @@ def test_rope_llama_3_2():
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 131_072
@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
rope_type: str
rope_scaling: dict

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
rope_type="llama3",
rope_scaling=their_rope_config
)

rot_emb = LlamaRotaryEmbedding(config)

qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All@@ -221,8 +303,6 @@ def test_rope_llama_3_2():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp