torch.jit.freeze#
- torch.jit.freeze(mod,preserved_attrs=None,optimize_numerics=True)[source]#
Freeze ScriptModule, inline submodules, and attributes as constants.
Freezing a
ScriptModulewill clone it and attempt to inline the clonedmodule’s submodules, parameters, and attributes as constants in the TorchScript IR Graph.By default,forward will be preserved, as well as attributes & methods specified inpreserved_attrs. Additionally, any attribute that is modified within a preservedmethod will be preserved.Freezing currently only accepts ScriptModules that are in eval mode.
Freezing applies generic optimization that will speed up your model regardless of machine.To further optimize using server-specific settings, runoptimize_for_inference afterfreezing.
- Parameters
mod (
ScriptModule) – a module to be frozenpreserved_attrs (Optional[List[str]]) – a list of attributes to preserve in addition to the forward method.Attributes modified in preserved methods will also be preserved.
optimize_numerics (bool) – If
True, a set of optimization passes will be run that does not strictlypreserve numerics. Full details of optimization can be found attorch.jit.run_frozen_optimizations.
- Returns
Frozen
ScriptModule.
Example (Freezing a simple module with a Parameter):
defforward(self,input):output=self.weight.mm(input)output=self.linear(output)returnoutputscripted_module=torch.jit.script(MyModule(2,3).eval())frozen_module=torch.jit.freeze(scripted_module)# parameters have been removed and inlined into the Graph as constantsassertlen(list(frozen_module.named_parameters()))==0# See the compiled graph as Python codeprint(frozen_module.code)
Example (Freezing a module with preserved attributes)
defforward(self,input):self.modified_tensor+=1returninput+self.modified_tensorscripted_module=torch.jit.script(MyModule2().eval())frozen_module=torch.jit.freeze(scripted_module,preserved_attrs=["version"])# we've manually preserved `version`, so it still exists on the frozen module and can be modifiedassertfrozen_module.version==1frozen_module.version=2# `modified_tensor` is detected as being mutated in the forward, so freezing preserves# it to retain model semanticsassertfrozen_module(torch.tensor(1))==torch.tensor(12)# now that we've run it once, the next result will be incremented by oneassertfrozen_module(torch.tensor(1))==torch.tensor(13)
Note
Freezing submodule attributes is also supported:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])
Note
If you’re not sure why an attribute is not being inlined as a constant, you can rundump_alias_db on frozen_module.forward.graph to see if freezing has detected theattribute is being modified.
Note
Because freezing makes weights constants and removes module hierarchy,to and othernn.Module methods to manipulate device or dtype no longer work. As a workaround,You can remap devices by specifyingmap_location intorch.jit.load, howeverdevice-specific logic may have been baked into the model.