- Notifications
You must be signed in to change notification settings - Fork91
Optim refactoring#662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Draft
dario-coscia wants to merge1 commit intodevChoose a base branch fromoptim
base:dev
Could not load branches
Branch not found:{{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline, and old review comments may become outdated.
Uh oh!
There was an error while loading.Please reload this page.
Draft
Changes fromall commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Jump to file
Failed to load files.
Loading
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,9 @@ | ||
| """Module for the Optimizers and Schedulers.""" | ||
| __all__ = [ | ||
| "TorchOptimizer", | ||
| "TorchScheduler", | ||
| ] | ||
| from .torch_optimizer import TorchOptimizer | ||
| from .torch_scheduler import TorchScheduler |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| """Module for the PINA Optimizer and Scheduler Connectors Interface.""" | ||
| from abc import ABCMeta, abstractmethod | ||
| from functools import wraps | ||
| class OptimizerConnectorInterface(metaclass=ABCMeta): | ||
| """ | ||
| Interface class for method definitions in the Optimizer classes. | ||
| """ | ||
| @abstractmethod | ||
| def parameter_hook(self, parameters): | ||
| """ | ||
| Abstract method to define the hook logic for the optimizer. This hook | ||
| is used to initialize the optimizer instance with the given parameters. | ||
| :param dict parameters: The parameters of the model to be optimized. | ||
| """ | ||
| @abstractmethod | ||
| def solver_hook(self, solver): | ||
| """ | ||
| Abstract method to define the hook logic for the optimizer. This hook | ||
| is used to hook the optimizer instance with the given parameters. | ||
| :param SolverInterface solver: The solver to hook. | ||
| """ | ||
| class SchedulerConnectorInterface(metaclass=ABCMeta): | ||
| """ | ||
| Abstract base class for defining a scheduler. All specific schedulers should | ||
| inherit form this class and implement the required methods. | ||
| """ | ||
| @abstractmethod | ||
| def optimizer_hook(self): | ||
| """ | ||
| Abstract method to define the hook logic for the scheduler. This hook | ||
| is used to hook the scheduler instance with the given optimizer. | ||
| """ | ||
| class _HooksOptim: | ||
| """ | ||
| Mixin class to manage and track the execution of hook methods in optimizer | ||
| or scheduler classes. | ||
| This class automatically detects methods ending with `_hook` and tracks | ||
| whether they have been executed for a given instance. Subclasses defining | ||
| `_hook` methods benefit from automatic tracking without additional | ||
| boilerplate. | ||
| """ | ||
| def __init__(self, *args, **kwargs): | ||
| """ | ||
| Initialize the hooks tracking dictionary `hooks_done` for this instance. | ||
| Each hook method detected in the class hierarchy is added to | ||
| `hooks_done` with an initial value of False (not executed). | ||
| """ | ||
| super().__init__(*args, **kwargs) | ||
| # Initialize hooks_done per instance | ||
| self.hooks_done = {} | ||
| for cls in self.__class__.__mro__: | ||
| for attr_name, attr_value in cls.__dict__.items(): | ||
| if callable(attr_value) and attr_name.endswith("_hook"): | ||
| self.hooks_done.setdefault(attr_name, False) | ||
| def __init_subclass__(cls, **kwargs): | ||
| """ | ||
| Hook called when a subclass of _HooksOptim is created. | ||
| Wraps all concrete `_hook` methods defined in the subclass so that | ||
| executing the method automatically updates `hooks_done`. | ||
| """ | ||
| super().__init_subclass__(**kwargs) | ||
| # Wrap only concrete _hook methods defined in this subclass | ||
| for attr_name, attr_value in cls.__dict__.items(): | ||
| if callable(attr_value) and attr_name.endswith("_hook"): | ||
| setattr(cls, attr_name, cls.hook_wrapper(attr_name, attr_value)) | ||
| @staticmethod | ||
| def hook_wrapper(name, func): | ||
| """ | ||
| Wrap a hook method to mark it as executed after calling it. | ||
| :param str name: The name of the hook method. | ||
| :param callable func: The original hook method to wrap. | ||
| :return: The wrapped hook method that updates `hooks_done`. | ||
| """ | ||
| @wraps(func) | ||
| def wrapper(self, *args, **kwargs): | ||
| result = func(self, *args, **kwargs) | ||
| self.hooks_done[name] = True | ||
| return result | ||
| return wrapper |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| """Module for the PINA Optimizer.""" | ||
| from .optim_connector_interface import OptimizerConnectorInterface, _HooksOptim | ||
| class OptimizerConnector(OptimizerConnectorInterface, _HooksOptim): | ||
| """ | ||
| Abstract base class for defining an optimizer connector. All specific | ||
| optimizers connectors should inherit form this class and implement the | ||
| required methods. | ||
| """ | ||
| def __init__(self, optimizer_class, **optimizer_class_kwargs): | ||
| """ | ||
| Initialize connector parameters | ||
| :param torch.optim.Optimizer optimizer_class: The torch optimizer class. | ||
| :param dict optimizer_class_kwargs: The optimizer kwargs. | ||
| """ | ||
| super().__init__() | ||
| self._optimizer_class = optimizer_class | ||
| self._optimizer_instance = None | ||
| self._optim_kwargs = optimizer_class_kwargs | ||
| self._solver = None | ||
| def parameter_hook(self, parameters): | ||
| """ | ||
| Abstract method to define the hook logic for the optimizer. This hook | ||
| is used to initialize the optimizer instance with the given parameters. | ||
| :param dict parameters: The parameters of the model to be optimized. | ||
| """ | ||
| self._optimizer_instance = self._optimizer_class( | ||
| parameters, **self._optim_kwargs | ||
| ) | ||
| def solver_hook(self, solver): | ||
| """ | ||
| Method to define the hook logic for the optimizer. This hook | ||
| is used to hook the optimizer instance with the given parameters. | ||
| :param SolverInterface solver: The solver to hook. | ||
| """ | ||
| if not self.hooks_done["parameter_hook"]: | ||
| raise RuntimeError( | ||
| "Cannot run 'solver_hook' before 'parameter_hook'. " | ||
| "Please call 'parameter_hook' first to initialize " | ||
| "the solver parameters." | ||
| ) | ||
| # hook to both instance and connector the solver | ||
| self._solver = solver | ||
| self._optimizer_instance.solver = solver | ||
| def _register_hooks(self, **kwargs): | ||
| """ | ||
| Register the optimizers hooks. This method inspects keyword arguments | ||
| for known keys (`parameters`, `solver`, ...) and applies the | ||
| corresponding hooks. | ||
| It allows flexible integration with | ||
| different workflows without enforcing a strict method signature. | ||
| This method is used inside the | ||
| :class:`~pina.solver.solver.SolverInterface` class. | ||
| :param kwargs: Expected keys may include: | ||
| - ``parameters``: Parameters to be registered for optimization. | ||
| - ``solver``: Solver instance. | ||
| """ | ||
| # parameter hook | ||
| parameters = kwargs.get("parameters", None) | ||
| if parameters is not None: | ||
| self.parameter_hook(parameters) | ||
| # solver hook | ||
| solver = kwargs.get("solver", None) | ||
| if solver is not None: | ||
| self.solver_hook(solver) | ||
| @property | ||
| def solver(self): | ||
| """ | ||
| Get the solver hooked to the optimizer. | ||
| """ | ||
| if not self.hooks_done["solver_hook"]: | ||
| raise RuntimeError( | ||
| "Solver has not been hooked." | ||
| "Override the method solver_hook to hook the solver to " | ||
| "the optimizer." | ||
| ) | ||
| return self._solver | ||
| @property | ||
| def instance(self): | ||
| """ | ||
| Get the optimizer instance. | ||
| :return: The optimizer instance | ||
| :rtype: torch.optim.Optimizer | ||
| """ | ||
| return self._optimizer_instance |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """Module for the PINA Scheduler.""" | ||
| from .optim_connector_interface import SchedulerConnectorInterface, _HooksOptim | ||
| from .optimizer_connector import OptimizerConnector | ||
| from ...utils import check_consistency | ||
| class SchedulerConnector(SchedulerConnectorInterface, _HooksOptim): | ||
| """ | ||
| Class for defining a scheduler connector. All specific schedulers connectors | ||
| should inherit form this class and implement the required methods. | ||
| """ | ||
| def __init__(self, scheduler_class, **scheduler_kwargs): | ||
| """ | ||
| Initialize connector parameters | ||
| :param torch.optim.lr_scheduler.LRScheduler scheduler_class: The torch | ||
| scheduler class. | ||
| :param dict scheduler_kwargs: The scheduler kwargs. | ||
| """ | ||
| super().__init__() | ||
| self._scheduler_class = scheduler_class | ||
| self._scheduler_instance = None | ||
| self._scheduler_kwargs = scheduler_kwargs | ||
| def optimizer_hook(self, optimizer): | ||
| """ | ||
| Abstract method to define the hook logic for the scheduler. This hook | ||
| is used to hook the scheduler instance with the given optimizer. | ||
| :param Optimizer optimizer: The optimizer to hook. | ||
| """ | ||
| check_consistency(optimizer, OptimizerConnector) | ||
| if not optimizer.hooks_done["parameter_hook"]: | ||
| raise RuntimeError( | ||
| "Scheduler cannot be set, Optimizer not hooked " | ||
| "to model parameters. " | ||
| "Please call Optimizer.parameter_hook()." | ||
| ) | ||
| self._scheduler_instance = self._scheduler_class( | ||
| optimizer.instance, **self._scheduler_kwargs | ||
| ) | ||
| def _register_hooks(self, **kwargs): | ||
| """ | ||
| Register the optimizers hooks. This method inspects keyword arguments | ||
| for known keys (`parameters`, `solver`, ...) and applies the | ||
| corresponding hooks. | ||
| It allows flexible integration with | ||
| different workflows without enforcing a strict method signature. | ||
| This method is used inside the | ||
| :class:`~pina.solver.solver.SolverInterface` class. | ||
| :param kwargs: Expected keys may include: | ||
| - ``parameters``: Parameters to be registered for optimization. | ||
| - ``solver``: Solver instance. | ||
| """ | ||
| # optimizer hook | ||
| optimizer = kwargs.get("optimizer", None) | ||
| if optimizer is not None: | ||
| check_consistency(optimizer, OptimizerConnector) | ||
| self.optimizer_hook(optimizer) | ||
| @property | ||
| def instance(self): | ||
| """ | ||
| Get the scheduler instance. | ||
| :return: The scheduler instance | ||
| :rtype: torch.optim.lr_scheduler.LRScheduler | ||
| """ | ||
| return self._scheduler_instance |
This file was deleted.
Oops, something went wrong.
Uh oh!
There was an error while loading.Please reload this page.
This file was deleted.
Oops, something went wrong.
Uh oh!
There was an error while loading.Please reload this page.
Oops, something went wrong.
Uh oh!
There was an error while loading.Please reload this page.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.