Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork35
Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
License
patrick-kidger/torchtyping
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Welcome! For new projects I nowstrongly recommend using my newerjaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!
The original torchtyping README is as follows.
Turn this:
defbatch_outer_product(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:# x has shape (batch, x_channels)# y has shape (batch, y_channels)# return has shape (batch, x_channels, y_channels)returnx.unsqueeze(-1)*y.unsqueeze(-2)
into this:
defbatch_outer_product(x:TensorType["batch","x_channels"],y:TensorType["batch","y_channels"] )->TensorType["batch","x_channels","y_channels"]:returnx.unsqueeze(-1)*y.unsqueeze(-2)
with programmatic checking that the shape (dtype, ...) specification is met.
Bye-bye bugs! Say hello to enforced, clear documentation of your code.
If (like me) you find yourself littering your code with comments like# x has shape (batch, hidden_state)
or statements likeassert x.shape == y.shape
, just to keep track of what shape everything is,then this is for you.
pip install torchtyping
Requires Python >=3.7 and PyTorch >=1.7.0.
If usingtypeguard
then it must be a version <3.0.0.
torchtyping
allows for type annotating:
- shape: size, number of dimensions;
- dtype (float, integer, etc.);
- layout (dense, sparse);
- names of dimensions as pernamed tensors;
- arbitrary number of batch dimensions with
...
; - ...plus anything else you like, as
torchtyping
is highly extensible.
Iftypeguard
is (optionally) installed thenat runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.
# EXAMPLEfromtorchimportrandfromtorchtypingimportTensorType,patch_typeguardfromtypeguardimporttypecheckedpatch_typeguard()# use before @typechecked@typecheckeddeffunc(x:TensorType["batch"],y:TensorType["batch"])->TensorType["batch"]:returnx+yfunc(rand(3),rand(3))# worksfunc(rand(3),rand(1))# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.
typeguard
also has an import hook that can be used to automatically test an entire module, without needing to manually add@typeguard.typechecked
decorators.
If you're not usingtypeguard
thentorchtyping.patch_typeguard()
can be omitted altogether, andtorchtyping
just used for documentation purposes. If you're not already usingtypeguard
for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Bothtypeguard
andtorchtyping
also integrate withpytest
, so if you're concerned about any performance penalty then they can be enabled during tests only.
torchtyping.TensorType[shape,dtype,layout,details]
The core of the library.
Each ofshape
,dtype
,layout
,details
are optional.
- The
shape
argument can be any of:- An
int
: the dimension must be of exactly this size. If it is-1
then any size is allowed. - A
str
: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent. - A
...
: An arbitrary number of dimensions of any sizes. - A
str: int
pair (technically it's a slice), combining bothstr
andint
behaviour. (Just astr
on its own is equivalent tostr: -1
.) - A
str: str
pair, in which case the size of the dimension passed at runtime will be bound toboth names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.) - A
str: ...
pair, in which case the multiple dimensions corresponding to...
will be bound to the name specified bystr
, and again checked for consistency between arguments. None
, which when used in conjunction withis_named
below, indicates a dimension that mustnot have a name in the sense ofnamed tensors.- A
None: int
pair, combining bothNone
andint
behaviour. (Just aNone
on its own is equivalent toNone: -1
.) - A
None: str
pair, combining bothNone
andstr
behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.) - A
typing.Any
: Any size is allowed for this dimension (equivalent to-1
). - Any tuple of the above. For example.
TensorType["batch": ..., "length": 10, "channels", -1]
. If you just want to specify the number of dimensions then use for exampleTensorType[-1, -1, -1]
for a three-dimensional tensor.
- An
- The
dtype
argument can be any of:torch.float32
,torch.float64
etc.int
,bool
,float
, which are converted to their corresponding PyTorch types.float
is specifically interpreted astorch.get_default_dtype()
, which is usuallyfloat32
.
- The
layout
argument can be eithertorch.strided
ortorch.sparse_coo
, for dense and sparse tensors respectively. - The
details
argument offers a way to pass an arbitrary number of additional flags that customise and extendtorchtyping
. Two flags are built-in by default.torchtyping.is_named
causes thenames of tensor dimensions to be checked, andtorchtyping.is_float
can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g.TensorType[torch.float32]
.) For discussion on how to customisetorchtyping
with your owndetails
, see thefurther documentation. - Check multiple things at once by just putting them all together inside a single
[]
. For exampleTensorType["batch": ..., "length", "channels", float, is_named]
.
torchtyping.patch_typeguard()
torchtyping
integrates withtypeguard
to perform runtime type checking.torchtyping.patch_typeguard()
should be called at the global level, and will patchtypeguard
to checkTensorType
s.
This function is safe to run multiple times. (It does nothing after the first run).
- If using
@typeguard.typechecked
, thentorchtyping.patch_typeguard()
should be called any time before using@typeguard.typechecked
. For example you could call it at the start of each file usingtorchtyping
. - If using
typeguard.importhook.install_import_hook
, thentorchtyping.patch_typeguard()
should be called any time before defining the functions you want checked. For example you could calltorchtyping.patch_typeguard()
just once, at the same time as thetypeguard
import hook. (The order of the hook and the patch doesn't matter.) - If you're not using
typeguard
thentorchtyping.patch_typeguard()
can be omitted altogether, andtorchtyping
just used for documentation purposes.
pytest --torchtyping-patch-typeguard
torchtyping
offers apytest
plugin to automatically runtorchtyping.patch_typeguard()
before your tests.pytest
will automatically discover the plugin, you just need to pass the--torchtyping-patch-typeguard
flag to enable it. Packages can then be passed totypeguard
as normal, either by using@typeguard.typechecked
,typeguard
's import hook, or thepytest
flag--typeguard-packages="your_package_here"
.
See thefurther documentation for:
- FAQ;
- Including
flake8
andmypy
compatibility;
- Including
- How to write custom extensions to
torchtyping
; - Resources and links to other libraries and materials on this topic;
- More examples.
About
Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Sponsor this project
Uh oh!
There was an error while loading.Please reload this page.
Contributors8
Uh oh!
There was an error while loading.Please reload this page.