Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
/meshPublic archive

Commitcfc7a67

Browse files
HyperparticleMesh TensorFlow Team
authored and
Mesh TensorFlow Team
committed
Add utility to save score predictions to TFRecords for scoring large datasets.
PiperOrigin-RevId: 396705745
1 parentf08b18e commitcfc7a67

File tree

1 file changed

+143
-9
lines changed

1 file changed

+143
-9
lines changed

‎mesh_tensorflow/transformer/utils.py‎

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626

2727
importfunctools
2828
importitertools
29+
importmath
2930
importos
3031
importrandom
3132
importre
33+
importtime
3234

3335
importgin
3436
importgin.tf
@@ -1654,6 +1656,54 @@ def get_sequence_length(tokens, pad_id=0):
16541656
returnscores
16551657

16561658

1659+
@gin.configurable
1660+
defsave_scores_to_tfrecords(
1661+
results,vocabulary,scores_filename,shard_idx=0,save_ids_only=False):
1662+
"""Processes results from scoring examples and saves them to tfrecords files.
1663+
1664+
Args:
1665+
results: list of dictionaries containing the results for each scored
1666+
example.
1667+
vocabulary: a function that that returns a tf.data.Dataset with examples
1668+
containing the string field 'targets' and optionally the field 'inputs'
1669+
scores_filename: a string (path of file to write scores to).
1670+
shard_idx: an integer indicating the current index of the file for sharding.
1671+
save_ids_only: if true, save the ID that is prepended to the inputs,
1672+
delimited by a space.
1673+
"""
1674+
results=_maybe_add_pretokenized_features(results,vocabulary)
1675+
scores= [r.get("scores",0.0)forrinresults]
1676+
targets= [r.get("targets_pretokenized",r["targets"])forrinresults]
1677+
inputs= [r.get("targets_neg_pretokenized",
1678+
r.get("inputs",""))forrinresults]
1679+
1680+
ifsave_ids_only:
1681+
inputs= [r.split(" ",1)[0]forrininputs]
1682+
1683+
table_path="{}_{}.tfrecord".format(scores_filename,shard_idx)
1684+
tf.logging.info("Saving results to {}".format(table_path))
1685+
1686+
withtf.io.TFRecordWriter(table_path)asfile_writer:
1687+
forinput_,target,scoreinzip(inputs,targets,scores):
1688+
record_bytes=tf.train.Example(
1689+
features=tf.train.Features(
1690+
feature={
1691+
"input":
1692+
tf.train.Feature(
1693+
bytes_list=tf.train.BytesList(
1694+
value=[bytes(input_,"utf8")])),
1695+
"target":
1696+
tf.train.Feature(
1697+
bytes_list=tf.train.BytesList(
1698+
value=[bytes(target,"utf8")])),
1699+
"score":
1700+
tf.train.Feature(
1701+
float_list=tf.train.FloatList(value=[score])),
1702+
})).SerializeToString()
1703+
file_writer.write(record_bytes)
1704+
1705+
1706+
@gin.configurable
16571707
defscore_with_estimator(estimator,input_fn,eval_checkpoint_step,model_dir,
16581708
vocabulary,score_postprocess_fn=save_scores,
16591709
num_examples=None):
@@ -1691,6 +1741,74 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16911741
returnscore_postprocess_fn(results,vocabulary)
16921742

16931743

1744+
@gin.configurable
1745+
defscore_with_estimator_lazy(
1746+
estimator,input_fn,eval_checkpoint_step,model_dir,
1747+
vocabulary,score_postprocess_fn=save_scores_to_tfrecords,
1748+
num_examples=None,num_examples_per_shard=100000):
1749+
"""Score each example returned by input_fn lazily.
1750+
1751+
Args:
1752+
estimator: a TPUEstimator
1753+
input_fn: a function that that returns a tf.data.Dataset with examples
1754+
containing the string field 'targets' and optionally the field 'inputs'
1755+
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1756+
docstring.
1757+
model_dir: string, estimator model_dir
1758+
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1759+
targets_vocabulary) tuple
1760+
score_postprocess_fn: a function that takes in model outputs
1761+
post-processes, and saves them.
1762+
num_examples: int, the total # of examples being scored, None if unknown
1763+
num_examples_per_shard: int, the number of examples per file shard.
1764+
1765+
Returns:
1766+
a list of floats
1767+
"""
1768+
ifnum_examplesisnotNone:
1769+
num_shards=math.ceil(num_examples/num_examples_per_shard)
1770+
else:
1771+
num_shards=None
1772+
tf.logging.info(
1773+
"Scoring {} examples with {} shards at {} examples per shard".format(
1774+
num_examples,num_shards,num_examples_per_shard))
1775+
1776+
checkpoint_path,=get_checkpoint_iterator(
1777+
eval_checkpoint_step,model_dir)
1778+
result_iter=estimator.predict(input_fn,checkpoint_path=checkpoint_path)
1779+
1780+
start=time.time()
1781+
results= []
1782+
shard_idx=0
1783+
1784+
fori,resultinenumerate(result_iter):
1785+
results.append(result)
1786+
num_results=len(results)
1787+
exceeded_examples_per_shard= (
1788+
num_examples_per_shardisnotNone
1789+
andnum_examples_per_shard>0
1790+
andnum_results>=num_examples_per_shard)
1791+
exceeded_num_examples=num_examplesisnotNoneandi>=num_examples
1792+
1793+
ifexceeded_examples_per_shardorexceeded_num_examples:
1794+
score_postprocess_fn(results,vocabulary,shard_idx=shard_idx)
1795+
1796+
elapsed=time.time()-start
1797+
tf.logging.info(
1798+
"Scored {} results in {} s, {} examples/s for shard {}".format(
1799+
num_results,elapsed,num_results/elapsed,shard_idx))
1800+
1801+
results= []
1802+
shard_idx+=1
1803+
start=time.time()
1804+
1805+
ifexceeded_num_examples:
1806+
break
1807+
1808+
ifresults:
1809+
score_postprocess_fn(results,vocabulary,shard_idx=shard_idx)
1810+
1811+
16941812
def_maybe_add_pretokenized_features(examples,vocabulary):
16951813
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
16961814
@@ -1712,9 +1830,19 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
17121830
forexampleinexamples:
17131831
forfeature_namein ["inputs","targets"]:
17141832
pretokenized_feature_name=feature_name+"_pretokenized"
1833+
neg_pretokenized_feature_name=feature_name+"_neg_pretokenized"
17151834
iffeature_nameinexampleandpretokenized_feature_namenotinexample:
1716-
s=vocabulary[feature_name].decode(example[feature_name].tolist())
1717-
example[pretokenized_feature_name]=s
1835+
ids=example[feature_name].tolist()
1836+
1837+
neg_ids= [abs(i)foriinidsifi<0]
1838+
ids= [iforiinidsifi>0]
1839+
1840+
decoded_string=vocabulary[feature_name].decode(ids)
1841+
example[pretokenized_feature_name]=decoded_string
1842+
1843+
ifneg_ids:
1844+
neg_decoded_string=vocabulary[feature_name].decode(neg_ids)
1845+
example[neg_pretokenized_feature_name]=neg_decoded_string
17181846

