After five months of extensive public beta testing,we're excited to announce the official release of Keras 3.0.Keras 3 is a full rewrite of Keras that enables you torun your Keras workflows on top of either JAX, TensorFlow, PyTorch, or OpenVINO (for inference-only),and that unlocks brand new large-scale model training and deployment capabilities.You can pick the framework that suits you best,and switch from one to another based on your current goals.You can also use Keras as a low-level cross-framework languageto develop custom components such as layers, models, or metricsthat can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase.
You're already familiar with the benefits of using Keras — it enableshigh-velocity development via an obsessive focus on great UX, API design,and debuggability. It's also a battle-tested framework that has been chosenby over 2.5M developers and that powers some of the most sophisticated,largest-scale ML systems in the world,such as the Waymo self-driving fleet and the YouTube recommendation engine.But what are the additional benefits of using the new multi-backend Keras 3?
Module
, can be exported as a TensorFlowSavedModel
, or can be instantiated as a stateless JAX function. That meansthat you can use your Keras 3 models with PyTorch ecosystem packages,with the full range of TensorFlow deployment & production tools(like TF-Serving, TF.js and TFLite), and with JAX large-scaleTPU training infrastructure. Write onemodel.py
usingKeras 3 APIs, and get access to everything the ML world has to offer.keras.distribution
namespace,currently implemented for the JAX backend (coming soon to the TensorFlow and PyTorch backends).It makes it easy to do model parallelism, data parallelism, and combinations of both —at arbitrary model scales and cluster scales.Because it keeps the model definition, training logic,and sharding configuration all separate from each other,it makes your distribution workflow easy to develop and easy to maintain.See ourstarter guide.fit()
/evaluate()
/predict()
routines are compatible withtf.data.Dataset
objects,with PyTorchDataLoader
objects, with NumPy arrays, Pandas dataframes —regardless of the backend you're using. You can train a Keras 3 + TensorFlowmodel on a PyTorchDataLoader
or train a Keras 3 + PyTorch model on atf.data.Dataset
.Keras 3 implements the full Keras API and makes it availablewith TensorFlow, JAX, and PyTorch — over a hundred layers, dozens of metrics,loss functions, optimizers, and callbacks, the Keras training and evaluationloops, and the Keras saving & serialization infrastructure. All the APIs youknow and love are here.
Any Keras model that only uses built-in layers will immediately work withall supported backends. In fact, your existingtf.keras
modelsthat only use built-in layers can start running in JAX and PyTorchright away!That's right, your codebase just gained a whole new set of capabilities.
Keras 3 enables you to create components(like arbitrary custom layers or pretrained models) that will work the samein any framework. In particular, Keras 3 gives you accessto thekeras.ops
namespace that works across all backends. It contains:
ops.matmul
,ops.sum
,ops.stack
,ops.einsum
, etc.ops.softmax
,ops.binary_crossentropy
,ops.conv
, etc.As long as you only use ops fromkeras.ops
, your custom layers,custom losses, custom metrics, and custom optimizerswill work with JAX, PyTorch, and TensorFlow — with the same code.That means that you can maintain only onecomponent implementation (e.g. a singlemodel.py
together with a single checkpoint file), and you can use it in all frameworks,with the exact same numerics.
Keras 3 is not just intended for Keras-centric workflowswhere you define a Keras model, a Keras optimizer, a Keras loss and metrics,and you callfit()
,evaluate()
, andpredict()
.It's also meant to work seamlessly with low-level backend-native workflows:you can take a Keras model (or any other component, such as a loss or metric)and start using it in a JAX training loop, a TensorFlow training loop,or a PyTorch training loop, or as part of a JAX or PyTorch model,with zero friction. Keras 3 provides exactlythe same degree of low-level implementation flexibility in JAX and PyTorchastf.keras
previously did in TensorFlow.
You can:
optax
optimizer,jax.grad
,jax.jit
,jax.pmap
.tf.GradientTape
andtf.distribute
.torch.optim
optimizer, atorch
loss function,and thetorch.nn.parallel.DistributedDataParallel
wrapper.Module
(because they areModule
instances too!)Module
in a Keras model as if it were a Keras layer.The models we've been working with have been getting larger and larger, so we wantedto provide a Kerasic solution to the multi-device model sharding problem. The API we designedkeeps the model definition, the training logic, and the sharding configuration entirely separate from eachother, meaning that your models can be written as if they were going to run on a single device. Youcan then add arbitrary sharding configurations to arbitrary models when it's time to train them.
Data parallelism (replicating a small model identically on multiple devices) can be handled in just two lines:
Model parallelism lets you specify sharding layouts for model variables and intermediate output tensors,along multiple named dimensions. In the typical case, you would organize available devices as a 2D grid(called adevice mesh), where the first dimension is used for data parallelism and the second dimensionis used for model parallelism. You would then configure your model to be sharded along the model dimensionand replicated along the data dimension.
The API lets you configure the layout of every variable and every output tensor via regular expressions.This makes it easy to quickly specify the same layout for entire categories of variables.
The new distribution API is intended to be multi-backend, but is only available for the JAX backend for the timebeing. TensorFlow and PyTorch support is coming soon. Get started withthis guide!
There's a wide range of pretrained models thatyou can start using today with Keras 3.
All 40 Keras Applications models (thekeras.applications
namespace)are available in all backends.The vast array of pretrained models inKerasCVandKerasHub also work with all backends. This includes:
Multi-framework ML also means multi-framework data loading and preprocessing.Keras 3 models can be trained using a wide range ofdata pipelines — regardless of whether you're using the JAX, PyTorch, orTensorFlow backends. It just works.
tf.data.Dataset
pipelines: the reference for scalable production ML.torch.utils.data.DataLoader
objects.keras.utils.PyDataset
objects.Progressive disclosure of complexity is the design principle at the heartof the Keras API. Keras doesn't force you to followa single "true" way of building and training models. Instead, it enablesa wide range of different workflows, from the very high-level to the verylow-level, corresponding to different user profiles.
That means that you can start out with simple workflows — such as usingSequential
andFunctional
models and training them withfit()
— and whenyou need more flexibility, you can easily customize different components whilereusing most of your prior code. As your needs become more specific,you don't suddenly fall off a complexity cliff and you don't need to switchto a different set of tools.
We've brought this principle to all of our backends. For instance,you can customize what happens in your training loop while stillleveraging the power offit()
, without having to write your own training loopfrom scratch — just by overriding thetrain_step
method.
Here's how it works in PyTorch and TensorFlow:
Andhere's the link to the JAX version.
Do you enjoyfunctional programming?You're in for a treat.
All stateful objects in Keras (i.e. objects that own numerical variables thatget updated during training or evaluation) now have a stateless API, making itpossible to use them in JAX functions (which are required to be fully stateless):
stateless_call()
method which mirrors__call__()
.stateless_apply()
method which mirrorsapply()
.stateless_update_state()
method which mirrorsupdate_state()
and astateless_result()
method which mirrorsresult()
.These methods have no side-effects whatsoever: they take as input the current valueof the state variables of the target object, and return the update values as partof their outputs, e.g.:
outputs,updated_non_trainable_variables=layer.stateless_call(trainable_variables,non_trainable_variables,inputs,)
You never have to implement these methods yourself — they're automatically availableas long as you've implemented the stateful version (e.g.call()
orupdate_state()
).
Starting with release 3.8, Keras introduces the OpenVINO backend that is an inference-only backend,meaning it is designed only for running model predictions usingpredict()
method.This backend enables to leverage OpenVINO performance optimizations directlywithin the Keras workflow, enabling faster inference on OpenVINO supported hardware.
To switch to the OpenVINO backend, set the KERAS_BACKEND environment variableto"openvino"
or specify the backend in the local configuration file at~/.keras/keras.json
.Here is an example of how to infer a model (trained with PyTorch, JAX, or TensorFlow backends),using the OpenVINO backend:
importosos.environ["KERAS_BACKEND"]="openvino"importkerasloaded_model=keras.saving.load_model(...)predictions=loaded_model.predict(...)
Note that the OpenVINO backend may currently lack support for some operations.This will be addressed in upcoming Keras releases as operation coverage is being expanded.
Keras 3 is highly backwards compatible with Keras 2:it implements the full public API surface of Keras 2,with a limited number of exceptions, listedhere.Most users will not have to make any code changeto start running their Keras scripts on Keras 3.
Larger codebases are likely to require some code changes,since they are more likely to run into one of the exceptions listed above,and are more likely to have been using private APIs or deprecated APIs(tf.compat.v1.keras
namespace,experimental
namespace,keras.src
private namespace).To help you move to Keras 3, we are releasing a completemigration guidewith quick fixes for all issues you might encounter.
You also have the option to ignore the changes in Keras 3 and just keep using Keras 2 with TensorFlow —this can be a good option for projects that are not actively developedbut need to keep running with updated dependencies.You have two possibilities:
keras
as a standalone package,just switch to using the Python packagetf_keras
instead,which you can install viapip install tf_keras
.The code and API are wholly unchanged — it's Keras 2.15 with a different package name.We will keep fixing bugs intf_keras
and we will keep regularly releasing new versions.However, no new features or performance improvements will be added,since the package is now in maintenance mode.keras
viatf.keras
,there are no immediate changes until TensorFlow 2.16.TensorFlow 2.16+ will use Keras 3 by default.In TensorFlow 2.16+, to keep using Keras 2, you can first installtf_keras
,and then export the environment variableTF_USE_LEGACY_KERAS=1
.This will direct TensorFlow 2.16+ to resolve tf.keras to the locally-installedtf_keras
package.Note that this may affect more than your own code, however:it will affect any package importingtf.keras
in your Python process.To make sure your changes only affect your own code, you should use thetf_keras
package.We're excited for you to try out the new Keras and improve your workflows by leveraging multi-framework ML.Let us know how it goes: issues, points of friction, feature requests, or success stories —we're eager to hear from you!
Code developed withtf.keras
can generally be run as-is with Keras 3(with the TensorFlow backend). There's a limited number of incompatibilities you should be mindfulof, all addressed inthis migration guide.
When it comes to using APIs fromtf.keras
and Keras 3 side by side,that isnot possible — they're different packages, running on entirely separate engines.
Generally, yes. Anytf.keras
model should work out of the box with Keras 3with the TensorFlow backend (make sure to save it in the.keras
v3 format).In addition, if the model onlyuses built-in Keras layers, then it will also work out of the boxwith Keras 3 with the JAX and PyTorch backends.
If the model contains custom layers written using TensorFlow APIs,it is usually easy to convert the code to be backend-agnostic.For instance, it only took us a few hours to convert all 40legacytf.keras
models from Keras Applications to be backend-agnostic.
Yes, you can. There is no backend specialization in saved.keras
files whatsoever.Your saved Keras models are framework-agnostic and can be reloaded with any backend.
However, note that reloading a model that contains custom componentswith a different backend requires your custom components to be implementedusing backend-agnostic APIs, e.g.keras.ops
.
tf.data
pipelines?With the TensorFlow backend, Keras 3 is fully compatible withtf.data
(e.g. you can.map()
aSequential
model into atf.data
pipeline).
With a different backend, Keras 3 has limited support fortf.data
.You won't be able to.map()
arbitrary layers or models into atf.data
pipeline. However, you will be able to use specific Keras 3preprocessing layers withtf.data
, such asIntegerLookup
orCategoryEncoding
.
When it comes to using atf.data
pipeline (that does not use Keras)to feed your call to.fit()
,.evaluate()
or.predict()
—that works out of the box with all backends.
Yes, numerics are identical across backends.However, keep in mind the following caveats:
1e-7
precision in float32,per function execution. So when training a model for a long time,small numerical differences will accumulate and may end up resultingin noticeable numerical differences.padding="same"
may result in different numerics on border rows/columns.This doesn't happen very often in practice —out of 40 Keras Applications vision models, only one was affected.Data-parallel distribution is supported out of the box in JAX, TensorFlow,and PyTorch. Model parallel distribution is supported out of the box for JAXwith thekeras.distribution
API.
With TensorFlow:
Keras 3 is compatible withtf.distribute
—just open a Distribution Strategy scope and create / train your model within it.Here's an example.
With PyTorch:
Keras 3 is compatible with PyTorch'sDistributedDataParallel
utility.Here's an example.
With JAX:
You can do both data parallel and model parallel distribution in JAX using thekeras.distribution
API.For instance, to do data parallel distribution, you only need the following code snippet:
distribution=keras.distribution.DataParallel(devices=keras.distribution.list_devices())keras.distribution.set_distribution(distribution)
For model parallel distribution, seethe following guide.
You can also distribute training yourself via JAX APIs such asjax.sharding
.Here's an example.
Modules
or with FlaxModules
?If they are only written using Keras APIs (e.g. thekeras.ops
namespace), then yes, yourKeras layers will work out of the box with native PyTorch and JAX code.In PyTorch, just use your Keras layer like any other PyTorchModule
.In JAX, make sure to use the stateless layer API, i.e.layer.stateless_call()
.
We're open to adding new backends as long as the target framework has a large user baseor otherwise has some unique technical benefits to bring to the table.However, adding and maintaining a new backend is a large burden,so we're going to carefully consider each new backend candidate on a case by case basis,and we're not likely to add many new backends. We will not add any new frameworksthat aren't yet well-established.We are now potentially considering adding a backend written inMojo.If that's something you might find useful, please let the Mojo team know.