Movatterモバイル変換


[0]ホーム

URL:


Introducing Keras 3.0

After five months of extensive public beta testing,we're excited to announce the official release of Keras 3.0.Keras 3 is a full rewrite of Keras that enables you torun your Keras workflows on top of either JAX, TensorFlow, PyTorch, or OpenVINO (for inference-only),and that unlocks brand new large-scale model training and deployment capabilities.You can pick the framework that suits you best,and switch from one to another based on your current goals.You can also use Keras as a low-level cross-framework languageto develop custom components such as layers, models, or metricsthat can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase.


Welcome to multi-framework machine learning.

You're already familiar with the benefits of using Keras — it enableshigh-velocity development via an obsessive focus on great UX, API design,and debuggability. It's also a battle-tested framework that has been chosenby over 2.5M developers and that powers some of the most sophisticated,largest-scale ML systems in the world,such as the Waymo self-driving fleet and the YouTube recommendation engine.But what are the additional benefits of using the new multi-backend Keras 3?

  • Always get the best performance for your models. In our benchmarks,we found that JAX typically delivers the best training and inference performanceon GPU, TPU, and CPU — but results vary from model to model, as non-XLATensorFlow is occasionally faster on GPU. The ability to dynamically selectthe backend that will deliver the best performance for your modelwithout having to change anything to your code means you're guaranteedto train and serve with the highest achievable efficiency.
  • Unlock ecosystem optionality for your models. Any Keras 3model can be instantiated as a PyTorchModule, can be exported as a TensorFlowSavedModel, or can be instantiated as a stateless JAX function. That meansthat you can use your Keras 3 models with PyTorch ecosystem packages,with the full range of TensorFlow deployment & production tools(like TF-Serving, TF.js and TFLite), and with JAX large-scaleTPU training infrastructure. Write onemodel.py usingKeras 3 APIs, and get access to everything the ML world has to offer.
  • Leverage large-scale model parallelism & data parallelism with JAX. Keras 3 includesa brand new distribution API, thekeras.distribution namespace,currently implemented for the JAX backend (coming soon to the TensorFlow and PyTorch backends).It makes it easy to do model parallelism, data parallelism, and combinations of both —at arbitrary model scales and cluster scales.Because it keeps the model definition, training logic,and sharding configuration all separate from each other,it makes your distribution workflow easy to develop and easy to maintain.See ourstarter guide.
  • Maximize reach for your open-source model releases. Want torelease a pretrained model? Want as many people as possibleto be able to use it? If you implement it in pure TensorFlow or PyTorch,it will be usable by roughly half of the community.If you implement it in Keras 3, it is instantly usable by anyone regardlessof their framework of choice (even if they're not Keras users themselves).Twice the impact at no added development cost.
  • Use data pipelines from any source. The Keras 3fit()/evaluate()/predict() routines are compatible withtf.data.Dataset objects,with PyTorchDataLoader objects, with NumPy arrays, Pandas dataframes —regardless of the backend you're using. You can train a Keras 3 + TensorFlowmodel on a PyTorchDataLoader or train a Keras 3 + PyTorch model on atf.data.Dataset.

The full Keras API, available for JAX, TensorFlow, and PyTorch.

Keras 3 implements the full Keras API and makes it availablewith TensorFlow, JAX, and PyTorch — over a hundred layers, dozens of metrics,loss functions, optimizers, and callbacks, the Keras training and evaluationloops, and the Keras saving & serialization infrastructure. All the APIs youknow and love are here.

Any Keras model that only uses built-in layers will immediately work withall supported backends. In fact, your existingtf.keras modelsthat only use built-in layers can start running in JAX and PyTorchright away!That's right, your codebase just gained a whole new set of capabilities.


Author multi-framework layers, models, metrics...

Keras 3 enables you to create components(like arbitrary custom layers or pretrained models) that will work the samein any framework. In particular, Keras 3 gives you accessto thekeras.ops namespace that works across all backends. It contains:

  • A full implementation of the NumPy API.Not something "NumPy-like" — just literally theNumPy API, with the same functions and the same arguments.You getops.matmul,ops.sum,ops.stack,ops.einsum, etc.
  • A set of neural network-specific functions that are absent from NumPy,such asops.softmax,ops.binary_crossentropy,ops.conv, etc.

As long as you only use ops fromkeras.ops, your custom layers,custom losses, custom metrics, and custom optimizerswill work with JAX, PyTorch, and TensorFlow — with the same code.That means that you can maintain only onecomponent implementation (e.g. a singlemodel.pytogether with a single checkpoint file), and you can use it in all frameworks,with the exact same numerics.


...that works seamlessly with any JAX, TensorFlow, and PyTorch workflow.

Keras 3 is not just intended for Keras-centric workflowswhere you define a Keras model, a Keras optimizer, a Keras loss and metrics,and you callfit(),evaluate(), andpredict().It's also meant to work seamlessly with low-level backend-native workflows:you can take a Keras model (or any other component, such as a loss or metric)and start using it in a JAX training loop, a TensorFlow training loop,or a PyTorch training loop, or as part of a JAX or PyTorch model,with zero friction. Keras 3 provides exactlythe same degree of low-level implementation flexibility in JAX and PyTorchastf.keras previously did in TensorFlow.

You can:

  • Write a low-level JAX training loop to train a Keras modelusing anoptax optimizer,jax.grad,jax.jit,jax.pmap.
  • Write a low-level TensorFlow training loop to train a Keras modelusingtf.GradientTape andtf.distribute.
  • Write a low-level PyTorch training loop to train a Keras modelusing atorch.optim optimizer, atorch loss function,and thetorch.nn.parallel.DistributedDataParallel wrapper.
  • Use Keras layers in a PyTorchModule (because they areModule instances too!)
  • Use any PyTorchModule in a Keras model as if it were a Keras layer.
  • etc.


A new distribution API for large-scale data parallelism and model parallelism.

The models we've been working with have been getting larger and larger, so we wantedto provide a Kerasic solution to the multi-device model sharding problem. The API we designedkeeps the model definition, the training logic, and the sharding configuration entirely separate from eachother, meaning that your models can be written as if they were going to run on a single device. Youcan then add arbitrary sharding configurations to arbitrary models when it's time to train them.

Data parallelism (replicating a small model identically on multiple devices) can be handled in just two lines:

Model parallelism lets you specify sharding layouts for model variables and intermediate output tensors,along multiple named dimensions. In the typical case, you would organize available devices as a 2D grid(called adevice mesh), where the first dimension is used for data parallelism and the second dimensionis used for model parallelism. You would then configure your model to be sharded along the model dimensionand replicated along the data dimension.

The API lets you configure the layout of every variable and every output tensor via regular expressions.This makes it easy to quickly specify the same layout for entire categories of variables.

The new distribution API is intended to be multi-backend, but is only available for the JAX backend for the timebeing. TensorFlow and PyTorch support is coming soon. Get started withthis guide!


Pretrained models.

There's a wide range of pretrained models thatyou can start using today with Keras 3.

All 40 Keras Applications models (thekeras.applications namespace)are available in all backends.The vast array of pretrained models inKerasCVandKerasHub also work with all backends. This includes:

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • etc.

Support for cross-framework data pipelines with all backends.

Multi-framework ML also means multi-framework data loading and preprocessing.Keras 3 models can be trained using a wide range ofdata pipelines — regardless of whether you're using the JAX, PyTorch, orTensorFlow backends. It just works.

  • tf.data.Dataset pipelines: the reference for scalable production ML.
  • torch.utils.data.DataLoader objects.
  • NumPy arrays and Pandas dataframes.
  • Keras's ownkeras.utils.PyDataset objects.

Progressive disclosure of complexity.

Progressive disclosure of complexity is the design principle at the heartof the Keras API. Keras doesn't force you to followa single "true" way of building and training models. Instead, it enablesa wide range of different workflows, from the very high-level to the verylow-level, corresponding to different user profiles.

That means that you can start out with simple workflows — such as usingSequential andFunctional models and training them withfit() — and whenyou need more flexibility, you can easily customize different components whilereusing most of your prior code. As your needs become more specific,you don't suddenly fall off a complexity cliff and you don't need to switchto a different set of tools.

We've brought this principle to all of our backends. For instance,you can customize what happens in your training loop while stillleveraging the power offit(), without having to write your own training loopfrom scratch — just by overriding thetrain_step method.

Here's how it works in PyTorch and TensorFlow:

Andhere's the link to the JAX version.


A new stateless API for layers, models, metrics, and optimizers.

Do you enjoyfunctional programming?You're in for a treat.

All stateful objects in Keras (i.e. objects that own numerical variables thatget updated during training or evaluation) now have a stateless API, making itpossible to use them in JAX functions (which are required to be fully stateless):

  • All layers and models have astateless_call() method which mirrors__call__().
  • All optimizers have astateless_apply() method which mirrorsapply().
  • All metrics have astateless_update_state() method which mirrorsupdate_state()and astateless_result() method which mirrorsresult().

These methods have no side-effects whatsoever: they take as input the current valueof the state variables of the target object, and return the update values as partof their outputs, e.g.:

outputs,updated_non_trainable_variables=layer.stateless_call(trainable_variables,non_trainable_variables,inputs,)

You never have to implement these methods yourself — they're automatically availableas long as you've implemented the stateful version (e.g.call() orupdate_state()).


Run inference with the OpenVINO backend.

Starting with release 3.8, Keras introduces the OpenVINO backend that is an inference-only backend,meaning it is designed only for running model predictions usingpredict() method.This backend enables to leverage OpenVINO performance optimizations directlywithin the Keras workflow, enabling faster inference on OpenVINO supported hardware.

To switch to the OpenVINO backend, set the KERAS_BACKEND environment variableto"openvino" or specify the backend in the local configuration file at~/.keras/keras.json.Here is an example of how to infer a model (trained with PyTorch, JAX, or TensorFlow backends),using the OpenVINO backend:

importosos.environ["KERAS_BACKEND"]="openvino"importkerasloaded_model=keras.saving.load_model(...)predictions=loaded_model.predict(...)

Note that the OpenVINO backend may currently lack support for some operations.This will be addressed in upcoming Keras releases as operation coverage is being expanded.


Moving from Keras 2 to Keras 3

Keras 3 is highly backwards compatible with Keras 2:it implements the full public API surface of Keras 2,with a limited number of exceptions, listedhere.Most users will not have to make any code changeto start running their Keras scripts on Keras 3.

Larger codebases are likely to require some code changes,since they are more likely to run into one of the exceptions listed above,and are more likely to have been using private APIs or deprecated APIs(tf.compat.v1.keras namespace,experimental namespace,keras.src private namespace).To help you move to Keras 3, we are releasing a completemigration guidewith quick fixes for all issues you might encounter.

You also have the option to ignore the changes in Keras 3 and just keep using Keras 2 with TensorFlow —this can be a good option for projects that are not actively developedbut need to keep running with updated dependencies.You have two possibilities:

  1. If you were accessingkeras as a standalone package,just switch to using the Python packagetf_keras instead,which you can install viapip install tf_keras.The code and API are wholly unchanged — it's Keras 2.15 with a different package name.We will keep fixing bugs intf_keras and we will keep regularly releasing new versions.However, no new features or performance improvements will be added,since the package is now in maintenance mode.
  2. If you were accessingkeras viatf.keras,there are no immediate changes until TensorFlow 2.16.TensorFlow 2.16+ will use Keras 3 by default.In TensorFlow 2.16+, to keep using Keras 2, you can first installtf_keras,and then export the environment variableTF_USE_LEGACY_KERAS=1.This will direct TensorFlow 2.16+ to resolve tf.keras to the locally-installedtf_keras package.Note that this may affect more than your own code, however:it will affect any package importingtf.keras in your Python process.To make sure your changes only affect your own code, you should use thetf_keras package.

Enjoy the library!

We're excited for you to try out the new Keras and improve your workflows by leveraging multi-framework ML.Let us know how it goes: issues, points of friction, feature requests, or success stories —we're eager to hear from you!


FAQ

Q: Is Keras 3 compatible with legacy Keras 2?

Code developed withtf.keras can generally be run as-is with Keras 3(with the TensorFlow backend). There's a limited number of incompatibilities you should be mindfulof, all addressed inthis migration guide.

When it comes to using APIs fromtf.keras and Keras 3 side by side,that isnot possible — they're different packages, running on entirely separate engines.

Q: Do pretrained models developed in legacy Keras 2 work with Keras 3?

Generally, yes. Anytf.keras model should work out of the box with Keras 3with the TensorFlow backend (make sure to save it in the.keras v3 format).In addition, if the model onlyuses built-in Keras layers, then it will also work out of the boxwith Keras 3 with the JAX and PyTorch backends.

If the model contains custom layers written using TensorFlow APIs,it is usually easy to convert the code to be backend-agnostic.For instance, it only took us a few hours to convert all 40legacytf.keras models from Keras Applications to be backend-agnostic.

Q: Can I save a Keras 3 model in one backend and reload it in another backend?

Yes, you can. There is no backend specialization in saved.keras files whatsoever.Your saved Keras models are framework-agnostic and can be reloaded with any backend.

However, note that reloading a model that contains custom componentswith a different backend requires your custom components to be implementedusing backend-agnostic APIs, e.g.keras.ops.

Q: Can I use Keras 3 components insidetf.data pipelines?

With the TensorFlow backend, Keras 3 is fully compatible withtf.data(e.g. you can.map() aSequential model into atf.data pipeline).

With a different backend, Keras 3 has limited support fortf.data.You won't be able to.map() arbitrary layers or models into atf.datapipeline. However, you will be able to use specific Keras 3preprocessing layers withtf.data, such asIntegerLookup orCategoryEncoding.

When it comes to using atf.data pipeline (that does not use Keras)to feed your call to.fit(),.evaluate() or.predict() —that works out of the box with all backends.

Q: Do Keras 3 models behave the same when run with different backends?

Yes, numerics are identical across backends.However, keep in mind the following caveats:

  • RNG behavior is different across different backends(even after seeding — your results will be deterministic in each backendbut will differ across backends). So random weight initializationsvalues and dropout values will differ across backends.
  • Due to the nature of floating-point implementations,results are only identical up to1e-7 precision in float32,per function execution. So when training a model for a long time,small numerical differences will accumulate and may end up resultingin noticeable numerical differences.
  • Due to lack of support for average pooling with asymmetric paddingin PyTorch, average pooling layers withpadding="same"may result in different numerics on border rows/columns.This doesn't happen very often in practice —out of 40 Keras Applications vision models, only one was affected.

Q: Does Keras 3 support distributed training?

Data-parallel distribution is supported out of the box in JAX, TensorFlow,and PyTorch. Model parallel distribution is supported out of the box for JAXwith thekeras.distribution API.

With TensorFlow:

Keras 3 is compatible withtf.distribute —just open a Distribution Strategy scope and create / train your model within it.Here's an example.

With PyTorch:

Keras 3 is compatible with PyTorch'sDistributedDataParallel utility.Here's an example.

With JAX:

You can do both data parallel and model parallel distribution in JAX using thekeras.distribution API.For instance, to do data parallel distribution, you only need the following code snippet:

distribution=keras.distribution.DataParallel(devices=keras.distribution.list_devices())keras.distribution.set_distribution(distribution)

For model parallel distribution, seethe following guide.

You can also distribute training yourself via JAX APIs such asjax.sharding.Here's an example.

Q: Can my custom Keras layers be used in native PyTorchModules or with FlaxModules?

If they are only written using Keras APIs (e.g. thekeras.ops namespace), then yes, yourKeras layers will work out of the box with native PyTorch and JAX code.In PyTorch, just use your Keras layer like any other PyTorchModule.In JAX, make sure to use the stateless layer API, i.e.layer.stateless_call().

Q: Will you add more backends in the future? What about framework XYZ?

We're open to adding new backends as long as the target framework has a large user baseor otherwise has some unique technical benefits to bring to the table.However, adding and maintaining a new backend is a large burden,so we're going to carefully consider each new backend candidate on a case by case basis,and we're not likely to add many new backends. We will not add any new frameworksthat aren't yet well-established.We are now potentially considering adding a backend written inMojo.If that's something you might find useful, please let the Mojo team know.


[8]ページ先頭

©2009-2025 Movatter.jp