Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Orbax Loading and Sharding Support feature#21903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
amitsrivastava78 wants to merge18 commits intokeras-team:master
base:master
Choose a base branch
Loading
fromamitsrivastava78:orbax-pending-features
Open
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
18 commits
Select commitHold shift + click to select a range
bc9060c
Added Load method for orbax
amitsrivastava78Dec 5, 2025
43d45d0
Added Sharding Support
amitsrivastava78Dec 8, 2025
4125ae0
Fix memory corruption in Model.load() by simplifying checkpoint loading
amitsrivastava78Dec 8, 2025
77689d9
Fix bare except clauses in orbax_checkpoint_test.py
amitsrivastava78Dec 8, 2025
ece275d
Refactor duplicated tensor-to-numpy conversion code
amitsrivastava78Dec 8, 2025
0464c3b
Multi-host feature support
amitsrivastava78Dec 8, 2025
d8a86e8
Fixed CI failure
amitsrivastava78Dec 8, 2025
43fbecd
Implement JAX-only multi-host checkpointing with proper Orbax APIs
amitsrivastava78Dec 10, 2025
82b345f
Re-run CI
amitsrivastava78Dec 11, 2025
9e35729
Re-run CI
amitsrivastava78Dec 11, 2025
5a5c810
Fixed review comments
amitsrivastava78Dec 12, 2025
b503b79
Fixed review comments
amitsrivastava78Dec 12, 2025
e45c186
Fixed review comments
amitsrivastava78Dec 12, 2025
f027ba9
Refactor Orbax integration to use LazyModule consistently
amitsrivastava78Dec 15, 2025
802d785
Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava78Dec 15, 2025
f98e27c
Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava78Dec 15, 2025
5ff095a
Integrate assets into Orbax pytree as checkpointables
amitsrivastava78Dec 17, 2025
1f28e74
Implement asset loading support in model.load_weights for Orbax check…
amitsrivastava78Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 113 additions & 27 deletionskeras/src/callbacks/orbax_checkpoint.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,7 +8,6 @@
from keras.src.callbacks.monitor_callback import (
MonitorCallback, # For metric monitoring logic
)
from keras.src.utils.io_utils import print_msg
from keras.src.utils.module_utils import ocp

# Context and AsyncOptions are accessed through the lazy-loaded ocp module
Expand DownExpand Up@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
This callback saves the model's weights and optimizer state asynchronously
using Orbax, allowing training to continue without blocking for I/O.

**Multi-host Support**: When running in a multi-host distributed training
environment with JAX backend, this callback automatically coordinates
checkpointing across all hosts to ensure consistency and proper
synchronization. Multi-host checkpointing is only supported on JAX.

Example:

