@@ -522,7 +522,9 @@ def tpu_estimator_model_fn(model_type,
522522ensemble_inputs = None ,
523523mesh_devices = None ,
524524model_info_file = None ,
525- hierarchical_tiling_spec = None ):
525+ hierarchical_tiling_spec = None ,
526+ weight_decay_checkpoint = None # GOOGLE-INTERNAL,
527+ ):
526528"""Create a TPUEstimator model function.
527529
528530 Args:
@@ -564,14 +566,17 @@ def tpu_estimator_model_fn(model_type,
564566 if empty string (default), all variables from the checkpoint are loaded.
565567 ensemble_inputs: an optional integer - pass the size of the ensemble to
566568 train an ensemble where each model gets different inputs.
567- You also need to configure Unitransformer.ensemble to the right size.
569+ You also need to configure Unitransformer.ensemble to the right size.
568570 If None, then all models are trained on the same inputs.
569571 mesh_devices: a list of strings, the device names to use for each mesh
570572 slice. Only required for GPU.
571573 model_info_file: an optional string, information about variables and
572574 operations will be logged to this file during the TRAIN mode.
573575 hierarchical_tiling_spec: an optional list that can be passed as the
574576 spec argument to simd_mesh_impl.HierarchicalTiling
577+ weight_decay_checkpoint: an optional checkpoint dir to weight decay from. #
578+ GOOGE-INTERNAL
579+
575580 Returns:
576581 a function to be passed to TPUEstimator
577582 """
@@ -663,7 +668,7 @@ def my_model_fn(features, labels, mode, params=None, config=None):
663668x = tf .cast (features [key ],tf .int32 )
664669x = tf .reshape (x ,feature_shape .to_integer_list )
665670if not use_tpu :
666- tf .logging .info ("feature %s : %s" % ( key ,x ) )
671+ tf .logging .info ("feature %s : %s" , key ,x )
667672mtf_features [key ]= mtf .import_fully_replicated (
668673mesh ,x ,feature_shape ,name = key )
669674
@@ -886,6 +891,7 @@ def serialized_fn(mtf_features):
886891var_grads = mtf .gradients (
887892 [loss ], [v .outputs [0 ]for v in graph .trainable_variables ])
888893
894+
889895if tpu_summaries :
890896mtf .scalar_summary ("loss" ,loss )
891897
@@ -919,11 +925,8 @@ def serialized_fn(mtf_features):
919925tf .logging .info ("Variables not being trained:" )
920926tf .logging .info ([v .name for v in graph .trainable_variables
921927if not variable_filter_fn (v )])
922-
923- update_ops = optimizer (learning_rate = learning_rate ).apply_grads (
924- trainable_var_grads ,trainable_vars
925- )
926-
928+ opt = optimizer (learning_rate = learning_rate )
929+ update_ops = opt .apply_grads (trainable_var_grads ,trainable_vars )
927930lowering = mtf .Lowering (
928931graph , {mesh :mesh_impl },
929932autostack = autostack ,
@@ -980,6 +983,7 @@ def serialized_fn(mtf_features):
980983 {init_checkpoint_variable_mapping (v ):v for v in restore_vars }
981984 )
982985
986+
983987# Copy master variables to slices. Must be called first.
984988restore_hook = mtf .MtfRestoreHook (lowering )
985989saver = tf .train .Saver (
@@ -1348,8 +1352,8 @@ def _maybe_detokenize(value, vocab):
13481352yield output_string
13491353if i & (i - 1 )== 0 :
13501354# LOG every power of 2.
1351- tf .logging .info ("decoded{}: {}" . format ( i ,input_string ) )
1352- tf .logging .info (" ->{}" . format ( output_string ) )
1355+ tf .logging .info ("decoded%s: %s" , i ,input_string )
1356+ tf .logging .info (" ->%s" , output_string )
13531357
13541358
13551359@gin .configurable
@@ -1681,7 +1685,7 @@ def save_scores_to_tfrecords(
16811685inputs = [r .split (" " ,1 )[0 ]for r in inputs ]
16821686
16831687table_path = "{}_{}.tfrecord" .format (scores_filename ,shard_idx )
1684- tf .logging .info ("Saving results to{}" . format ( table_path ) )
1688+ tf .logging .info ("Saving results to%s" , table_path )
16851689
16861690with tf .io .TFRecordWriter (table_path )as file_writer :
16871691for input_ ,target ,score in zip (inputs ,targets ,scores ):
@@ -1769,12 +1773,10 @@ def score_with_estimator_lazy(
17691773num_shards = math .ceil (num_examples / num_examples_per_shard )
17701774else :
17711775num_shards = None
1772- tf .logging .info (
1773- "Scoring {} examples with {} shards at {} examples per shard" .format (
1774- num_examples ,num_shards ,num_examples_per_shard ))
1776+ tf .logging .info ("Scoring %s examples with %s shards at %s examples per shard" ,
1777+ num_examples ,num_shards ,num_examples_per_shard )
17751778
1776- checkpoint_path ,= get_checkpoint_iterator (
1777- eval_checkpoint_step ,model_dir )
1779+ checkpoint_path ,= get_checkpoint_iterator (eval_checkpoint_step ,model_dir )
17781780result_iter = estimator .predict (input_fn ,checkpoint_path = checkpoint_path )
17791781
17801782start = time .time ()
@@ -1794,9 +1796,8 @@ def score_with_estimator_lazy(
17941796score_postprocess_fn (results ,vocabulary ,shard_idx = shard_idx )
17951797
17961798elapsed = time .time ()- start
1797- tf .logging .info (
1798- "Scored {} results in {} s, {} examples/s for shard {}" .format (
1799- num_results ,elapsed ,num_results / elapsed ,shard_idx ))
1799+ tf .logging .info ("Scored %s results in %s s, %s examples/s for shard %s" ,
1800+ num_results ,elapsed ,num_results / elapsed ,shard_idx )
18001801
18011802results = []
18021803shard_idx += 1
@@ -2379,7 +2380,7 @@ def eval_model(estimator,
23792380
23802381checkpoint_paths = get_checkpoint_iterator (eval_checkpoint_step ,model_dir )
23812382for checkpoint_path in checkpoint_paths :
2382- tf .logging .info ("Checkpoint path %s" % checkpoint_path )
2383+ tf .logging .info ("Checkpoint path %s" , checkpoint_path )
23832384global_step = int (get_step_from_checkpoint_path (checkpoint_path ))
23842385if eval_with_score :
23852386outputs ,_ = score_with_estimator_fn (
@@ -2907,15 +2908,15 @@ def run(tpu_job_name,
29072908learning_rate_schedule = functools .partial (
29082909learning_rate_schedule ,total_train_steps = total_run_steps )
29092910
2910- tf .logging .info ("model_type=%s" % model_type ,)
2911- tf .logging .info ("mode=%s" % mode ,)
2912- tf .logging .info ("sequence_length=%s" % sequence_length ,)
2913- tf .logging .info ("batch_size=%s" % batch_size ,)
2914- tf .logging .info ("train_steps=%s" % train_steps ,)
2911+ tf .logging .info ("model_type=%s" , model_type ,)
2912+ tf .logging .info ("mode=%s" , mode ,)
2913+ tf .logging .info ("sequence_length=%s" , sequence_length ,)
2914+ tf .logging .info ("batch_size=%s" , batch_size ,)
2915+ tf .logging .info ("train_steps=%s" , train_steps ,)
29152916if total_run_steps is not None :
2916- tf .logging .info ("total_run_steps=%s" % total_run_steps ,)
2917- tf .logging .info ("mesh_shape=%s" % mesh_shape ,)
2918- tf .logging .info ("layout_rules=%s" % layout_rules ,)
2917+ tf .logging .info ("total_run_steps=%s" , total_run_steps ,)
2918+ tf .logging .info ("mesh_shape=%s" , mesh_shape ,)
2919+ tf .logging .info ("layout_rules=%s" , layout_rules ,)
29192920
29202921if mode == "train" and dataset_split != "train" :
29212922raise ValueError ("mode==\" train\" requires dataset_split==\" train\" " )
@@ -2929,9 +2930,7 @@ def run(tpu_job_name,
29292930cluster = tf .distribute .cluster_resolver .TPUClusterResolver (
29302931tpu ,zone = tpu_zone ,project = gcp_project )if tpu else None
29312932
2932- tf .logging .info (
2933- "Building TPUConfig with tpu_job_name={}" .format (tpu_job_name )
2934- )
2933+ tf .logging .info ("Building TPUConfig with tpu_job_name=%s" ,tpu_job_name )
29352934
29362935score_in_predict_mode = "score" in mode
29372936estimator_fn = functools .partial (