Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
/meshPublic archive

Commitead60ee

Browse files
William FedusMesh TensorFlow Team
William Fedus
authored and
Mesh TensorFlow Team
committed
Split out optimizer call for internal purposes.
PiperOrigin-RevId: 424207820
1 parenta32810e commitead60ee

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

‎mesh_tensorflow/optimize.py‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def lr(self):
8989

9090
defapply_grad(self,grad,var):
9191
ifgradisNone:
92-
tf.logging.warning("Gradient is None for variable %s"%var.name)
92+
tf.logging.warning("Gradient is None for variable %s",var.name)
9393
return []
9494
# It is critical to use assign_sub instead of mtf.assign(var - ...)
9595
# for the case of bfloat16 activations, so as to avoid repeatedly rounding
@@ -115,7 +115,7 @@ def momentum(self):
115115

116116
defapply_grad(self,grad,var):
117117
ifgradisNone:
118-
tf.logging.warning("Gradient is None for variable %s"%var.name)
118+
tf.logging.warning("Gradient is None for variable %s",var.name)
119119
return []
120120

121121
updates= []
@@ -153,7 +153,7 @@ def __init__(self,
153153
defapply_grad(self,grad,var):
154154
"""See base class."""
155155
ifgradisNone:
156-
tf.logging.warning("Gradient is None for variable %s"%var.name)
156+
tf.logging.warning("Gradient is None for variable %s",var.name)
157157
return []
158158
grad=mtf.to_float(grad)
159159

@@ -219,7 +219,8 @@ def __init__(self,
219219
epsilon2=1e-3,
220220
min_dim_size_to_factor=128,
221221
stacked_dim_names=None,
222-
exclude_from_parameter_scale=None):
222+
exclude_from_parameter_scale=None,
223+
):
223224
"""Construct a new Adafactor optimizer.
224225
225226
See class comment.
@@ -306,7 +307,7 @@ def _parameter_scale(self, var):
306307

307308
defapply_grad(self,grad,var):
308309
ifgradisNone:
309-
tf.logging.warning("Gradient is None for variable %s"%var.name)
310+
tf.logging.warning("Gradient is None for variable %s",var.name)
310311
return []
311312
# create slots
312313
grad=mtf.to_float(grad)

‎mesh_tensorflow/transformer/utils.py‎

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,9 @@ def tpu_estimator_model_fn(model_type,
522522
ensemble_inputs=None,
523523
mesh_devices=None,
524524
model_info_file=None,
525-
hierarchical_tiling_spec=None):
525+
hierarchical_tiling_spec=None,
526+
weight_decay_checkpoint=None# GOOGLE-INTERNAL,
527+
):
526528
"""Create a TPUEstimator model function.
527529
528530
Args:
@@ -564,14 +566,17 @@ def tpu_estimator_model_fn(model_type,
564566
if empty string (default), all variables from the checkpoint are loaded.
565567
ensemble_inputs: an optional integer - pass the size of the ensemble to
566568
train an ensemble where each model gets different inputs.
567-
You also need to configure Unitransformer.ensembleto the right size.
569+
You also need to configure Unitransformer.ensemble to the right size.
568570
If None, then all models are trained on the same inputs.
569571
mesh_devices: a list of strings, the device names to use for each mesh
570572
slice. Only required for GPU.
571573
model_info_file: an optional string, information about variables and
572574
operations will be logged to this file during the TRAIN mode.
573575
hierarchical_tiling_spec: an optional list that can be passed as the
574576
spec argument to simd_mesh_impl.HierarchicalTiling
577+
weight_decay_checkpoint: an optional checkpoint dir to weight decay from. #
578+
GOOGE-INTERNAL
579+
575580
Returns:
576581
a function to be passed to TPUEstimator
577582
"""
@@ -663,7 +668,7 @@ def my_model_fn(features, labels, mode, params=None, config=None):
663668
x=tf.cast(features[key],tf.int32)
664669
x=tf.reshape(x,feature_shape.to_integer_list)
665670
ifnotuse_tpu:
666-
tf.logging.info("feature %s : %s"% (key,x))
671+
tf.logging.info("feature %s : %s",key,x)
667672
mtf_features[key]=mtf.import_fully_replicated(
668673
mesh,x,feature_shape,name=key)
669674

@@ -886,6 +891,7 @@ def serialized_fn(mtf_features):
886891
var_grads=mtf.gradients(
887892
[loss], [v.outputs[0]forvingraph.trainable_variables])
888893

894+
889895
iftpu_summaries:
890896
mtf.scalar_summary("loss",loss)
891897

@@ -919,11 +925,8 @@ def serialized_fn(mtf_features):
919925
tf.logging.info("Variables not being trained:")
920926
tf.logging.info([v.nameforvingraph.trainable_variables
921927
ifnotvariable_filter_fn(v)])
922-
923-
update_ops=optimizer(learning_rate=learning_rate).apply_grads(
924-
trainable_var_grads,trainable_vars
925-
)
926-
928+
opt=optimizer(learning_rate=learning_rate)
929+
update_ops=opt.apply_grads(trainable_var_grads,trainable_vars)
927930
lowering=mtf.Lowering(
928931
graph, {mesh:mesh_impl},
929932
autostack=autostack,
@@ -980,6 +983,7 @@ def serialized_fn(mtf_features):
980983
{init_checkpoint_variable_mapping(v):vforvinrestore_vars}
981984
)
982985

986+
983987
# Copy master variables to slices. Must be called first.
984988
restore_hook=mtf.MtfRestoreHook(lowering)
985989
saver=tf.train.Saver(
@@ -1348,8 +1352,8 @@ def _maybe_detokenize(value, vocab):
13481352
yieldoutput_string
13491353
ifi& (i-1)==0:
13501354
# LOG every power of 2.
1351-
tf.logging.info("decoded{}: {}".format(i,input_string))
1352-
tf.logging.info(" ->{}".format(output_string))
1355+
tf.logging.info("decoded%s: %s",i,input_string)
1356+
tf.logging.info(" ->%s",output_string)
13531357

13541358

13551359
@gin.configurable
@@ -1681,7 +1685,7 @@ def save_scores_to_tfrecords(
16811685
inputs= [r.split(" ",1)[0]forrininputs]
16821686

16831687
table_path="{}_{}.tfrecord".format(scores_filename,shard_idx)
1684-
tf.logging.info("Saving results to{}".format(table_path))
1688+
tf.logging.info("Saving results to%s",table_path)
16851689

16861690
withtf.io.TFRecordWriter(table_path)asfile_writer:
16871691
forinput_,target,scoreinzip(inputs,targets,scores):
@@ -1769,12 +1773,10 @@ def score_with_estimator_lazy(
17691773
num_shards=math.ceil(num_examples/num_examples_per_shard)
17701774
else:
17711775
num_shards=None
1772-
tf.logging.info(
1773-
"Scoring {} examples with {} shards at {} examples per shard".format(
1774-
num_examples,num_shards,num_examples_per_shard))
1776+
tf.logging.info("Scoring %s examples with %s shards at %s examples per shard",
1777+
num_examples,num_shards,num_examples_per_shard)
17751778

1776-
checkpoint_path,=get_checkpoint_iterator(
1777-
eval_checkpoint_step,model_dir)
1779+
checkpoint_path,=get_checkpoint_iterator(eval_checkpoint_step,model_dir)
17781780
result_iter=estimator.predict(input_fn,checkpoint_path=checkpoint_path)
17791781

17801782
start=time.time()
@@ -1794,9 +1796,8 @@ def score_with_estimator_lazy(
17941796
score_postprocess_fn(results,vocabulary,shard_idx=shard_idx)
17951797

17961798
elapsed=time.time()-start
1797-
tf.logging.info(
1798-
"Scored {} results in {} s, {} examples/s for shard {}".format(
1799-
num_results,elapsed,num_results/elapsed,shard_idx))
1799+
tf.logging.info("Scored %s results in %s s, %s examples/s for shard %s",
1800+
num_results,elapsed,num_results/elapsed,shard_idx)
18001801

18011802
results= []
18021803
shard_idx+=1
@@ -2379,7 +2380,7 @@ def eval_model(estimator,
23792380

23802381
checkpoint_paths=get_checkpoint_iterator(eval_checkpoint_step,model_dir)
23812382
forcheckpoint_pathincheckpoint_paths:
2382-
tf.logging.info("Checkpoint path %s"%checkpoint_path)
2383+
tf.logging.info("Checkpoint path %s",checkpoint_path)
23832384
global_step=int(get_step_from_checkpoint_path(checkpoint_path))
23842385
ifeval_with_score:
23852386
outputs,_=score_with_estimator_fn(
@@ -2907,15 +2908,15 @@ def run(tpu_job_name,
29072908
learning_rate_schedule=functools.partial(
29082909
learning_rate_schedule,total_train_steps=total_run_steps)
29092910

2910-
tf.logging.info("model_type=%s"%model_type,)
2911-
tf.logging.info("mode=%s"%mode,)
2912-
tf.logging.info("sequence_length=%s"%sequence_length,)
2913-
tf.logging.info("batch_size=%s"%batch_size,)
2914-
tf.logging.info("train_steps=%s"%train_steps,)
2911+
tf.logging.info("model_type=%s",model_type,)
2912+
tf.logging.info("mode=%s",mode,)
2913+
tf.logging.info("sequence_length=%s",sequence_length,)
2914+
tf.logging.info("batch_size=%s",batch_size,)
2915+
tf.logging.info("train_steps=%s",train_steps,)
29152916
iftotal_run_stepsisnotNone:
2916-
tf.logging.info("total_run_steps=%s"%total_run_steps,)
2917-
tf.logging.info("mesh_shape=%s"%mesh_shape,)
2918-
tf.logging.info("layout_rules=%s"%layout_rules,)
2917+
tf.logging.info("total_run_steps=%s",total_run_steps,)
2918+
tf.logging.info("mesh_shape=%s",mesh_shape,)
2919+
tf.logging.info("layout_rules=%s",layout_rules,)
29192920

29202921
ifmode=="train"anddataset_split!="train":
29212922
raiseValueError("mode==\"train\" requires dataset_split==\"train\"")
@@ -2929,9 +2930,7 @@ def run(tpu_job_name,
29292930
cluster=tf.distribute.cluster_resolver.TPUClusterResolver(
29302931
tpu,zone=tpu_zone,project=gcp_project)iftpuelseNone
29312932

2932-
tf.logging.info(
2933-
"Building TPUConfig with tpu_job_name={}".format(tpu_job_name)
2934-
)
2933+
tf.logging.info("Building TPUConfig with tpu_job_name=%s",tpu_job_name)
29352934

29362935
score_in_predict_mode="score"inmode
29372936
estimator_fn=functools.partial(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp