- Notifications
You must be signed in to change notification settings - Fork19.7k
Orbax Loading and Sharding Support feature#21903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Open
amitsrivastava78 wants to merge18 commits intokeras-team:masterChoose a base branch fromamitsrivastava78:orbax-pending-features
base:master
Could not load branches
Branch not found:{{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline, and old review comments may become outdated.
+1,207 −286
Open
Changes fromall commits
Commits
Show all changes
18 commits Select commitHold shift + click to select a range
bc9060c Added Load method for orbax
amitsrivastava7843d45d0 Added Sharding Support
amitsrivastava784125ae0 Fix memory corruption in Model.load() by simplifying checkpoint loading
amitsrivastava7877689d9 Fix bare except clauses in orbax_checkpoint_test.py
amitsrivastava78ece275d Refactor duplicated tensor-to-numpy conversion code
amitsrivastava780464c3b Multi-host feature support
amitsrivastava78d8a86e8 Fixed CI failure
amitsrivastava7843fbecd Implement JAX-only multi-host checkpointing with proper Orbax APIs
amitsrivastava7882b345f Re-run CI
amitsrivastava789e35729 Re-run CI
amitsrivastava785a5c810 Fixed review comments
amitsrivastava78b503b79 Fixed review comments
amitsrivastava78e45c186 Fixed review comments
amitsrivastava78f027ba9 Refactor Orbax integration to use LazyModule consistently
amitsrivastava78802d785 Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava78f98e27c Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava785ff095a Integrate assets into Orbax pytree as checkpointables
amitsrivastava781f28e74 Implement asset loading support in model.load_weights for Orbax check…
amitsrivastava78File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Jump to file
Failed to load files.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
140 changes: 113 additions & 27 deletionskeras/src/callbacks/orbax_checkpoint.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -8,7 +8,6 @@ | ||
| from keras.src.callbacks.monitor_callback import ( | ||
| MonitorCallback, # For metric monitoring logic | ||
| ) | ||
| from keras.src.utils.module_utils import ocp | ||
| # Context and AsyncOptions are accessed through the lazy-loaded ocp module | ||
| @@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback): | ||
| This callback saves the model's weights and optimizer state asynchronously | ||
| using Orbax, allowing training to continue without blocking for I/O. | ||
| **Multi-host Support**: When running in a multi-host distributed training | ||
| environment with JAX backend, this callback automatically coordinates | ||
| checkpointing across all hosts to ensure consistency and proper | ||
| synchronization. Multi-host checkpointing is only supported on JAX. | ||
| Example: | ||
| ```python | ||
| @@ -138,6 +142,9 @@ def __init__( | ||
| self._current_epoch = 0 # Keep track of epoch | ||
| self._total_batches_seen = 0 # Global batch counter for step tracking | ||
| # Multi-host support | ||
| self._multihost_initialized = self._is_multihost_initialized() | ||
| if self.save_freq != "epoch" and not isinstance(self.save_freq, int): | ||
| raise ValueError( | ||
| f"Unrecognized save_freq: {self.save_freq}. " | ||
| @@ -151,14 +158,18 @@ def __init__( | ||
| ocp.training.preservation_policies.LatestN(max_to_keep) | ||
| ) | ||
| # Use AnyPreservationPolicy to combine them, or use directly | ||
| # if single policy | ||
| preservation_policy = None | ||
| if policies: | ||
| if len(policies) == 1: | ||
| preservation_policy = policies[0] | ||
| else: | ||
| preservation_policy = ( | ||
| ocp.training.preservation_policies.AnyPreservationPolicy( | ||
| policies | ||
| ) | ||
| ) | ||
| # Create the V1 Checkpointer with direct parameter passing | ||
| # Orbax will handle directory creation on all processes as needed | ||
| @@ -167,6 +178,54 @@ def __init__( | ||
| preservation_policy=preservation_policy, | ||
| ) | ||
| def _is_multihost_initialized(self): | ||
| """Check if multi-host environment is initialized.""" | ||
| # Multi-host checkpointing is only supported on JAX backend | ||
| if backend.backend() != "jax": | ||
| return False | ||
| multihost = ocp.multihost | ||
| # Check if JAX distributed client is initialized | ||
| # (indicates multihost setup) | ||
| return multihost.is_jax_distributed_client_initialized() | ||
| def _sync_processes(self, key=None): | ||
| """Synchronize all processes across hosts.""" | ||
| if not self._multihost_initialized: | ||
| return # No-op for single host | ||
| multihost = ocp.multihost | ||
| sync_key = key or "orbax_checkpoint_sync" | ||
| multihost.sync_global_processes(sync_key) | ||
| def is_multihost_enabled(self): | ||
| """Return True if multi-host checkpointing is enabled and initialized. | ||
| This method can be used to check if the callback is operating in | ||
| a multi-host distributed training environment. Multi-host checkpointing | ||
| is only supported on JAX backend. | ||
| Returns: | ||
| bool: True if multi-host support is active, False otherwise. | ||
| """ | ||
| return self._multihost_initialized | ||
| def is_primary_host(self): | ||
| """Return True if this process is the primary host in multi-host setup. | ||
| In multi-host environments, only the primary host typically handles | ||
| logging and coordination tasks. Multi-host checkpointing is only | ||
| supported on JAX backend. | ||
| Returns: | ||
| bool: True if this is the primary host, False otherwise. | ||
| Always returns True in single-host environments. | ||
| """ | ||
| if not self._multihost_initialized: | ||
| return True # Single host is always primary | ||
| multihost = ocp.multihost | ||
| return multihost.is_primary_host() | ||
| def _should_save_on_batch(self, batch): | ||
| """Check if we should save on this batch.""" | ||
| if self.save_freq == "epoch": | ||
| @@ -186,7 +245,7 @@ def _should_save_on_batch(self, batch): | ||
| return False | ||
| def _save_checkpoint(self, step, logs=None): | ||
| """Save a checkpoint at the given step with multi-host coordination.""" | ||
| # --- Prepare Composite State (Backend-Agnostic) --- | ||
| state_tree = _get_state_tree(self.model) | ||
| @@ -201,17 +260,46 @@ def _save_checkpoint(self, step, logs=None): | ||
| composite_state["non_trainable_variables"] = state_tree[ | ||
| "non_trainable_variables" | ||
| ] | ||
| # Include assets even for weights-only checkpoints | ||
| assets_tree = {} | ||
| for layer in self.model.layers: | ||
| if hasattr(layer, "asset_data"): | ||
| # Convert TrackedDict to dict, handle bytes as base64 | ||
| asset_dict = {} | ||
| for key, value in layer.asset_data.items(): | ||
| if isinstance(value, bytes): | ||
| import base64 | ||
| asset_dict[key] = base64.b64encode(value).decode( | ||
| "ascii" | ||
| ) | ||
| else: | ||
| asset_dict[key] = value | ||
| assets_tree[layer.name] = asset_dict | ||
| composite_state["assets"] = assets_tree | ||
| else: | ||
| composite_state = { | ||
| "model_config": self.model.get_config(), | ||
amitsrivastava78 marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| **state_tree, | ||
| } | ||
| # Include assets as part of the tree | ||
| if not self.save_weights_only: | ||
| assets_tree = {} | ||
| for layer in self.model.layers: | ||
| if hasattr(layer, "asset_data"): | ||
| # Convert TrackedDict to dict, handle bytes as base64 | ||
| asset_dict = {} | ||
| for key, value in layer.asset_data.items(): | ||
| if isinstance(value, bytes): | ||
| import base64 | ||
| asset_dict[key] = base64.b64encode( | ||
| value | ||
| ).decode("ascii") | ||
| else: | ||
| asset_dict[key] = value | ||
| assets_tree[layer.name] = asset_dict | ||
| composite_state["assets"] = assets_tree | ||
| # Use a single with statement. If context_options is empty, | ||
| # Context() uses defaults. | ||
| @@ -282,18 +370,16 @@ def on_train_end(self, logs=None): | ||
| except Exception: | ||
| pass # Ignore errors during cleanup | ||
| # Multi-host synchronization: ensure all hosts complete cleanup | ||
| self._sync_processes("checkpoint_cleanup") | ||
| def wait_until_finished(self): | ||
| """Wait for any in-progress checkpoint operations to complete. | ||
| This method blocks until all asynchronous checkpoint save operations | ||
| have completed across all hosts in a multi-host setup. | ||
| """ | ||
| # Wait for any async operations to complete on this host | ||
| self.checkpointer.wait() | ||
| # Multi-host synchronization: ensure all hosts complete | ||
| self._sync_processes("checkpoint_wait_complete") | ||
amitsrivastava78 marked this conversation as resolved. Show resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
Oops, something went wrong.
Uh oh!
There was an error while loading.Please reload this page.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.