Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7ca7c47

Browse files
authored
Make quote style consistent (#891)
1 parent9276edb commit7ca7c47

File tree

24 files changed

+240
-82
lines changed

24 files changed

+240
-82
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2+
# Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3+
# Code repository: https://github.com/rasbt/reasoning-from-scratch
4+
5+
# Verify that Python source files (and optionally notebooks) use double quotes for strings.
6+
7+
importargparse
8+
importast
9+
importio
10+
importjson
11+
importsys
12+
importtokenize
13+
frompathlibimportPath
14+
15+
EXCLUDED_DIRS= {
16+
".git",
17+
".hg",
18+
".mypy_cache",
19+
".pytest_cache",
20+
".ruff_cache",
21+
".svn",
22+
".tox",
23+
".venv",
24+
"__pycache__",
25+
"build",
26+
"dist",
27+
"node_modules",
28+
}
29+
30+
PREFIX_CHARS= {"r","u","f","b"}
31+
SINGLE_QUOTE="'"
32+
DOUBLE_QUOTE="\""
33+
TRIPLE_SINGLE=SINGLE_QUOTE*3
34+
TRIPLE_DOUBLE=DOUBLE_QUOTE*3
35+
36+
37+
defshould_skip(path):
38+
parts=set(path.parts)
39+
returnbool(EXCLUDED_DIRS&parts)
40+
41+
42+
defcollect_fstring_expr_string_positions(source):
43+
"""
44+
Return set of (lineno, col_offset) for string literals that appear inside
45+
formatted expressions of f-strings. These should be exempt from the double
46+
quote check, since enforcing double quotes there is unnecessarily strict.
47+
"""
48+
try:
49+
tree=ast.parse(source)
50+
exceptSyntaxError:
51+
returnset()
52+
53+
positions=set()
54+
55+
classCollector(ast.NodeVisitor):
56+
defvisit_JoinedStr(self,node):
57+
forvalueinnode.values:
58+
ifisinstance(value,ast.FormattedValue):
59+
self._collect_from_expr(value.value)
60+
# Continue walking to catch nested f-strings within expressions
61+
self.generic_visit(node)
62+
63+
def_collect_from_expr(self,node):
64+
ifisinstance(node,ast.Constant)andisinstance(node.value,str):
65+
positions.add((node.lineno,node.col_offset))
66+
elifisinstance(node,ast.Str):# Python <3.8 compatibility
67+
positions.add((node.lineno,node.col_offset))
68+
else:
69+
forchildinast.iter_child_nodes(node):
70+
self._collect_from_expr(child)
71+
72+
Collector().visit(tree)
73+
returnpositions
74+
75+
76+
defcheck_quotes_in_source(source,path):
77+
violations= []
78+
ignored_positions=collect_fstring_expr_string_positions(source)
79+
tokens=tokenize.generate_tokens(io.StringIO(source).readline)
80+
fortok_type,tok_str,start,_,_intokens:
81+
iftok_type==tokenize.STRING:
82+
ifstartinignored_positions:
83+
continue
84+
lowered=tok_str.lower()
85+
# ignore triple-quoted strings
86+
iflowered.startswith((TRIPLE_DOUBLE,TRIPLE_SINGLE)):
87+
continue
88+
89+
# find the prefix and quote type
90+
# prefix = ""
91+
forcinPREFIX_CHARS:
92+
iflowered.startswith(c):
93+
# prefix = c
94+
lowered=lowered[1:]
95+
break
96+
97+
# report if not using double quotes
98+
iflowered.startswith(SINGLE_QUOTE):
99+
line,col=start
100+
violations.append(f"{path}:{line}:{col}: uses single quotes")
101+
returnviolations
102+
103+
104+
defcheck_file(path):
105+
try:
106+
ifpath.suffix==".ipynb":
107+
returncheck_notebook(path)
108+
else:
109+
text=path.read_text(encoding="utf-8")
110+
returncheck_quotes_in_source(text,path)
111+
exceptExceptionase:
112+
return [f"{path}: failed to check ({e})"]
113+
114+
115+
defcheck_notebook(path):
116+
violations= []
117+
withopen(path,encoding="utf-8")asf:
118+
nb=json.load(f)
119+
forcellinnb.get("cells", []):
120+
ifcell.get("cell_type")=="code":
121+
src="".join(cell.get("source", []))
122+
violations.extend(check_quotes_in_source(src,path))
123+
returnviolations
124+
125+
126+
defparse_args():
127+
parser=argparse.ArgumentParser(description="Verify double-quoted string literals.")
128+
parser.add_argument(
129+
"--include-notebooks",
130+
action="store_true",
131+
help="Also scan Jupyter notebooks (.ipynb files) for single-quoted strings.",
132+
)
133+
returnparser.parse_args()
134+
135+
136+
defmain():
137+
args=parse_args()
138+
project_root=Path(".").resolve()
139+
py_files=sorted(project_root.rglob("*.py"))
140+
notebook_files=sorted(project_root.rglob("*.ipynb"))ifargs.include_notebookselse []
141+
142+
violations= []
143+
forpathinpy_files+notebook_files:
144+
ifshould_skip(path):
145+
continue
146+
violations.extend(check_file(path))
147+
148+
ifviolations:
149+
print("\n".join(violations))
150+
print(f"\n{len(violations)} violations found.")
151+
return1
152+
153+
print("All files use double quotes correctly.")
154+
return0
155+
156+
157+
if__name__=="__main__":
158+
sys.exit(main())

‎appendix-D/01_main-chapter-code/previous_chapters.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
7373
self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
7474
self.out_proj=nn.Linear(d_out,d_out)# Linear layer to combine head outputs
7575
self.dropout=nn.Dropout(dropout)
76-
self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
76+
self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
7777

7878
defforward(self,x):
7979
b,num_tokens,d_in=x.shape

‎appendix-E/01_main-chapter-code/previous_chapters.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
8080
self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
8181
self.out_proj=nn.Linear(d_out,d_out)# Linear layer to combine head outputs
8282
self.dropout=nn.Dropout(dropout)
83-
self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
83+
self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
8484

8585
defforward(self,x):
8686
b,num_tokens,d_in=x.shape
@@ -257,8 +257,8 @@ def assign(left, right):
257257

258258

259259
defload_weights_into_gpt(gpt,params):
260-
gpt.pos_emb.weight=assign(gpt.pos_emb.weight,params['wpe'])
261-
gpt.tok_emb.weight=assign(gpt.tok_emb.weight,params['wte'])
260+
gpt.pos_emb.weight=assign(gpt.pos_emb.weight,params["wpe"])
261+
gpt.tok_emb.weight=assign(gpt.tok_emb.weight,params["wte"])
262262

263263
forbinrange(len(params["blocks"])):
264264
q_w,k_w,v_w=np.split(
@@ -318,7 +318,7 @@ def load_weights_into_gpt(gpt, params):
318318

319319

320320
deftext_to_token_ids(text,tokenizer):
321-
encoded=tokenizer.encode(text,allowed_special={'<|endoftext|>'})
321+
encoded=tokenizer.encode(text,allowed_special={"<|endoftext|>"})
322322
encoded_tensor=torch.tensor(encoded).unsqueeze(0)# add batch dimension
323323
returnencoded_tensor
324324

‎ch02/02_bonus_bytepair-encoder/bpe_openai_gpt2.py‎

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_pairs(word):
7070

7171

7272
classEncoder:
73-
def__init__(self,encoder,bpe_merges,errors='replace'):
73+
def__init__(self,encoder,bpe_merges,errors="replace"):
7474
self.encoder=encoder
7575
self.decoder= {v:kfork,vinself.encoder.items()}
7676
self.errors=errors# how to handle errors in decoding
@@ -92,7 +92,7 @@ def bpe(self, token):
9292
returntoken
9393

9494
whileTrue:
95-
bigram=min(pairs,key=lambdapair:self.bpe_ranks.get(pair,float('inf')))
95+
bigram=min(pairs,key=lambdapair:self.bpe_ranks.get(pair,float("inf")))
9696
ifbigramnotinself.bpe_ranks:
9797
break
9898
first,second=bigram
@@ -119,43 +119,43 @@ def bpe(self, token):
119119
break
120120
else:
121121
pairs=get_pairs(word)
122-
word=' '.join(word)
122+
word=" ".join(word)
123123
self.cache[token]=word
124124
returnword
125125

126126
defencode(self,text):
127127
bpe_tokens= []
128128
fortokeninre.findall(self.pat,text):
129-
token=''.join(self.byte_encoder[b]forbintoken.encode('utf-8'))
130-
bpe_tokens.extend(self.encoder[bpe_token]forbpe_tokeninself.bpe(token).split(' '))
129+
token="".join(self.byte_encoder[b]forbintoken.encode("utf-8"))
130+
bpe_tokens.extend(self.encoder[bpe_token]forbpe_tokeninself.bpe(token).split(" "))
131131
returnbpe_tokens
132132

133133
defdecode(self,tokens):
134-
text=''.join([self.decoder[token]fortokenintokens])
135-
text=bytearray([self.byte_decoder[c]forcintext]).decode('utf-8',errors=self.errors)
134+
text="".join([self.decoder[token]fortokenintokens])
135+
text=bytearray([self.byte_decoder[c]forcintext]).decode("utf-8",errors=self.errors)
136136
returntext
137137

138138

139139
defget_encoder(model_name,models_dir):
140-
withopen(os.path.join(models_dir,model_name,'encoder.json'),'r')asf:
140+
withopen(os.path.join(models_dir,model_name,"encoder.json"),"r")asf:
141141
encoder=json.load(f)
142-
withopen(os.path.join(models_dir,model_name,'vocab.bpe'),'r',encoding="utf-8")asf:
142+
withopen(os.path.join(models_dir,model_name,"vocab.bpe"),"r",encoding="utf-8")asf:
143143
bpe_data=f.read()
144-
bpe_merges= [tuple(merge_str.split())formerge_strinbpe_data.split('\n')[1:-1]]
144+
bpe_merges= [tuple(merge_str.split())formerge_strinbpe_data.split("\n")[1:-1]]
145145
returnEncoder(encoder=encoder,bpe_merges=bpe_merges)
146146

147147

148148
defdownload_vocab():
149149
# Modified code from
150-
subdir='gpt2_model'
150+
subdir="gpt2_model"
151151
ifnotos.path.exists(subdir):
152152
os.makedirs(subdir)
153-
subdir=subdir.replace('\\','/')# needed for Windows
153+
subdir=subdir.replace("\\","/")# needed for Windows
154154

155-
forfilenamein ['encoder.json','vocab.bpe']:
155+
forfilenamein ["encoder.json","vocab.bpe"]:
156156
r=requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M/"+filename,stream=True)
157157

158-
withopen(os.path.join(subdir,filename),'wb')asf:
158+
withopen(os.path.join(subdir,filename),"wb")asf:
159159
file_size=int(r.headers["content-length"])
160160
chunk_size=1000
161161
withtqdm(ncols=100,desc="Fetching "+filename,total=file_size,unit_scale=True)aspbar:

‎ch04/01_main-chapter-code/previous_chapters.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
6060
self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
6161
self.out_proj=nn.Linear(d_out,d_out)# Linear layer to combine head outputs
6262
self.dropout=nn.Dropout(dropout)
63-
self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
63+
self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
6464

6565
defforward(self,x):
6666
b,num_tokens,d_in=x.shape

‎ch04/01_main-chapter-code/tests.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_main(capsys):
3333
captured=capsys.readouterr()
3434

3535
# Normalize line endings and strip trailing whitespace from each line
36-
normalized_expected='\n'.join(line.rstrip()forlineinexpected.splitlines())
37-
normalized_output='\n'.join(line.rstrip()forlineincaptured.out.splitlines())
36+
normalized_expected="\n".join(line.rstrip()forlineinexpected.splitlines())
37+
normalized_output="\n".join(line.rstrip()forlineincaptured.out.splitlines())
3838

3939
# Compare normalized strings
4040
assertnormalized_output==normalized_expected

‎ch05/01_main-chapter-code/previous_chapters.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
7171
self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
7272
self.out_proj=nn.Linear(d_out,d_out)# Linear layer to combine head outputs
7373
self.dropout=nn.Dropout(dropout)
74-
self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
74+
self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
7575

7676
defforward(self,x):
7777
b,num_tokens,d_in=x.shape

‎ch05/03_bonus_pretraining_on_gutenberg/prepare_dataset.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftex
4343
content=strip_headers(content)
4444

4545
# Regular expression to replace multiple blank lines with a single blank line
46-
content=re.sub(r'\n\s*\n','\n\n',content)
46+
content=re.sub(r"\n\s*\n","\n\n",content)
4747
estimated_size=len(content.encode("utf-8"))
4848

4949
ifcurrent_size+estimated_size>max_size_mb*1024*1024:

‎ch05/03_bonus_pretraining_on_gutenberg/pretraining_simple.py‎

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,26 @@ def train_model_simple(model, optimizer, device, n_epochs,
148148

149149
if__name__=="__main__":
150150

151-
parser=argparse.ArgumentParser(description='GPT Model Training Configuration')
152-
153-
parser.add_argument('--data_dir',type=str,default='gutenberg/data',
154-
help='Directory containing the training data')
155-
parser.add_argument('--output_dir',type=str,default='model_checkpoints',
156-
help='Directory where the model checkpoints will be saved')
157-
parser.add_argument('--n_epochs',type=int,default=1,
158-
help='Number of epochs to train the model')
159-
parser.add_argument('--print_sample_iter',type=int,default=1000,
160-
help='Iterations between printing sample outputs')
161-
parser.add_argument('--eval_freq',type=int,default=100,
162-
help='Frequency of evaluations during training')
163-
parser.add_argument('--save_ckpt_freq',type=int,default=100_000,
164-
help='Frequency of saving model checkpoints during training')
165-
parser.add_argument('--lr',type=float,default=5e-4,
166-
help='Learning rate for the optimizer')
167-
parser.add_argument('--batch_size',type=int,default=4,
168-
help='Batch size for training')
169-
parser.add_argument('--debug',type=bool,default=False,
170-
help='Uses a very small model for debugging purposes')
151+
parser=argparse.ArgumentParser(description="GPT Model Training Configuration")
152+
153+
parser.add_argument("--data_dir",type=str,default="gutenberg/data",
154+
help="Directory containing the training data")
155+
parser.add_argument("--output_dir",type=str,default="model_checkpoints",
156+
help="Directory where the model checkpoints will be saved")
157+
parser.add_argument("--n_epochs",type=int,default=1,
158+
help="Number of epochs to train the model")
159+
parser.add_argument("--print_sample_iter",type=int,default=1000,
160+
help="Iterations between printing sample outputs")
161+
parser.add_argument("--eval_freq",type=int,default=100,
162+
help="Frequency of evaluations during training")
163+
parser.add_argument("--save_ckpt_freq",type=int,default=100_000,
164+
help="Frequency of saving model checkpoints during training")
165+
parser.add_argument("--lr",type=float,default=5e-4,
166+
help="Learning rate for the optimizer")
167+
parser.add_argument("--batch_size",type=int,default=4,
168+
help="Batch size for training")
169+
parser.add_argument("--debug",type=bool,default=False,
170+
help="Uses a very small model for debugging purposes")
171171

172172
args=parser.parse_args()
173173

‎ch05/05_bonus_hparam_tuning/hparam_search.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def train_model(model, train_loader, val_loader, optimizer, device,
118118
print(f"Total hyperparameter configurations:{total_combinations}")
119119

120120
# Placeholder for the best loss and best hyperparameters
121-
best_val_loss=float('inf')
121+
best_val_loss=float("inf")
122122
best_hparams= {}
123123

124124
script_path=os.path.abspath(__file__)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp