- Notifications
You must be signed in to change notification settings - Fork39
The goal of this library is to generate more helpful exception messages for matrix algebra expressions for numpy, pytorch, jax, tensorflow, keras, fastai.
License
parrt/tensor-sensor
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
See articleClarifying exceptions and visualizing tensor operations in deep learning code andTensorSensor implementation slides (PDF).
(As of September 2021, M1 macs experience illegal instructions in many of the tensor libraries installed via Anaconda, so you should expect TensorSensor to work only on Intel-based Macs at the moment. PyTorch appears to work.)
One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefinedTensorflow network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages.
To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works withTensorflow,PyTorch,JAX, andNumpy, as well as higher-level libraries likeKeras andfastai.
TensorSensor is currently at 1.0 (December 2021).
For more, seeexamples.ipynb at colab. (The github rendering does not show images for some reason:examples.ipynb.)
importnumpyasnpn=200# number of instancesd=764# number of instance featuresn_neurons=100# how many neurons in this layer?W=np.random.rand(d,n_neurons)b=np.random.rand(n_neurons,1)X=np.random.rand(n,d)withtsensor.clarify()asc:Y=W @X.T+b
Displays this in a jupyter notebook or separate window:
Instead of the following default exception message:
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)TensorSensor augments the message with more information about which operator caused the problem and includes the shape of the operands:
Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)You can also get the full computation graph for an expression that includes all of the sub-expression shapes.
W=torch.rand(size=(2000,2000),dtype=torch.float64)b=torch.rand(size=(2000,1),dtype=torch.float64)h=torch.zeros(size=(1_000_000,),dtype=int)x=torch.rand(size=(2000,1))z=torch.rand(size=(2000,1),dtype=torch.complex64)tsensor.astviz("b = W@b + (h+3).dot(h) + z",sys._getframe())
yields the following abstract syntax tree with shapes:
pip install tensor-sensor # This will only install the library for youpip install tensor-sensor[torch] # install pytorch related dependencypip install tensor-sensor[tensorflow] # install tensorflow related dependencypip install tensor-sensor[jax] # install jax, jaxlibpip install tensor-sensor[all] # install tensorflow, pytorch, jaxwhich gives you moduletsensor. I developed and tested with the following versions
$ pip list | grep -i flowtensorflow 2.5.0tensorflow-estimator 2.5.0$ pip list | grep -i numpynumpy 1.19.5numpydoc 1.1.0$ pip list | grep -i torchtorch 1.10.0torchvision 0.10.0$ pip list | grep -i jaxjax 0.2.20jaxlib 0.1.71For displaying abstract syntax trees (ASTs) withtsensor.astviz(...), you need thedot executable from graphviz, not just the python library.
OnMac, do this before or after tensor-sensor install:
brew install graphvizOnWindows, apparently you need
conda install python-graphviz # Do this first; get's dot executable and py libpip install tensor-sensor # Or one of the other installsI rely on parsing lines that are assignments or expressions only so the clarify and explain routines do not handle methods expressed like:
def bar(): b + x * 3Instead, use
def bar():b + x * 3watch out for side effects! I don't do assignments, but any functions you call with side effects will be done while I reevaluate statements.
Can't handle\ continuations.
With Pythonthreading package, don't use multiple threads calling clarify().multiprocessing package should be fine.
Also note: I've built my own parser to handle just the assignments / expressions tsensor can handle.
$ python setup.py sdist upload
Or download and install locally
$cd~/github/tensor-sensor$ pip install.
- can i call pyviz in debugger?
About
The goal of this library is to generate more helpful exception messages for matrix algebra expressions for numpy, pytorch, jax, tensorflow, keras, fastai.
Topics
Resources
License
Contributing
Uh oh!
There was an error while loading.Please reload this page.