- Notifications
You must be signed in to change notification settings - Fork5
Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`
License
alvarobartt/safejax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
safejax
is a Python package to serialize JAX, Flax, Haiku, or Objax model params usingsafetensors
as the tensor storage format, instead of relying onpickle
. For more details on whysafetensors
is safer thanpickle
please checkhuggingface/safetensors.
Note thatsafejax
supports the serialization ofjax
,flax
,dm-haiku
, andobjax
modelparameters and has been tested with all those frameworks, but there may be some cases where itdoes not work as expected, as this is still in an early development phase, so please if you haveany feedback or bug reports, open an issue atsafejax/issues.
safejax
requires Python 3.7 or above
pip install safejax --upgrade
Convert
params
tobytes
in memoryfromsafejax.flaximportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)model.apply(decoded_params, ...)
Convert
params
tobytes
inparams.safetensors
filefromsafejax.flaximportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params,filename="./params.safetensors")decoded_params=deserialize("./params.safetensors")model.apply(decoded_params, ...)
Just contains
params
fromsafejax.haikuimportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)model.apply(decoded_params, ...)
If it contains
params
andstate
e.g. ExponentialMovingAverage in BatchNormfromsafejax.haikuimportserialize,deserializeparams,state=model.init(...)params_state= {"params":params,"state":state}encoded_bytes=serialize(params_state)decoded_params_state=deserialize(encoded_bytes)# .keys() contains `params` and `state`model.apply(decoded_params_state["params"],decoded_params_state["state"], ...)
If it contains
params
andstate
, but we want to serialize those individuallyfromsafejax.haikuimportserialize,deserializeparams,state=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)encoded_bytes=serialize(state)decoded_state=deserialize(encoded_bytes)model.apply(decoded_params,decoded_state, ...)
Convert
params
tobytes
in memory, and convert back toVarCollection
fromsafejax.objaximportserialize,deserializeparams=model.vars()encoded_bytes=serialize(params=params)decoded_params=deserialize(encoded_bytes)forkey,valueindecoded_params.items():ifkeyinmodel.vars():model.vars()[key].assign(value.value)model(...)
Convert
params
tobytes
inparams.safetensors
filefromsafejax.objaximportserialize,deserializeparams=model.vars()encoded_bytes=serialize(params=params,filename="./params.safetensors")decoded_params=deserialize("./params.safetensors")forkey,valueindecoded_params.items():ifkeyinmodel.vars():model.vars()[key].assign(value.value)model(...)
Convert
params
tobytes
inparams.safetensors
and assign during deserializationfromsafejax.objaximportserialize,deserialize_with_assignmentparams=model.vars()encoded_bytes=serialize(params=params,filename="./params.safetensors")deserialize_with_assignment(filename="./params.safetensors",model_vars=params)model(...)
More in-detail examples can be found atexamples/
forflax
,dm-haiku
, andobjax
.
safetensors
defines an easy and fast (zero-copy) format to store tensors,whilepickle
has some known weaknesses and security issues.safetensors
is also a storage format that is intended to be trivial to the frameworkused to load the tensors. More in-depth information can be found athuggingface/safetensors.
jax
usespytrees
to store the model parameters in memory, soit's a dictionary-like class containing nestedjnp.DeviceArray
tensors.
dm-haiku
uses a custom dictionary formatted as<level_1>/~/<level_2>
, where thelevels are the ones that define the tree structure and/~/
is the separator between thosee.g.res_net50/~/intial_conv
, and that key does not contain ajnp.DeviceArray
, but adictionary with key value pairs e.g. for both weights asw
and biases asb
.
objax
defines a custom dictionary-like class namedVarCollection
that containssome variables inheriting fromBaseVar
which is another customobjax
type.
flax
defines a dictionary-like class namedFrozenDict
that is used tostore the tensors in memory, it can be dumped either intobytes
inMessagePack
format or as astate_dict
.
There are no plans from HuggingFace to extendsafetensors
to support anything more than tensorse.g.FrozenDict
s, see their response athuggingface/safetensors/discussions/138.
So the motivation to createsafejax
is to easily provide a way to serializeFrozenDict
susingsafetensors
as the tensor storage format instead ofpickle
, as well as to providea common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax)usingsafetensors
format.
About
Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`