Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.example_libraries.optimizers module#

Examples of how to write optimizers with JAX.

You likely do not mean to import this module! The optimizers in this libraryare intended as examples only. If you are looking for a fully featured optimizerlibrary, considerOptax.

This module contains some convenient optimizer definitions, specificallyinitialization and update functions, which can be used with ndarrays orarbitrarily-nested tuple/list/dicts of ndarrays.

An optimizer is modeled as an(init_fun,update_fun,get_params) triple offunctions, where the component functions have these signatures:

init_fun(params)Args:  params: pytree representing the initial parameters.Returns:  A pytree representing the initial optimizer state, which includes the  initial parameters and may also include auxiliary values like initial  momentum. The optimizer state pytree structure generally differs from that  of `params`.
update_fun(step, grads, opt_state)Args:  step: integer representing the step index.  grads: a pytree with the same structure as `get_params(opt_state)`    representing the gradients to be used in updating the optimizer state.  opt_state: a pytree representing the optimizer state to be updated.Returns:  A pytree with the same structure as the `opt_state` argument representing  the updated optimizer state.
get_params(opt_state)Args:  opt_state: pytree representing an optimizer state.Returns:  A pytree representing the parameters extracted from `opt_state`, such that  the invariant `params == get_params(init_fun(params))` holds true.

Notice that an optimizer implementation has a lot of flexibility in the form ofopt_state: it just has to be a pytree of JaxTypes (so that it can be passed tothe JAX transforms defined in api.py) and it has to be consumable by update_funand get_params.

Example Usage:

opt_init,opt_update,get_params=optimizers.sgd(learning_rate)opt_state=opt_init(params)defstep(step,opt_state):value,grads=jax.value_and_grad(loss_fn)(get_params(opt_state))opt_state=opt_update(step,grads,opt_state)returnvalue,opt_stateforiinrange(num_steps):value,opt_state=step(i,opt_state)
classjax.example_libraries.optimizers.JoinPoint(subtree)[source]#

Bases:object

Marks the boundary between two joined (nested) pytrees.

classjax.example_libraries.optimizers.Optimizer(init_fn,update_fn,params_fn)[source]#

Bases:NamedTuple

Parameters:
  • init_fn (InitFn)

  • update_fn (UpdateFn)

  • params_fn (ParamsFn)

init_fn:InitFn#

Alias for field number 0

params_fn:ParamsFn#

Alias for field number 2

update_fn:UpdateFn#

Alias for field number 1

classjax.example_libraries.optimizers.OptimizerState(packed_state,tree_def,subtree_defs)#

Bases:tuple

packed_state#

Alias for field number 0

subtree_defs#

Alias for field number 2

tree_def#

Alias for field number 1

jax.example_libraries.optimizers.adagrad(step_size,momentum=0.9)[source]#

Construct optimizer triple for Adagrad.

Adaptive Subgradient Methods for Online Learning and Stochastic Optimization:http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.adam(step_size,b1=0.9,b2=0.999,eps=1e-08)[source]#

Construct optimizer triple for Adam.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay ratefor the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay ratefor the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant fornumerical stability (default 1e-8).

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.adamax(step_size,b1=0.9,b2=0.999,eps=1e-08)[source]#

Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • b1 – optional, a positive scalar value for beta_1, the exponential decay ratefor the first moment estimates (default 0.9).

  • b2 – optional, a positive scalar value for beta_2, the exponential decay ratefor the second moment estimates (default 0.999).

  • eps – optional, a positive scalar value for epsilon, a small constant fornumerical stability (default 1e-8).

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.clip_grads(grad_tree,max_norm)[source]#

Clip gradients stored as a pytree of arrays to maximum normmax_norm.

jax.example_libraries.optimizers.constant(step_size)[source]#
Return type:

Schedule

jax.example_libraries.optimizers.exponential_decay(step_size,decay_steps,decay_rate)[source]#
jax.example_libraries.optimizers.inverse_time_decay(step_size,decay_steps,decay_rate,staircase=False)[source]#
jax.example_libraries.optimizers.l2_norm(tree)[source]#

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[source]#
Parameters:

