Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

PyTorch IR

gmagogsfm edited this pageJun 11, 2020 ·8 revisions

This document presents the IR as of October 17th 2018. Seethe JIT OVERVIEW.md for more up-to-date info

PyTorch uses an SSA-based IR, which is built of multiple entities:

  • Graph is generally the outermost container for the program representation. At the moment all programs are mostly pure (modulo special operators likeprim::Print orprim::PythonOp), but this will change in the future.
  • Then, there areBlocks, which you can treat as functions. They have a list of inputs and outputs, and a (topologically ordered) list ofNodes. EveryGraph has a top-levelBlock which is (using C lingo) like amain function.
  • Nodes, in turn, represent calls to functions (with possibly multiple arguments and returns).
  • Finally, every single intermediate value that appears in the program is represented using aValue - the rootBlocks have a list of input values, and everyNode takes them as inputs, and returns some more as outputs. Every Value has aType associated with it.

Ownership model: All of the structures above are conventionally passed around as pointers.AllNodes andValues are owned by theGraph in which they appear (and in particular they can't be shared between differentGraphs).Types are pointed to only fromValues, so they are wrapped inshared_ptrs.

EveryBlock has a canonicalNode ordering (a doubly-linkednode_list_), which determines both the display and the actual order in which operations will get executed once it's compiled into the JIT interpreter bytecode. It is a responsibility of the programmer to ensure that allNodes they create appear somewhere in that list, and thatthe list is a valid topological ordering.

With this amount of background, let's take a look at an example. Consider this Python program:

deff(a,b):c=a+bd=c*ce=torch.tanh(d*c)returnd+ (e+e)

If we were to translate it into the IR, it could be represented as such aGraph:

How the actual translation from Python to the IR works will be descrbed later in this document.

graph(%0 : Double(2)      %1 : Double(2)) {  %2 : int = prim::Constant[value=1]()  %3 : Double(2) = aten::add(%0, %1, %2)  %4 : Double(2) = aten::mul(%3, %3)  %5 : Double(2) = aten::mul(%4, %3)  %6 : Double(2) = aten::tanh(%5)  %7 : Double(2) = aten::add(%6, %6, %2)  %8 : Double(2) = aten::add(%5, %7, %2)  return (%8);}

This is the canonical textual representation of the IR. You should be able to easily find (almost all) of the elements we discussed above.

  • graph is theGraph
  • %x areValues
  • %x : Double(2) is a type annotation ofValue%x (see below for a list of supported types).
  • %x : T1, %y : T2 = namespace::name(%z, %w) is aNode which represents thenamespace::nameoperator (this name is usually refered to as theNodeskind). It takes%z and%wValues as inputs, and returns two outputs (%x,%y) of typesT1 andT2 respectively.

Finally, nodes can have extra pieces of information assigned to them, which are calledattributes. You can see that it's used in theprim::Constant node, which returns thevalue attribute when it's called. There's a fixed list of types you can attach:

  • int64_t
  • double
  • Tensor
  • Graph (useful for e.g. slicing subgraphs that are meant to be fused)
  • std::string
  • and lists of them (not nested)

Supported types

JIT supports a number of builtin types (and the list is fixed):

  • int,float,bool - scalars
  • Dynamic - tensorswithout any static information available
  • Float(*, *) - tensors withpartial static information available (note that not all of it is shown in the textual output). It includes:
    • data type (Byte,Short,Int,Long,Half,Float,Double)
    • number of dimensions
    • the device on which their data will reside
    • boolean indicating if they will need to be differentiated (usually false for evaluation, and true for training)
  • Float(1, 3, 224, 224) - tensors withfull static information available.Note that as of today this is almost unused, and you should not assume that those details will be ever present. It includes:
    • all of the above
    • sizes
    • strides

Control flow

Blocks can also be embedded insideNodes, and are used to implement control flow combinators. You can treat them as lambda expressions passed in as arguments (no other way of passing functions by value exists). They can take and return multiple values and close over the lexical environment of the surrounding block (everyGraph has a default top-levelBlock).

There are two combinators used today.

prim::If

Implements a conditional statement. The general semantics of this node are as follows:

%y_1, ..., %y_r = prim::If(%condition)  block0() { # TRUE BRANCH, never takes arguments, has to return r outputs    %t_1, ..., %t_k = some::node(%a_value_from_outer_block)    -> (%t_1, ..., %t_r)  }  block1() { # FALSE BRANCH, never takes arguments, has to return r outputs    %f_1, ..., %f_m = some::node(%a_value_from_outer_block)    -> (%f_1, ..., %f_r)  }

Values corresponding to%y_1, ..., %y_r will become either%t_1, ..., %t_r, or%f_1, ..., %f_r depending on the value of%condition at runtime (you can see that the node kind of acts as a Phi node in conventional SSA).

Here's an example translation of a Python program:

deff(a,b,c):d=a+bifc:e=d+delse:e=b+dreturne
graph(%a : Dynamic      %b : Dynamic      %c : Dynamic) {  %2 : int = prim::Constant[value=1]()  %3 : Dynamic = aten::add(%a, %b, %2)  %5 : Dynamic = prim::If(%c)    block0() {      %6 : int = prim::Constant[value=1]()      %7 : Dynamic = aten::add(%3, %3, %6)      -> (%7)    }    block1() {      %8 : int = prim::Constant[value=1]()      %9 : Dynamic = aten::add(%b, %3, %8)      -> (%9)    }  return (%5);}

prim::Loop

Implements a looping construct (covers bothwhile andfor loops). A valid instantiation of this node always looks like this:

%y_1, ..., %y_r = prim::Loop(%max_trip_count, %initial_condition, %x_1, ..., %x_r)  block0(%i, %a_1, ..., %a_r) {    %b_1, ..., %b_m = some::node(%a_value_from_outer_block, %a_1)    %iter_condition = some::other_node(%a_2)    -> (%iter_condition, %b_1, ..., %b_r)  }

The simplest way to explain the semantics is to consider this Python-like pseudo-code:

y_1, ...,y_r=x_1, ...,x_rcondition=initial_conditioni=0whileconditionandi<max_trip_count:a_1, ...,a_r=y_1, ...,y_r############################################################# Actual body of the loopb_1, ...,b_m=some::node(a_value_from_outside_of_the_loop,a_1)iter_condition=some::node(a_2)############################################################y_1, ...,y_r=b_1, ...,b_rcondition=iter_conditioni+=1

Note that translations offor loops simply pass in a constanttrue for both%initial_condition and%iter_condition, while forwhile loops%max_trip_count is set to the largest value ofint64_t, and%i is unused. Those patterns are recognized by our interpreter and optimized accordingly (e.g.while loops don't maintain the loop counter).

For example, this program:

deff(x):z=xforiinrange(x.size(0)):z=z*zreturnz

can be translated as:

graph(%z.1 : Dynamic) {  %3 : bool = prim::Constant[value=1]()  %1 : int = prim::Constant[value=0]()  %2 : int = aten::size(%z.1, %1)  %z : Dynamic = prim::Loop(%2, %3, %z.1)    block0(%i : int, %5 : Dynamic) {      %z.2 : Dynamic = aten::mul(%5, %5)      -> (%3, %z.2)    }  return (%z);}

Function calls

At the moment there's way to call aGraph from anotherGraph, and all function calls appearing in the frontend result in inlining of the callee's body into the caller. In particular recursive function calls are not supported yet. This will be addressed in a future release.

Node overloading

PyTorch IR supports function overloading (but you can't have two overloads that differ only in their return types). For example,aten::add name has usually those overloads associated with it (Scalar meansfloat orint in this case):

  • aten::add(Tensor self, Tensor other) -> Tensor
  • aten::add(Tensor self, Scalar other) -> Tensor
  • aten::add(int self, int other) -> int
  • aten::add(float self, float other) -> float

All of the strings above can actually be parsed intoFunctionSchema objects, which hold all this infomation in a machine-readable way. ANode can be queried for its schema using theschema() method (it will check the argument types, and will try to match one of the options for itskind()).

Note that the chosen overload is not shown in any way in the textual output. If you're unsure which function does a node resolve to, you might need to check the type annotations of its input values.

JIT interpreter bytecode

Graphs are data structures written with the ease of manipulation in mind, and are not meant to be interpreted directly.Instead, they are first transformed intoCode objects, which hold a list ofstd::functions (individual instructions), and some additional metadata regarding register use.Later,Code objects can be executed usingInterpreterState objects.

The JIT interpreter is a simple stack-based VM (with a number of registers to hold local values).There's a singleStack used to pass arguments into and out of every instruction. The aforementioned metadata inCode describes how to organize stores/loads between registers and the stack.

Stack is really anstd::vector<IValue>, whereIValue is our custom tagged union type, which is able to represent all kinds of values that the JIT can accept (it's optimized to be small and lean, and only takes 16B).

There are many tricks that can be applied in the interpreter to make it faster, but we haven't seen it becoming a bottleneck this far, so we haven't spent time on it.

Operator registration

PyTorch JIT supports open registration of new operators, so they can be freely added at runtime e.g. viadlopen. The syntax is as follows:

RegisterOperatorsreg({Operator(// Specify a function signature"my_namespace::magic(Tensor a, Tensor b, int c) -> (Tensor, Tensor)",// An std::function that should be called to retrieve a callable implementing// this operator.    [](Node *node) -> Operation {// Retrieve the multplier attribute from the nodedouble multiplier = node->d(attr::multiplier);return [multiplier](Stack& stack) ->int {        torch::Tensor a, b;int c;torch::jit::pop(stack, a, b, c);        std::pair<torch::Tensor, torch::Tensor> result =magic_impl(a, b, c);torch::jit::push(stack, result.first, result.second);return0;// Always return 0 here.      }    })});

Graph specialization

Certain optimization require certain knowledge about the data types and devices of tensors appearing in user programs. To support this, we have aGraphExecutor, which is like a wrapper around an interpreter, that additionally checks what kind of inputs were given, and caches execution plans forGraphs specialized to their details. For exampleTensor inputs toGraphs get assignedTensorTypes (dtype, ndim, device, gradient status), and we later attempt to propagate that statically (usingtorch/csrc/jit/passes/shape_analysis.cpp).

This has the drawback that every call to a JITed function has to go through this matching of arguments to specialized graphs, which e.g. causes a 0.5% slowdown for CNNs (which don't even get any optimization benefits at the moment). In the future we might consider ditching the specialization in favor of more JIT-like techniques (gathering statistics about run time values like tensor sizes, and making optimizations in later stages).

Important files

This section contains a list of relatively important files and a brief description of their contents. All paths are relative totorch/csrc/jit.

  • ir.h - implementation ofGraph,Block,Node,Value
  • type.h - implementation ofType
  • interpreter.cpp - JIT interpreter (Code,InterpreterImpl)
  • ivalue.h - implementation ofIValue
  • stack.h - implementation ofStack
  • graph_executor.cpp - a runner for graphs that will specialize them to different argument configurations
  • tracer.h - tracer for PyTorch code (generates straight lineGraphs from any code)
  • operator.cpp - infrastructure for overload resolution and custom operator registration
  • script/ - compiler from TorchScript (think Python AST) toGraphs
  • passes/*.cpp - optimization passes
  • fusers/**/* - CUDA and CPU codegens for pointwise subgraphs
  • autodiff.cpp - symbolic AD forGraphs
  • symbolic_variable.h - a helper to makeGraph building easier

IR construction

There are three main ways of building up the IR.

Tracing

This means that you run arbitrary Python/C++ code using PyTorch operators, and we record a straight line trace (control flow gets unrolled and inlined). Good for simple models, bad if you really have data dependent control flow (and it's not only used for metaprogramming). The relevant entry point for this istorch.jit.trace.

TorchScript

This method implements a simple Python-like language (it's in fact a subset of Python that conforms to its semantics) and a compiler from it to the IR. Great if you need to retain control flow, but a bit annoying if you need more advanced language features.

Manual construction

This doesn't really happen anywhere outside of the optimization passes, and is probably not recommended.SymbolicVariable is a helper that overloads manyTensor operators and makes them insertNodes into itsGraph instead of doing actual compute.

Graph manipulation

As mentioend previously, the IR is really optimized to be easy to manipulate and change.TO help with that there are numerous methods onGraphs,Nodes andValues, and we maintain a lot of extra metadata that allows to quickly check certain conditions (e.g. looking up all use sites of a singleValue takes constant time, because we have this information cached). Here's a list of the most relevant methods you can find (think ofArrayRef as of anstd::vector,Symbol is an interned string):

Graph

  • ArrayRef<Value*> inputs()
  • ArrayRef<Value*> outputs()
  • graph_node_list nodes()
  • Value* addInput()
  • Value* insertInput(size_t offset)
  • Value* eraseInput(size_t offset)
  • size_t registerOutput(Value *output);
  • void eraseOutput(size_t offset)
  • Value* insert(Symbol opname, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs = {});
    • This is the most convenient method of adding more nodes to theGraph. An example calllooks like this:graph->insert(aten::add, {some_value, 3}) (note how C++ values will getinserted into theGraph as constants automatically).
  • Block* block() (returns the top-level block)
  • void lint() (throws if the graph violates invariants like having the node list being a valid topological order of data dependencies)
  • void dump() (prints the graph tostdout -- useful for debugging)

Value

  • const TypePtr& type()
  • Node* node() (producer of thisValue)
  • size_t offset() (offset into the output list of thenode())
  • void replaceAllUsesWith(Value* other)
  • const use_list& uses() (use_list isstd::vector<Use>, whereUse is a struct containing aNode* and offset into its input list)
  • Graph* owningGraph()

Node

  • Symbol kind()
  • ArrayRef<Value*> inputs()
  • ArrayRef<Value*> outputs()
  • Value* namedInput(Symbol name) (lets you look up inputs by their names instead of depdending on the positional order)
  • bool is_constant(Symbol name) (return true if inputname is a constant)
  • optional<IValue> get(Symbol name) (ifis_constant(name), returns anIValue containing its value)
  • optional<T> get(Symbol name) (same as above but returnsT instead ofIValue)
  • Value* addInput(Value* value)
  • Value* insertInput(size_t offset, Value* value)
  • Value* replaceInput(size_t offset, Value* newValue)
  • Value* replaceInputWith(Value* from, Value* to)
  • Value* addOutput()
  • Value* insertOutput(size_t offset)
  • void eraseOutput(size_t offset)
  • ArrayRef<Block*> blocks()
  • Block* addBlock()
  • void eraseBlock(size_t offset)
  • void destroy() (This is dangerous! All references toValues produced by this node, and to the node itself become invalid!)
  • void dump() (Debug print tostdout)
  • Block* owningBlock()
  • Graph* owningGraph()

A larger example (simple RNN loop)

Building up on everything that I covered so far, here's a Python code that shows you how to inspect example translations into the IR (and shows a simple single-layer RNN). Note that the Python 3 type annotations are supported as well, but this is more portable.

importtorch@torch.jit.scriptdeflstm_cell(input,hidden,w_ih,w_hh,b_ih,b_hh):# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]hx,cx=hiddengates=torch.mm(input,w_ih.t())+torch.mm(hx,w_hh.t())+b_ih+b_hhingate,forgetgate,cellgate,outgate=gates.chunk(4,1)ingate=torch.sigmoid(ingate)forgetgate=torch.sigmoid(forgetgate)cellgate=torch.tanh(cellgate)outgate=torch.sigmoid(outgate)cy= (forgetgate*cx)+ (ingate*cellgate)hy=outgate*torch.tanh(cy)returnhy,cy@torch.jit.scriptdefsimple_lstm(input,hidden,wih,whh,bih,bhh):# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]outputs= []inputs=input.unbind(0)forseq_idxinrange(len(inputs)):hidden=lstm_cell(inputs[seq_idx],hidden,wih,whh,bih,bhh)hy,_=hiddenoutputs.append(hy)returnhiddenprint(simple_lstm.graph)
graph(%input : Dynamic      %hidden.1 : Tuple      %wih : Dynamic      %whh : Dynamic      %bih : Dynamic      %bhh : Dynamic) {  %20 : int = prim::Constant[value=1]()  %19 : int = prim::Constant[value=4]()  %10 : bool = prim::Constant[value=1]()  %7 : int = prim::Constant[value=0]()  %54 : World = prim::LoadWorld()  %outputs : Dynamic[] = prim::ListConstruct()  %inputs : Dynamic[] = aten::unbind(%input, %7)  %9 : int = aten::len(%inputs)  %hidden : Tuple, %56 : World = prim::Loop(%9, %10, %hidden.1, %54)    block0(%seq_idx : int, %14 : Tuple, %55 : World) {      %13 : Dynamic = aten::select(%inputs, %seq_idx)      %hx : Dynamic, %cx : Dynamic = prim::TupleUnpack(%14)      %23 : Dynamic = aten::t(%wih)      %24 : Dynamic = aten::mm(%13, %23)      %25 : Dynamic = aten::t(%whh)      %26 : Dynamic = aten::mm(%hx, %25)      %27 : Dynamic = aten::add(%24, %26, %20)      %28 : Dynamic = aten::add(%27, %bih, %20)      %gates : Dynamic = aten::add(%28, %bhh, %20)      %30 : Dynamic[] = aten::chunk(%gates, %19, %20)      %ingate.1 : Dynamic, %forgetgate.1 : Dynamic, %cellgate.1 : Dynamic, %outgate.1 : Dynamic = prim::ListUnpack(%30)      %ingate : Dynamic = aten::sigmoid(%ingate.1)      %forgetgate : Dynamic = aten::sigmoid(%forgetgate.1)      %cellgate : Dynamic = aten::tanh(%cellgate.1)      %outgate : Dynamic = aten::sigmoid(%outgate.1)      %39 : Dynamic = aten::mul(%forgetgate, %cx)      %40 : Dynamic = aten::mul(%ingate, %cellgate)      %_ : Dynamic = aten::add(%39, %40, %20)      %42 : Dynamic = aten::tanh(%_)      %hy : Dynamic = aten::mul(%outgate, %42)      %hidden.2 : Tuple = prim::TupleConstruct(%hy, %_)      %49 : World = aten::append(%55, %outputs, %hy)      -> (%10, %hidden.2, %49)    }  %52 : Dynamic, %53 : Dynamic = prim::TupleUnpack(%hidden)   = prim::StoreWorld(%56)  return (%52, %53);}

I would love to contribute to PyTorch!

Clone this wiki locally


[8]ページ先頭

©2009-2025 Movatter.jp