Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.

License

NotificationsYou must be signed in to change notification settings

patrick-kidger/torchtyping

Repository files navigation

Welcome! For new projects I nowstrongly recommend using my newerjaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. The 'jax' in the name is now historical!




The original torchtyping README is as follows.


torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

defbatch_outer_product(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:# x has shape (batch, x_channels)# y has shape (batch, y_channels)# return has shape (batch, x_channels, y_channels)returnx.unsqueeze(-1)*y.unsqueeze(-2)

into this:

defbatch_outer_product(x:TensorType["batch","x_channels"],y:TensorType["batch","y_channels"]                        )->TensorType["batch","x_channels","y_channels"]:returnx.unsqueeze(-1)*y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like# x has shape (batch, hidden_state) or statements likeassert x.shape == y.shape , just to keep track of what shape everything is,then this is for you.


Installation

pip install torchtyping

Requires Python >=3.7 and PyTorch >=1.7.0.

If usingtypeguard then it must be a version <3.0.0.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as pernamed tensors;
  • arbitrary number of batch dimensions with...;
  • ...plus anything else you like, astorchtyping is highly extensible.

Iftypeguard is (optionally) installed thenat runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLEfromtorchimportrandfromtorchtypingimportTensorType,patch_typeguardfromtypeguardimporttypecheckedpatch_typeguard()# use before @typechecked@typecheckeddeffunc(x:TensorType["batch"],y:TensorType["batch"])->TensorType["batch"]:returnx+yfunc(rand(3),rand(3))# worksfunc(rand(3),rand(1))# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add@typeguard.typechecked decorators.

If you're not usingtypeguard thentorchtyping.patch_typeguard() can be omitted altogether, andtorchtyping just used for documentation purposes. If you're not already usingtypeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Bothtypeguard andtorchtyping also integrate withpytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape,dtype,layout,details]

The core of the library.

Each ofshape,dtype,layout,details are optional.

  • Theshape argument can be any of:
    • Anint: the dimension must be of exactly this size. If it is-1 then any size is allowed.
    • Astr: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A...: An arbitrary number of dimensions of any sizes.
    • Astr: int pair (technically it's a slice), combining bothstr andint behaviour. (Just astr on its own is equivalent tostr: -1.)
    • Astr: str pair, in which case the size of the dimension passed at runtime will be bound toboth names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
    • Astr: ... pair, in which case the multiple dimensions corresponding to... will be bound to the name specified bystr, and again checked for consistency between arguments.
    • None, which when used in conjunction withis_named below, indicates a dimension that mustnot have a name in the sense ofnamed tensors.
    • ANone: int pair, combining bothNone andint behaviour. (Just aNone on its own is equivalent toNone: -1.)
    • ANone: str pair, combining bothNone andstr behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
    • Atyping.Any: Any size is allowed for this dimension (equivalent to-1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for exampleTensorType[-1, -1, -1] for a three-dimensional tensor.
  • Thedtype argument can be any of:
    • torch.float32,torch.float64 etc.
    • int,bool,float, which are converted to their corresponding PyTorch types.float is specifically interpreted astorch.get_default_dtype(), which is usuallyfloat32.
  • Thelayout argument can be eithertorch.strided ortorch.sparse_coo, for dense and sparse tensors respectively.
  • Thedetails argument offers a way to pass an arbitrary number of additional flags that customise and extendtorchtyping. Two flags are built-in by default.torchtyping.is_named causes thenames of tensor dimensions to be checked, andtorchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g.TensorType[torch.float32].) For discussion on how to customisetorchtyping with your owndetails, see thefurther documentation.
  • Check multiple things at once by just putting them all together inside a single[]. For exampleTensorType["batch": ..., "length", "channels", float, is_named].
torchtyping.patch_typeguard()

torchtyping integrates withtypeguard to perform runtime type checking.torchtyping.patch_typeguard() should be called at the global level, and will patchtypeguard to checkTensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using@typeguard.typechecked, thentorchtyping.patch_typeguard() should be called any time before using@typeguard.typechecked. For example you could call it at the start of each file usingtorchtyping.
  • If usingtypeguard.importhook.install_import_hook, thentorchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could calltorchtyping.patch_typeguard() just once, at the same time as thetypeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not usingtypeguard thentorchtyping.patch_typeguard() can be omitted altogether, andtorchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers apytest plugin to automatically runtorchtyping.patch_typeguard() before your tests.pytest will automatically discover the plugin, you just need to pass the--torchtyping-patch-typeguard flag to enable it. Packages can then be passed totypeguard as normal, either by using@typeguard.typechecked,typeguard's import hook, or thepytest flag--typeguard-packages="your_package_here".

Further documentation

See thefurther documentation for:

  • FAQ;
    • Includingflake8 andmypy compatibility;
  • How to write custom extensions totorchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.

About

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.

Topics

Resources

License

Stars

Watchers

Forks

Sponsor this project

 

Contributors8

Languages


[8]ページ先頭

©2009-2025 Movatter.jp