scalar_or_schedule (float |Schedule)

Return type:

Schedule

jax.example_libraries.optimizers.momentum(step_size,mass)[source]#

Construct optimizer triple for SGD with momentum.

Parameters:
  • step_size (Schedule) – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.nesterov(step_size,mass)[source]#

Construct optimizer triple for SGD with Nesterov momentum.

Parameters:
  • step_size (Schedule) – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • mass (float) – positive scalar representing the momentum coefficient.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.optimizer(opt_maker)[source]#

Decorator to make an optimizer defined for arrays generalize to containers.

With this decorator, you can write init, update, and get_params functions thateach operate only on single arrays, and convert them to correspondingfunctions that operate on pytrees of parameters. See the optimizers defined inoptimizers.py for examples.

Parameters:

opt_maker (Callable[...,tuple[Callable[[Params],State],Callable[[Step,Updates,Params],Params],Callable[[State],Params]]]) –

a function that returns an(init_fun,update_fun,get_params)triple of functions that might only work with ndarrays, as per

init_fun::ndarray->OptStatePytreendarrayupdate_fun::OptStatePytreendarray->OptStatePytreendarrayget_params::OptStatePytreendarray->ndarray

Returns:

An(init_fun,update_fun,get_params) triple of functions that work onarbitrary pytrees, as per

init_fun::ParameterPytreendarray->OptimizerStateupdate_fun::OptimizerState->OptimizerStateget_params::OptimizerState->ParameterPytreendarray

The OptimizerState pytree type used by the returned functions is isomorphictoParameterPytree(OptStatePytreendarray), but may store the stateinstead as e.g. a partially-flattened data structure for performance.

Return type:

Callable[…,Optimizer]

jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[source]#

Converts a marked pytree to an OptimizerState.

The inverse of unpack_optimizer_state. Converts a marked pytree with theleaves of the outer pytree represented as JoinPoints back into anOptimizerState. This function is intended to be useful when deserializingoptimizer states.

Parameters:

marked_pytree – A pytree containing JoinPoint leaves that hold more pytrees.

Returns:

An equivalent OptimizerState to the input argument.

jax.example_libraries.optimizers.piecewise_constant(boundaries,values)[source]#
Parameters:
  • boundaries (Any)

  • values (Any)

jax.example_libraries.optimizers.polynomial_decay(step_size,decay_steps,final_step_size,power=1.0)[source]#
jax.example_libraries.optimizers.rmsprop(step_size,gamma=0.9,eps=1e-08)[source]#

Construct optimizer triple for RMSProp.

Parameters:

step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.gamma: Decay parameter.eps: Epsilon parameter.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.rmsprop_momentum(step_size,gamma=0.9,eps=1e-08,momentum=0.9)[source]#

Construct optimizer triple for RMSProp with momentum.

This optimizer is separate from the rmsprop optimizer because it needs tokeep track of additional parameters.

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • gamma – Decay parameter.

  • eps – Epsilon parameter.

  • momentum – Momentum parameter.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.sgd(step_size)[source]#

Construct optimizer triple for stochastic gradient descent.

Parameters:

step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.sm3(step_size,momentum=0.9)[source]#

Construct optimizer triple for SM3.

Memory-Efficient Adaptive Optimization for Large-Scale Learning.https://arxiv.org/abs/1901.11150

Parameters:
  • step_size – positive scalar, or a callable representing a step size schedulethat maps the iteration index to a positive scalar.

  • momentum – optional, a positive scalar value for momentum

Returns:

An (init_fun, update_fun, get_params) triple.

jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)[source]#

Converts an OptimizerState to a marked pytree.

Converts an OptimizerState to a marked pytree with the leaves of the outerpytree represented as JoinPoints to avoid losing information. This function isintended to be useful when serializing optimizer states.

Parameters:

opt_state – An OptimizerState

Returns:

A pytree with JoinPoint leaves that contain a second level of pytrees.


[8]ページ先頭

©2009-2025 Movatter.jp