```python
Expand DownExpand Up@@ -138,6 +142,9 @@ def __init__(
self._current_epoch = 0 # Keep track of epoch
self._total_batches_seen = 0 # Global batch counter for step tracking

# Multi-host support
self._multihost_initialized = self._is_multihost_initialized()

if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError(
f"Unrecognized save_freq: {self.save_freq}. "
Expand All@@ -151,14 +158,18 @@ def __init__(
ocp.training.preservation_policies.LatestN(max_to_keep)
)

# Use AnyPreservationPolicy to combine them.
# Use AnyPreservationPolicy to combine them, or use directly
# if single policy
preservation_policy = None
if policies:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
if len(policies) == 1:
preservation_policy = policies[0]
else:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
)
)
)

# Create the V1 Checkpointer with direct parameter passing
# Orbax will handle directory creation on all processes as needed
Expand All@@ -167,6 +178,54 @@ def __init__(
preservation_policy=preservation_policy,
)

def _is_multihost_initialized(self):
"""Check if multi-host environment is initialized."""
# Multi-host checkpointing is only supported on JAX backend
if backend.backend() != "jax":
return False

multihost = ocp.multihost
# Check if JAX distributed client is initialized
# (indicates multihost setup)
return multihost.is_jax_distributed_client_initialized()

def _sync_processes(self, key=None):
"""Synchronize all processes across hosts."""
if not self._multihost_initialized:
return # No-op for single host

multihost = ocp.multihost
sync_key = key or "orbax_checkpoint_sync"
multihost.sync_global_processes(sync_key)

def is_multihost_enabled(self):
"""Return True if multi-host checkpointing is enabled and initialized.

This method can be used to check if the callback is operating in
a multi-host distributed training environment. Multi-host checkpointing
is only supported on JAX backend.

Returns:
bool: True if multi-host support is active, False otherwise.
"""
return self._multihost_initialized

def is_primary_host(self):
"""Return True if this process is the primary host in multi-host setup.

In multi-host environments, only the primary host typically handles
logging and coordination tasks. Multi-host checkpointing is only
supported on JAX backend.

Returns:
bool: True if this is the primary host, False otherwise.
Always returns True in single-host environments.
"""
if not self._multihost_initialized:
return True # Single host is always primary
multihost = ocp.multihost
return multihost.is_primary_host()

def _should_save_on_batch(self, batch):
"""Check if we should save on this batch."""
if self.save_freq == "epoch":
Expand All@@ -186,7 +245,7 @@ def _should_save_on_batch(self, batch):
return False

def _save_checkpoint(self, step, logs=None):
"""Save a checkpoint at the given step."""
"""Save a checkpoint at the given step with multi-host coordination."""

# --- Prepare Composite State (Backend-Agnostic) ---
state_tree = _get_state_tree(self.model)
Expand All@@ -201,17 +260,46 @@ def _save_checkpoint(self, step, logs=None):
composite_state["non_trainable_variables"] = state_tree[
"non_trainable_variables"
]
# Include assets even for weights-only checkpoints
assets_tree = {}
for layer in self.model.layers:
if hasattr(layer, "asset_data"):
# Convert TrackedDict to dict, handle bytes as base64
asset_dict = {}
for key, value in layer.asset_data.items():
if isinstance(value, bytes):
import base64

asset_dict[key] = base64.b64encode(value).decode(
"ascii"
)
else:
asset_dict[key] = value
assets_tree[layer.name] = asset_dict
composite_state["assets"] = assets_tree
else:
composite_state = state_tree

# --- Save Logic (V1 API) ---
# All processes participate in distributed checkpointing
# Checkpointer is configured to save unconditionally when
# save_pytree is called
if self.verbose > 0:
print_msg(
f"OrbaxCheckpoint: Triggering async save for step {step}..."
)
composite_state = {
"model_config": self.model.get_config(),
**state_tree,
}
# Include assets as part of the tree
if not self.save_weights_only:
assets_tree = {}
for layer in self.model.layers:
if hasattr(layer, "asset_data"):
# Convert TrackedDict to dict, handle bytes as base64
asset_dict = {}
for key, value in layer.asset_data.items():
if isinstance(value, bytes):
import base64

asset_dict[key] = base64.b64encode(
value
).decode("ascii")
else:
asset_dict[key] = value
assets_tree[layer.name] = asset_dict
composite_state["assets"] = assets_tree

# Use a single with statement. If context_options is empty,
# Context() uses defaults.
Expand DownExpand Up@@ -282,18 +370,16 @@ def on_train_end(self, logs=None):
except Exception:
pass # Ignore errors during cleanup

# Multi-host synchronization: ensure all hosts complete cleanup
self._sync_processes("checkpoint_cleanup")

def wait_until_finished(self):
"""Wait for any in-progress checkpoint operations to complete.
This method blocks until all asynchronous checkpoint save operations
have completed. It should be called before attempting to load
checkpoints if there might be pending save operations.
have completed across all hosts in a multi-host setup.
"""
# Wait for any async operations to complete
if hasattr(self.checkpointer, "wait"):
self.checkpointer.wait()
else:
# Fallback for older Orbax versions that don't have wait() method
while self.checkpointer.is_saving_in_progress():
import time
# Wait for any async operations to complete on this host
self.checkpointer.wait()

time.sleep(0.1)
# Multi-host synchronization: ensure all hosts complete
self._sync_processes("checkpoint_wait_complete")
Loading

[8]ページ先頭

©2009-2025 Movatter.jp