Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Orbax Loading and Sharding Support feature#21903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
amitsrivastava78 wants to merge18 commits intokeras-team:master
base:master
Choose a base branch
Loading
fromamitsrivastava78:orbax-pending-features
Open
Show file tree
Hide file tree
Changes from1 commit
Commits
Show all changes
18 commits
Select commitHold shift + click to select a range
bc9060c
Added Load method for orbax
amitsrivastava78Dec 5, 2025
43d45d0
Added Sharding Support
amitsrivastava78Dec 8, 2025
4125ae0
Fix memory corruption in Model.load() by simplifying checkpoint loading
amitsrivastava78Dec 8, 2025
77689d9
Fix bare except clauses in orbax_checkpoint_test.py
amitsrivastava78Dec 8, 2025
ece275d
Refactor duplicated tensor-to-numpy conversion code
amitsrivastava78Dec 8, 2025
0464c3b
Multi-host feature support
amitsrivastava78Dec 8, 2025
d8a86e8
Fixed CI failure
amitsrivastava78Dec 8, 2025
43fbecd
Implement JAX-only multi-host checkpointing with proper Orbax APIs
amitsrivastava78Dec 10, 2025
82b345f
Re-run CI
amitsrivastava78Dec 11, 2025
9e35729
Re-run CI
amitsrivastava78Dec 11, 2025
5a5c810
Fixed review comments
amitsrivastava78Dec 12, 2025
b503b79
Fixed review comments
amitsrivastava78Dec 12, 2025
e45c186
Fixed review comments
amitsrivastava78Dec 12, 2025
f027ba9
Refactor Orbax integration to use LazyModule consistently
amitsrivastava78Dec 15, 2025
802d785
Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava78Dec 15, 2025
f98e27c
Fix Orbax checkpoint preservation policy and asset functionality
amitsrivastava78Dec 15, 2025
5ff095a
Integrate assets into Orbax pytree as checkpointables
amitsrivastava78Dec 17, 2025
1f28e74
Implement asset loading support in model.load_weights for Orbax check…
amitsrivastava78Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PrevPrevious commit
NextNext commit
Fix Orbax checkpoint preservation policy and asset functionality
- Remove manual directory cleanup in _save_checkpoint that interfered with Orbax preservation policies- Simplify preservation policy setup to use LatestN directly instead of AnyPreservationPolicy wrapper- Update asset directory structure to use checkpoint_dir/assets/step/ format- Add comprehensive asset saving/loading tests for both sync and async modes- Make test_save_freq_epoch more robust by checking for numeric checkpoint names rather than specific epoch- Fix asset loading to handle new directory structure in saving_api.pyAll Orbax checkpoint tests now pass on both JAX and TensorFlow backends.
  • Loading branch information
@amitsrivastava78
amitsrivastava78 committedDec 15, 2025
commit802d7853e1e9604644c0a2d1c3e88654643c6897

Some comments aren't visible on the classic Files Changed page.

43 changes: 38 additions & 5 deletionskeras/src/callbacks/orbax_checkpoint.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
import os
import warnings

import numpy as np
Expand All@@ -8,6 +9,7 @@
from keras.src.callbacks.monitor_callback import (
MonitorCallback, # For metric monitoring logic
)
from keras.src.saving.saving_lib import DiskIOStore
from keras.src.utils.module_utils import ocp

# Context and AsyncOptions are accessed through the lazy-loaded ocp module
Expand DownExpand Up@@ -158,14 +160,18 @@ def __init__(
ocp.training.preservation_policies.LatestN(max_to_keep)
)

# Use AnyPreservationPolicy to combine them.
# Use AnyPreservationPolicy to combine them, or use directly
# if single policy
preservation_policy = None
if policies:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
if len(policies) == 1:
preservation_policy = policies[0]
else:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
)
)
)

# Create the V1 Checkpointer with direct parameter passing
# Orbax will handle directory creation on all processes as needed
Expand DownExpand Up@@ -270,6 +276,33 @@ def _save_checkpoint(self, step, logs=None):
else:
self.checkpointer.save_pytree(step, composite_state)

# Save assets separately since PyTree can't handle binary data
if not self.save_weights_only:
self._save_assets(step)

def _save_assets(self, step):
"""Save model assets to a separate directory."""
from keras.src.saving.saving_lib import _save_state

assets_dir = os.path.join(self.directory, "assets", str(step))
try:
assets_store = DiskIOStore(assets_dir, mode="w")
except FileExistsError:
# Directory already exists, skip asset saving
return
try:
# Use the same recursive saving logic as _save_state
visited = set()
_save_state(
self.model,
None, # No weights store
assets_store, # Assets store
"", # Root path
visited,
)
finally:
assets_store.close()

def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
# Handle save_best_only logic for batch-level saving
Expand Down
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp