Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`

License

NotificationsYou must be signed in to change notification settings

alvarobartt/safejax

Repository files navigation

safejax is a Python package to serialize JAX, Flax, Haiku, or Objax model params usingsafetensorsas the tensor storage format, instead of relying onpickle. For more details on whysafetensors is safer thanpickle please checkhuggingface/safetensors.

Note thatsafejax supports the serialization ofjax,flax,dm-haiku, andobjax modelparameters and has been tested with all those frameworks, but there may be some cases where itdoes not work as expected, as this is still in an early development phase, so please if you haveany feedback or bug reports, open an issue atsafejax/issues.

🛠️ Requirements & Installation

safejax requires Python 3.7 or above

pip install safejax --upgrade

💻 Usage

flax

  • Convertparams tobytes in memory

    fromsafejax.flaximportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)model.apply(decoded_params, ...)
  • Convertparams tobytes inparams.safetensors file

    fromsafejax.flaximportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params,filename="./params.safetensors")decoded_params=deserialize("./params.safetensors")model.apply(decoded_params, ...)

dm-haiku

  • Just containsparams

    fromsafejax.haikuimportserialize,deserializeparams=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)model.apply(decoded_params, ...)
  • If it containsparams andstate e.g. ExponentialMovingAverage in BatchNorm

    fromsafejax.haikuimportserialize,deserializeparams,state=model.init(...)params_state= {"params":params,"state":state}encoded_bytes=serialize(params_state)decoded_params_state=deserialize(encoded_bytes)# .keys() contains `params` and `state`model.apply(decoded_params_state["params"],decoded_params_state["state"], ...)
  • If it containsparams andstate, but we want to serialize those individually

    fromsafejax.haikuimportserialize,deserializeparams,state=model.init(...)encoded_bytes=serialize(params)decoded_params=deserialize(encoded_bytes)encoded_bytes=serialize(state)decoded_state=deserialize(encoded_bytes)model.apply(decoded_params,decoded_state, ...)

objax

  • Convertparams tobytes in memory, and convert back toVarCollection

    fromsafejax.objaximportserialize,deserializeparams=model.vars()encoded_bytes=serialize(params=params)decoded_params=deserialize(encoded_bytes)forkey,valueindecoded_params.items():ifkeyinmodel.vars():model.vars()[key].assign(value.value)model(...)
  • Convertparams tobytes inparams.safetensors file

    fromsafejax.objaximportserialize,deserializeparams=model.vars()encoded_bytes=serialize(params=params,filename="./params.safetensors")decoded_params=deserialize("./params.safetensors")forkey,valueindecoded_params.items():ifkeyinmodel.vars():model.vars()[key].assign(value.value)model(...)
  • Convertparams tobytes inparams.safetensors and assign during deserialization

    fromsafejax.objaximportserialize,deserialize_with_assignmentparams=model.vars()encoded_bytes=serialize(params=params,filename="./params.safetensors")deserialize_with_assignment(filename="./params.safetensors",model_vars=params)model(...)

More in-detail examples can be found atexamples/ forflax,dm-haiku, andobjax.

🤔 Whysafejax?

safetensors defines an easy and fast (zero-copy) format to store tensors,whilepickle has some known weaknesses and security issues.safetensorsis also a storage format that is intended to be trivial to the frameworkused to load the tensors. More in-depth information can be found athuggingface/safetensors.

jax usespytrees to store the model parameters in memory, soit's a dictionary-like class containing nestedjnp.DeviceArray tensors.

dm-haiku uses a custom dictionary formatted as<level_1>/~/<level_2>, where thelevels are the ones that define the tree structure and/~/ is the separator between thosee.g.res_net50/~/intial_conv, and that key does not contain ajnp.DeviceArray, but adictionary with key value pairs e.g. for both weights asw and biases asb.

objax defines a custom dictionary-like class namedVarCollection that containssome variables inheriting fromBaseVar which is another customobjax type.

flax defines a dictionary-like class namedFrozenDict that is used tostore the tensors in memory, it can be dumped either intobytes inMessagePackformat or as astate_dict.

There are no plans from HuggingFace to extendsafetensors to support anything more than tensorse.g.FrozenDicts, see their response athuggingface/safetensors/discussions/138.

So the motivation to createsafejax is to easily provide a way to serializeFrozenDictsusingsafetensors as the tensor storage format instead ofpickle, as well as to providea common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax)usingsafetensors format.


[8]ページ先頭

©2009-2025 Movatter.jp