17191847
ifnotadded_pretokenized[feature_name]:
17201848
added_pretokenized[feature_name]=True
@@ -1730,7 +1858,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17301858
sequence_length,model_dir,eval_checkpoint_step,
17311859
inputs=gin.REQUIRED,targets=gin.REQUIRED,
17321860
score_postprocess_fn=gin.REQUIRED,eos_id=1,
1733-
score_eos=True):
1861+
score_eos=True,
1862+
score_with_estimator_fn=score_with_estimator):
17341863
"""Compute log likelihoods per example and write to a text file.
17351864
17361865
inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1890,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17611890
score_eos: a boolean - whether to score the final eos token of each line
17621891
If this is set to false, the scores can be interpreted as prefix
17631892
log-likelihoods
1893+
score_with_estimator_fn: a function to run scoring with the estimator.
17641894
Returns:
17651895
a list of floats
17661896
"""
@@ -1806,7 +1936,7 @@ def input_fn(params):
18061936
dataset=dataset.batch(batch_size,drop_remainder=True)
18071937
returndataset.prefetch(tf.data.experimental.AUTOTUNE)
18081938

1809-
returnscore_with_estimator(
1939+
returnscore_with_estimator_fn(
18101940
estimator,input_fn,eval_checkpoint_step,model_dir,
18111941
vocabulary,score_postprocess_fn,len(targets))
18121942

@@ -1815,7 +1945,8 @@ def input_fn(params):
18151945
defscore_from_dataset(estimator,vocabulary,batch_size,sequence_length,
18161946
model_dir,eval_checkpoint_step,dataset_split,
18171947
score_dataset_fn=None,
1818-
score_postprocess_fn=gin.REQUIRED):
1948+
score_postprocess_fn=gin.REQUIRED,
1949+
score_with_estimator_fn=score_with_estimator):
18191950
"""Compute log likelihoods per example and write to a text file.
18201951
18211952
The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1968,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18371968
See `eval_dataset_fn` argument to `eval_model` for details.
18381969
score_postprocess_fn: Function that takes in model outputs and
18391970
post-processes then returns then.
1971+
score_with_estimator_fn: a function to run scoring with the estimator.
18401972
18411973
Returns:
18421974
scores: a list of floats, the log likelihood scores
@@ -1850,9 +1982,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18501982
input_fn=_get_combined_dataset_input_fn(
18511983
scoring_datasets,batch_size,sequence_length)
18521984

1853-
returnscore_with_estimator(
1985+
returnscore_with_estimator_fn(
18541986
estimator,input_fn,eval_checkpoint_step,model_dir,
1855-
vocabulary,score_postprocess_fn,None)
1987+
vocabulary,score_postprocess_fn)
18561988

18571989

18581990
defget_estimator(model_type,vocabulary,mesh_shape,
@@ -2093,7 +2225,8 @@ def eval_model(estimator,
20932225
eval_checkpoint_step,
20942226
eval_with_score=False,
20952227
output_eval_examples=True,
2096-
eval_dir_suffix=None):
2228+
eval_dir_suffix=None,
2229+
score_with_estimator_fn=score_with_estimator):
20972230
"""Eval a Mesh-TF model.
20982231
20992232
Args:
@@ -2137,6 +2270,7 @@ def eval_model(estimator,
21372270
of the eval examples in plaintext to eval_summary_dir.
21382271
eval_dir_suffix: string, if not None then will appended to the
21392272
eval_summary_dir.
2273+
score_with_estimator_fn: a function to run scoring with the estimator.
21402274
"""
21412275
ifeval_dataset_fnisNone:
21422276
raiseValueError("Must provide eval_dataset_fn through gin for eval.")
@@ -2248,7 +2382,7 @@ def eval_model(estimator,
22482382
tf.logging.info("Checkpoint path %s"%checkpoint_path)
22492383
global_step=int(get_step_from_checkpoint_path(checkpoint_path))
22502384
ifeval_with_score:
2251-
outputs,_=score_with_estimator(
2385+
outputs,_=score_with_estimator_fn(
22522386
estimator,input_fn,global_step,model_dir,vocabulary,
22532387
num_examples=sum(len(cex)forcexincached_examples.values()))
22542388
else:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp