🔐 Serialize JAX, Flax, Haiku, or Objax model params withsafetensors
safejax
is a Python package to serialize JAX, Flax, Haiku, or Objax model params usingsafetensors
as the tensor storage format, instead of relying onpickle
. For more details on whysafetensors
is safer thanpickle
please checkhuggingface/safetensors.
Note thatsafejax
supports the serialization ofjax
,flax
,dm-haiku
, andobjax
modelparameters and has been tested with all those frameworks, but there may be some cases where itdoes not work as expected, as this is still in an early development phase, so please if you haveany feedback or bug reports, open an issue atsafejax/issues.
Last update:2023-01-19
Created:2023-01-19
Created:2023-01-19