Movatterモバイル変換


[0]ホーム

URL:


  • Slang User's Guide
    • Introduction
      • Why use Slang?
      • Who is Slang for?
      • Who is this guide for?
      • Goals and Non-Goals
    • Getting Started with Slang
      • Installation
      • Your first Slang shader
      • The full example
    • Conventional Language Features
      • Types
      • Expressions
      • Statements
      • Functions
      • Preprocessor
      • Attributes
      • Global Variables and Shader Parameters
      • Shader Entry Points
      • Mixed Shader Entry Points
      • Auto-Generated Constructors
      • Initializer Lists
    • Basic Convenience Features
      • Type Inference in Variable Definitions
      • Immutable Values
      • Namespaces
      • Member functions
      • Properties
      • Initializers
      • Operator Overloading
      • Subscript Operator
      • Tuple Types
      • `Optional<T>` type
      • `Conditional<T, bool condition>` Type
      • `if_let` syntax
      • `reinterpret<T>` operation
      • Pointers (limited)
      • `DescriptorHandle` for Bindless Descriptor Access
      • Extensions
      • Multi-level break
      • Force inlining
      • Error handling
      • Special Scoping Syntax
      • User Defined Attributes (Experimental)
    • Modules and Access Control
      • Defining a Module
      • Importing a Module
      • Access Control
      • Organizing File Structure of Modules
      • Legacy Modules
    • Capabilities
      • Capability Atoms and Capability Requirements
      • Conflicting Capabilities
      • Capabilities Between Parent and Members
      • Capabilities Between Subtype and Supertype
      • Capabilities Between Requirement and Implementation
      • Capabilities of Functions
      • Capability Aliases
      • Validation of Capability Requirements
    • Interfaces and Generics
      • Interfaces
      • Generics
      • Supported Constructs in Interface Definitions
      • Associated Types
      • Generic Value Parameters
      • Type Equality Constraints
      • Interface-typed Values
      • Extending a Type with Additional Interface Conformances
      • `is` and `as` Operator
      • Generic Interfaces
      • Generic Extensions
      • Extensions to Interfaces
      • Variadic Generics
      • Builtin Interfaces
    • Automatic Differentiation
      • Auto-diff operations `fwd_diff` and `bwd_diff`
      • Differentiable Type System
      • User-Defined Derivative Functions
      • Using Auto-diff with Generics
      • Using Auto-diff with Interface Requirements and Interface Types
      • Primal Substitute Functions
      • Working with Mixed Differentiable and Non-Differentiable Code
      • Higher-Order Differentiation
      • Restrictions and Known Issues
      • Reference
    • Compiling Code with Slang
      • Concepts
      • Command-Line Compilation with `slangc`
      • Using the Compilation API
      • Multithreading
      • Compiler Options
      • Debugging
    • Using the Reflection API
      • Compiling a Program
      • Types and Variables
      • Layout for Types and Variables
      • Programs and Scopes
      • Calculating Cumulative Offsets
      • Determining Whether Parameters Are Used
      • Conclusion
    • Supported Compilation Targets
      • Background and Terminology
      • Direct3D 11
      • Direct3D 12
      • Vulkan
      • OpenGL
      • Metal
      • CUDA and OptiX
      • CPU Compute
      • WebGPU
      • Summary
    • Link-time Specialization and Module Precompilation
      • Link-time Constants
      • Link-time Types
      • Providing Default Settings
      • Using Precompiling Modules with the API
      • Additional Remarks
    • language slang 2026
      • The Legacy Slang Language
      • Slang 2025
      • Slang 2026
    • Special Topics
      • Handling Matrix Layout Differences on Different Platforms
        • Two conventions of matrix transform math
        • Discussion
        • Matrix Layout
        • Overriding default matrix layout
      • Obfuscation
        • Obfuscation in Slang
        • Using An Obfuscated Module
        • Accessing Source Maps
        • Accessing Source Maps without Files
        • Emit Source Maps
        • Issues/Future Work
      • Interoperation with Target-Specific Code
        • Defining Intrinsic Functions for Textual Targets
        • Defining Intrinsic Types
        • Injecting Preludes
        • Managing Cross-Platform Code
        • Inline SPIRV Assembly
      • Uniformity Analysis
        • Treat Values as Uniform
        • Treat Function Return Values as Non-uniform
    • Target-Specific Features
      • SPIR-V-Specific Functionalities
        • Experimental support for the older versions of SPIR-V
        • Memory model
        • Combined texture sampler
        • System-Value semantics
        • Using SV_InstanceID and SV_VertexID with SPIR-V target
        • Behavior of `discard` after SPIR-V 1.6
        • Supported HLSL features when targeting SPIR-V
        • Unsupported GLSL keywords when targeting SPIR-V
        • Supported atomic types for each target
        • ConstantBuffer, StructuredBuffer and ByteAddressBuffer
        • ParameterBlock for SPIR-V target
        • Push Constants
        • Specialization Constants
        • SPIR-V specific Attributes
        • Multiple entry points support
        • Global memory pointers
        • Matrix type translation
        • Legalization
        • Tessellation
        • SPIR-V specific Compiler options
      • Metal-Specific Functionalities
        • Entry Point Parameter Handling
        • System-Value semantics
        • Interpolation Modifiers
        • Resource Types
        • Array Types
        • Matrix Layout
        • Mesh Shader Support
        • Header Inclusions and Namespace
        • Parameter blocks and Argument Buffers
        • Struct Parameter Flattening
        • Return Value Handling
        • Value Type Conversion
        • Conservative Rasterization
        • Address Space Assignment
        • Explicit Parameter Binding
        • Specialization Constants
      • WGSL-Specific Functionalities
        • System-Value semantics
        • Supported HLSL features when targeting WGSL
        • Supported atomic types
        • ConstantBuffer, (RW/RasterizerOrdered)StructuredBuffer, (RW/RasterizerOrdered)ByteAddressBuffer
        • Interlocked operations
        • Entry Point Parameter Handling
        • Parameter blocks
        • Write-only Textures
        • Pointers
        • Address Space Assignment
        • Matrix type translation
        • Explicit Parameter Binding
        • Specialization Constants
      • GLSL-Specific Functionalities
        • Combined Texture Sampler
        • System-Value Semantics
        • `discard` Statement
        • HLSL Features Supported in GLSL
        • Atomic Types
        • Buffer Types
        • Matrix Layout
        • Entry Points
        • Specialization Constants
        • Attributes and Layout Qualifiers
        • GLSL-Specific Compiler Options
    • Reference
      • Capability Profiles
      • Capability Atoms
        • Targets
        • Stages
        • Versions
        • Extensions
        • Compound Capabilities
        • Other

Automatic Differentiation

To support differentiable graphics systems such as Gaussian splatters, neural radiance fields, differentiable path tracers, and more,Slang provides first class support for differentiable programming. An overview:

  • Slang supports thefwd_diff andbwd_diff operators that can generate the forward and backward-mode derivative propagation functions for any valid Slang function annotated with the[Differentiable] attribute.
  • TheDifferentialPair<T> built-in generic type is used to pass derivatives associated with each function input.
  • TheIDifferentiable, and the experimentalIDifferentiablePtrType, interfaces denote differentiable value and pointer types respectively, and allow finer control over how types behave under differentiation.
  • Further, Slang allows for user-defined derivative functions through the[ForwardDerivative(custom_fn)] and[BackwardDerivative(custom_fn)]
  • All Slang features, such as control-flow, generics, interfaces, extensions, and more are compatible with automatic differentiation, though the bottom of this chapter documents some sharp edges & known issues.

Auto-diff operationsfwd_diff andbwd_diff

In Slang,fwd_diff andbwd_diff are higher-order functions used to transform Slang functions into their forward or backward derivative methods. To better understand what these methods do, here is a small refresher on differentiable calculus:

Mathematical overview: Jacobian and its vector products

Forward and backward derivative methods are two different ways of computing a dot product with the Jacobian of a given function.Parts of this overview are based on JAX’s excellent auto-diff cookbookhere. The relevantwikipedia article is also a great resource for understanding auto-diff.

TheJacobian (also called the total derivative) of a function \(\mathbf{f}(\mathbf{x})\) is represented by \(D\mathbf{f}(\mathbf{x})\).

For a general function with multiple scalar inputs and multiple scalar outputs, the Jacobian is amatrix where \(D\mathbf{f}_{ij}\) represents thepartial derivative of the \(i^{th}\) output element w.r.t the \(j^{th}\) input element \(\frac{\partial f_i}{\partial x_j}\)

As an example, consider a polynomial function

\[f(x, y) = x^3 + x^2 - y\]

Here, \(f\) has 1 output and 2 inputs. \(Df\) is therefore the row matrix:

\[Df(x, y) = [\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}] = [3x^2 + 2x, -1]\]

Another, more complex example with a function that has multiple outputs (for clarity, denoted by \(f_1\), \(f_2\), etc..)

\[\mathbf{f}(x, y) = \begin{bmatrix} f_0(x, y) & f_1(x, y) & f_2(x, y) \end{bmatrix} = \begin{bmatrix} x^3 & y^2x & y^3 \end{bmatrix}\]

Here, \(D\mathbf{f}\) is a 3x2 matrix with each element containing a partial derivative:

\[D\mathbf{f}(x, y) = \begin{bmatrix} \partial f_0 / \partial x & \partial f_0 / \partial y \\ \partial f_1 / \partial x & \partial f_1 / \partial y \\\partial f_2 / \partial x & \partial f_2 / \partial y\end{bmatrix} = \begin{bmatrix} 3x^2 & 0 \\ y^2 & 2yx \\0 & 3y^2\end{bmatrix}\]

Computing full Jacobians is often unnecessary and expensive. Instead, auto-diff offers ways to computeproducts of the Jacobian with a vector, which is a much faster operation.There are two basic ways to compute this product:

  1. the Jacobian-vector product \(\langle D\mathbf{f}(\mathbf{x}), \mathbf{v} \rangle\), also called forward-mode autodiff, and can be computed usingfwd_diff operator in Slang, and
  2. the vector-Jacobian product \(\langle \mathbf{v}^T, D\mathbf{f}(\mathbf{x}) \rangle\), also called reverse-mode autodiff, and can be computed usingbwd_diff operator in Slang. From a linear algebra perspective, this is the transpose of the forward-mode operator.

Propagating derivatives with forward-mode auto-diff

The products described above allow thepropagation of derivatives forward and backward through the function \(f\)

The forward-mode derivative (Jacobian-vector product) can convert a derivative of the inputs to a derivative of the outputs. For example, let’s say inputs \(\mathbf{x}\) depend on some scalar \(\theta\), and \(\frac{\partial \mathbf{x}}{\partial \theta}\) is a vector of partial derivatives describing that dependency.

Invoking forward-mode auto-diff with \(\mathbf{v} = \frac{\partial \mathbf{x}}{\partial \theta}\) converts this into a derivative of the outputs w.r.t the same scalar \(\theta\).This can be verified by expanding the Jacobian and applying thechain rule of derivatives:

\[\langle D\mathbf{f}(\mathbf{x}), \frac{\partial \mathbf{x}}{\partial \theta} \rangle = \langle \begin{bmatrix} \frac{\partial f_0}{\partial x_0} & \frac{\partial f_0}{\partial x_1} & \cdots \\ \frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1} & \cdots \\ \cdots & \cdots & \cdots \end{bmatrix}, \begin{bmatrix} \frac{\partial x_0}{\partial \theta} \\ \frac{\partial x_1}{\partial \theta} \\ \cdots \end{bmatrix} \rangle = \begin{bmatrix} \frac{\partial f_0}{\partial \theta} \\ \frac{\partial f_1}{\partial \theta} \\ \cdots \end{bmatrix} = \frac{\partial \mathbf{f}}{\partial \theta}\]

Propagating derivatives with reverse-mode auto-diff

The reverse-mode derivative (vector-Jacobian product) can convert a derivative w.r.t outputs into a derivative w.r.t inputs.For example, let’s say we have some scalar \(\mathcal{L}\) that depends on the outputs \(\mathbf{f}\), and \(\frac{\partial \mathcal{L}}{\partial \mathbf{f}}\) is a vector of partial derivatives describing that dependency.

Invoking forward-mode auto-diff with \(\mathbf{v} = \frac{\partial \mathcal{L}}{\partial \mathbf{f}}\) converts this into a derivative of the same scalar \(\mathcal{L}\) w.r.t the inputs \(\mathbf{x}\).To provide more intuition for this, we can expand the Jacobian in a same way we did above:

\[\langle \frac{\partial \mathcal{L}}{\partial \mathbf{f}}^T, D\mathbf{f}(\mathbf{x}) \rangle = \langle \begin{bmatrix}\frac{\partial \mathcal{L}}{\partial f_0} & \frac{\partial \mathcal{L}}{\partial f_1} & \cdots \end{bmatrix}, \begin{bmatrix} \frac{\partial f_0}{\partial x_0} & \frac{\partial f_0}{\partial x_1} & \cdots \\ \frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1} & \cdots \\ \cdots & \cdots & \cdots \end{bmatrix} \rangle = \begin{bmatrix} \frac{\partial \mathcal{L}}{\partial x_0} & \frac{\partial \mathcal{L}}{\partial x_1} & \cdots \end{bmatrix} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}^T\]

This mode is the most popular, since machine learning systems often construct their differentiable pipeline with multiple inputs (which can number in the millions or billions), and a single scalar output often referred to as the ‘loss’ denoted by \(\mathcal{L}\). The desired derivative can be constructed with a single reverse-mode invocation.

Invoking auto-diff in Slang

With the mathematical foundations established, we can describe concretely how to compute derivatives using Slang.

In Slang derivatives are computed usingfwd_diff/bwd_diff which each correspond to Jacobian-vector and vector-Jacobian products.For forward-diff, to pass the vector \(\mathbf{v}\) and receive the outputs, we use theDifferentialPair<T> type. We use pairs of inputs because every input element \(x_i\) has a corresponding element \(v_i\) in the vector, and each original output element has a corresponding output element in the product.

Example offwd_diff:

[Differentiable]// Auto-diff requires that functions are marked differentiablefloat2foo(floata,floatb){returnfloat2(a*b*b,a*a);}voidmain(){DifferentialPair<float>dp_a=diffPair(1.0,// input 'a'1.0// vector 'v' for vector-Jacobian product input (for 'a'));DifferentialPair<float>dp_b=diffPair(2.4,0.0);// fwd_diff to compute output and d_output w.r.t 'a'.// Our output is also a differential pair.//DifferentialPair<float2>dp_output=fwd_diff(foo)(dp_a,dp_b);// Extract output's primal part, which is just the standard output when foo is called normally.// Can also use `.getPrimal()`//float2output_p=dp_output.p;// Extract output's derivative part. Can also use `.getDifferential()`float2output_d=dp_output.d;printf("foo(1.0, 2.4) = (%f %f)\n",output_p.x,output_p.y);printf("d(foo)/d(a) at (1.0, 2.4) = (%f, %f)\n",output_d.x,output_d.y);}

Note that all the inputs and outputs to our function become ‘paired’. This only applies to differentiable types, such asfloat,float2, etc. See the section on differentiable types for more info.

diffPair<T>(primal_val, diff_val) is a built-in utility function that constructs the pair from the primal and differential values.

Additionally, invoking forward-mode also computes the regular (or ‘primal’) output value (can be obtained fromoutput.getPrimal() oroutput.p). The same isnot true for reverse-mode.

For reverse-mode, the example proceeds in a similar way, and we still useDifferentialPair<T> type. However, note that each input gets a correspondingoutput and each output gets a correspondinginput. Thus, all inputs becomeinout differential pairs, to allow the function to write into the derivative part (the primal part is still accepted as an input in the same pair data-structure).The one extra rule is that the derivative corresponding to the return value of the function is accepted as the last argument (an extra input). This value does not need to be a pair.

Example:

[Differentiable]// Auto-diff requires that functions are marked differentiablefloat2foo(floata,floatb){returnfloat2(a*b*b,a*a);}voidmain(){DifferentialPair<float>dp_a=diffPair(1.0// input 'a');// Calling diffPair without a derivative part initializes to 0.DifferentialPair<float>dp_b=diffPair(2.4);// Derivatives of scalar L w.r.t output.float2dL_doutput=float2(1.0,0.0);// bwd_diff to compute dL_da and dL_db// The derivative of the output is provided as an additional _input_ to the call// Derivatives w.r.t inputs are written into dp_a.d and dp_b.d//bwd_diff(foo)(dp_a,dp_b,dL_doutput);// Extract the derivatives of L w.r.t inputfloatdL_da=dp_a.d;floatdL_db=dp_b.d;printf("If dL/dOutput = (1.0, 0.0), then (dL/da, dL/db) at (1.0, 2.4) = (%f, %f)",dL_da,dL_db);}

Differentiable Type System

Slang will only generate differentiation code for values that has adifferentiable type. Differentiable types are defining through conformance to one of two built-in interfaces:

  1. IDifferentiable: For value types (e.g.float, structs of value types, etc..)
  2. IDifferentiablePtrType: For buffer, pointer & reference types that represent locations rather than values.

Differentiable Value Types

All basic types (float,int,double, etc..) and all aggregate types (i.e.struct) that use any combination of these are considered value types in Slang.

Slang uses theIDifferentiable interface to define differentiable types. Basic types that describe a continuous value (float,double andhalf) and their vector/matrix versions (float3,half2x2, etc..) are defined as differentiable by the standard library. For all basic types, the type used for the differential (can be obtained withT.Differential) is the same as the primal.

Builtin Differentiable Value Types

The following built-in types are differentiable:

  • Scalars:float,double andhalf.
  • Vector/Matrix:vector andmatrix offloat,double andhalf types.
  • Arrays:T[n] is differentiable ifT is differentiable.
  • Tuples:Tuple<each T> is differentiable ifT is differentiable.

User-defined Differentiable Value Types

However, it is easy to define your own differentiable types.Typically, all you need is to implement theIDifferentiable interface.

structMyType:IDifferentiable{floatx;floaty;};

The main requirement of a type implementingIDifferentiable is theDifferential associated type that the compiler uses to carry the corresponding derivative.In most cases theDifferential of a type can be itself, though it can be different if necessary.You can access the differential of any differentiable type throughType.Differential

Example:

MyTypeobj;obj.x=1.f;MyType.Differentiald_obj;// Differentiable fields will have a corresponding field in the diff typed_obj.x=1.f;

Slang can automatically derive theDifferential type in the majority of cases.For instance, forMyType, Slang can infer the differential trivially:

structMyType:IDifferentiable{// Automatically inserted by Slang from the fact that// MyType has 2 floats which are both differentiable//typealiasDifferential=MyType;// ...}

For more complex types that aren’t fully differentiable, a new type is synthesized automatically:

structMyPartialDiffType:IDifferentiable{// Automatically inserted by Slang based on which fields are differentiable.typealiasMyPartialDiffType=syn_MyPartialDiffType_Differential;floatx;uinty;};// Synthesizedstructsyn_MyPartialDiffType_Differential{// Only one field since 'y' does not conform to IDifferentiablefloatx;};

You can make existing types differentiable through Slang’s extension mechanism.For instance,extension MyType : IDifferentiable { } will makeMyType differentiable retroactively.

See theIDifferentiablereference documentation for more information on how to override the default behavior.

DifferentialPair: Pairs of differentiable value types

TheDifferentialPair<T> type is used to pass derivatives to a derivative call by representing a pair of values of typeT andT.Differential. Note thatT must conform toIDifferentiable.

DifferentialPair<T> can either be created via constructor calls or thediffPair utility method.

Example:

MyTypeobj={1.f,2.f};MyType.Differentiald_obj={0.4f,3.f};// The differential part of a differentiable-pair is of the diff type.DifferentialPair<MyType>dp_obj=diffPair(obj,d_obj);// Use .p to extract the primal partMyTypenew_p_obj=dp_obj.p;// Use .d to extract the differential partMyType.Differentialnew_d_obj=dp_obj.d;

Differentiable Ptr types

Pointer types are any type that represents a location or reference to a value rather than the value itself.Examples include resource types (RWStructuredBuffer,Texture2D), pointer types (Ptr<float>) and references.

TheIDifferentiablePtrType interface can be used to denote types that need to transform into pairs during auto-diff. However, unlikeanIDifferentiable type whose derivative portion is anoutput underbwd_diff, the derivative part ofIDifferentiablePtrType remains an input. This is because only the value is returned as an output, while the location where it needs to be written to, is still effectively an input to the derivative methods.

Note

Support forIDifferentiablePtrType is still experimental. There are no built-in types conforming to this interface, though we plan to add stdlib support in the near future.

IDifferentiablePtrType only requires aDifferential associated type to be specified.

DifferentialPtrPair: Pairs of differentiable ptr types

For types conforming toIDifferentiablePtrType, the corresponding pair to use for passing the derivative counterpart isDifferentialPtrPair<T>, which represents a pair ofT andT.Differential. Objects of this type can be created using a constructor.

Example of defining and using anIDifferentiablePtrType object.

Here is an example of create a differentiable buffer pointer type, and using it within a differentiable function.You can find an interactive sample on the Slang playgroundhere.

structMyBufferPointer:IDifferentiablePtrType{// The differential part is another instance of MyBufferPointer.typealiasDifferential=MyBufferPointer;RWStructuredBuffer<float>buf;uintoffset;};// Link a custom derivative[BackwardDerivative(load_bwd)]floatload(MyBufferPointerp,uintindex){returnp.buf[p.offset+index];}// Note that the backward derivative signature is still an 'in' differential pair.voidload_bwd(DifferentialPtrPair<MyBufferPointer>p,uintindex,floatdOut){MyBufferPointerdiff_ptr=p.d;diff_ptr.buf[diff_ptr.offset+index]+=dOut;}[Differentiable]floatsumOfSquares<letN:int>(MyBufferPointerp){floatsos=0.f;[MaxIters(N)]for(uinti=0;i<N;i++){floatval_i=load(p,i);sos+=val_i*val_i;}returnsos;}RWStructuredBuffer<float>inputs;RWStructuredBuffer<float>derivs;voidmain(){MyBufferPointerptr={inputs,0};print("Sum of squares of first 10 values: ",sumOfSquares<10>(ptr));MyBufferPointerderiv_ptr={derivs,0};// Pass a pair of pointers as input.bwd_diff(sumOfSquares<10>)(DifferentialPtrPair<MyBufferPointer>(ptr,deriv_ptr),1.0);print("Derivative of result w.r.t the 10 values: \n");for(uinti=0;i<10;i++)print("%d: %f\n",i,load(deriv_ptr,i));}

User-Defined Derivative Functions

As an alternative to compiler-generated derivatives, you can choose to provide an implementation for the derivative, which the compiler will use instead of attempting to generate one.

This can be performed on a per-function basis by using the decorators[ForwardDerivative(fwd_deriv_func)] and[BackwardDerivative(bwd_deriv_func)] to reference the derivative from the primal function.

For instance, it often makes little sense to differentiate the body of asin(x) implementation, when we know that the derivative iscos(x) * dx. In Slang, this can be represented in the following way:

DifferentialPair<float>sin_fwd(DifferentialPair<float>dpx){floatx=dpx.p;floatdx=dpx.d;returnDifferentialPair<float>(dpx.p,cos(x)*dx);}// sin() is now considered differentiable (atleast for forward-mode) since it provides// a derivative implementation.//[ForwardDerivative(sin_fwd)]floatsin(floatx){// Calc sin(X) using Taylor series..}// Any uses of sin() in a `[Differentiable]` will automaticaly use the sin_fwd implementation when differentiated.

A similar example for a backward derivative.

voidsin_bwd(inoutDifferentialPair<float>dpx,floatdresult){floatx=dpx.p;// Write-back the derivative to each input (the primal part must be copied over as-is)dpx=DifferentialPair<float>(x,cos(x)*dresult);}[BackwardDerivative(sin_bwd)]floatsin(floatx){// Calc sin(X) using Taylor series..}

Note that the signature of the provided forward or backward derivative function must match the expected signature from invokingfwd_diff(fn)/bwd_diff(fn)For a full list of signature rules, see the reference section for theauto-diff operators.

Back-referencing User Derivative Attributes.

Sometimes, the original function’s definition might be inaccessible, so it can be tricky to add an attribute to create the association.

For such cases, Slang provides the[ForwardDerivativeOf(primal_fn)] and[BackwardDerivativeOf(primal_fn)] attributes that can be usedon the derivative function and contain a reference to the function for which they are providing a derivative implementation.As long as both the derivative function is in scope, the primal function will be considered differentiable.

Example:

// Module Afloatsin(floatx){/* ... */}// Module BimportA;[BackwardDerivativeOf(sin)]// Add a derivative implementation for sin() in module A.voidsin_bwd(inoutDifferentialPair<float>dpx,floatdresult){/* ... */}

User-defined derivatives also work for generic functions, member functions, accessors, and more. See the reference section for the[ForwardDerivative(fn)] and[BackwardDerivative(fn)] attributes for more.

Using Auto-diff with Generics

Automatic differentiation works seamlessly with generically-defined types and methods.For generic methods, differentiability of a type is defined either through an explicitIDifferentiable constraint or any otherinterface that extendsIDifferentiable.

Example for generic methods:

[Differentiable]TcalcFoo<T:IDifferentiable>(Tx){/* ... */}[Differentiable]TcalcBar<T:__BuiltinFloatingPointType>(Tx){/* ... */}[Differentiable]voidmain(){DifferentialPair<float4>dpa=/* ... */;// Can call with any type that is IDifferentiable. Generic parameters// are inferred like any other call.//bwd_diff(calcFoo)(dpa,float4(1.f));// But you can also be explicit with < >bwd_diff(calcFoo<float4>)(dpa,float4(1.f));// x is differentiable for calcBar because// __BuiltinFloatingPointType : IDifferentiable//DifferentialPair<double>dpb=/* .. */;bwd_diff(calcBar)(dpb,1.0);}

You can implementIDifferentiable on a generic type. Automatic synthesis still applies and will usegeneric constraints to resolve whether a field is differentiable or not.

structFoo<T:IDifferentiable,U>:IDifferentiable{Tt;Uu;};// The synthesized Foo<T, U>.Differential will contain a field for// 't' but not 'U'//

Using Auto-diff with Interface Requirements and Interface Types

For interface requirements, using[Differentiable] attribute enforces that any implementation of that method must also bedifferentiable. You can, of course, provide a manual derivative implementation to satisfy the requirement.

The following is a sample snippet. You can run the full sample on the playgroundhere.

interfaceIFoo{[Differentiable]floatcalc(floatx);}structFooImpl:IFoo{// Implementation via automatic differentiation.[Differentiable]floatcalc(floatx){/* ... */}}structFooImpl2:IFoo{// Implementation via manually providing derivative methods.[ForwardDerivative(calc_fwd)][BackwardDerivative(calc_bwd)]floatcalc(floatx){/* ... */}DifferentialPair<float>calc_fwd(DifferentialPair<float>x){/* ... */}voidcalc_bwd(inoutDifferentialPair<float>x,floatdresult){/* ... */}}[Differentiable]floatcompute(floatx,uintobj_id){// Create an instance of either FooImpl1 or FooImpl2IFoofoo=createDynamicObject<IFoo>(obj_id);// Dynamic dispatch to appropriate 'calc'.//// Note that foo itself is non-differentiable, and// has no differential data, but 'x' and 'result'// will carry derivatives.s//varresult=foo.calc(x);returnresult;}

Differentiable Interface (and Associated) Types

Note: This is an advanced use-case and support is currently experimental.

You can have an interface or an interface associated type extendIDifferentiable and use that in differentiable interface requirement functions. This is often important in large code-bases with modular components that are all differentiable (one example is the material system in large production renderers)

Here is a snippet of how to make an interface and associated type (and by consequence all its implementations) differentiable. For a full working sample, check out the Slang playgroundhere

interfaceIFoo:IDifferentiable{associatedtypeBaseType:IDifferentiable;[Differentiable]BaseTypefoo(BaseTypex);};[Differentiable]floatcalc(floatx){// Note that since IFoo is differentiable,// any data in the IFoo implementation is differentiable// and will carry derivatives.//IFooobj=makeObj(/* ... */);returnobj.foo(x);}

Under the hood, Slang will automatically construct an anonymous abstract type to represent the differentials. However, on targets that don’t support true dynamic dispatch, these are lowered into tagged unions. While we are working to improve the implementation, this union can currently include all active differential types, rather than just the relevant ones. This can lead to increased memory use.

Primal Substitute Functions

Sometimes it is desirable to replace a function with another when generating derivative code. Most often, this is because a lot of shader operations may just not have a function body, such hardware intrinsics fortexture sampling. In such cases, Slang provides a[PrimalSubstitute(fn)] attribute that can be used to providea reference implementation that Slang can differentiate to generate the derivative function.

The following is a small snippet with bilinear texture sampling. For a full example application that uses this concept, see thetexture differentiation sample in the Slang repository.

[PrimalSubstitute(sampleTextureBiliear_reference)]float4sampleTextureBilinear(Texture2D<float4>x,float2loc){// HW-accelerated sampling intrinsics.// Slang does not have access to body, so cannot differentiate.//x.Sample(/*...*/)}// Since the substitute is differentiable, so is `sampleTextureBilinear`.[Differentiable]float4sampleTextureBilinear_reference(Texture2D<float4>x,float2loc){// Reference SW interpolation, that is differentiable.}[Differentiable]floatcomputePixel(Texture2D<float>x,floata,floatb){// Slang will use HW-accelerated sampleTextureBilinear for standard function// call, but differentiate the SW reference interpolation during backprop.//float4sample1=sampleTextureBilinear(x,float2(a,1));}

Similar to[ForwardDerivativeOf(fn)] and[BackwardDerivativeOf(fn)] attributes, Slang provides a[PrimalSubstituteOf(fn)] attribute that can be used on the substitute function to reference the primal one.

Working with Mixed Differentiable and Non-Differentiable Code

Introducing differentiability to an existing system often involves dealing with code that mixes differentiable and non-differentiable logic.Slang provides type checking and code analysis features to allow users to clarify the intention and guard against unexpected behaviors involving when to propagate derivatives through operations.

Excluding Parameters from Differentiation

Sometimes we do not wish a parameter to be considered differentiable despite it has a differentiable type. We can use theno_diff modifier on the parameter to inform the compiler to treat the parameter as non-differentiable and skip generating differentiation code for the parameter. The syntax is:

// Only differentiate this function with regard to `x`.floatmyFunc(no_difffloata,floatx);

The forward derivative and backward propagation functions ofmyFunc should have the following signature:

DifferentialPair<float>fwd_derivative(floata,DifferentialPair<float>x);voidback_prop(floata,inoutDifferentialPair<float>x,floatdResult);

In addition, theno_diff modifier can also be used on the return type to indicate the return value should be considered non-differentiable. For example, the function

no_difffloatmyFunc(no_difffloata,floatx,outfloaty);

Will have the following forward derivative and backward propagation function signatures:

floatfwd_derivative(floata,DifferentialPair<float>x);voidback_prop(floata,inoutDifferentialPair<float>x,floatd_y);

By default, the implicitthis parameter will be treated as differentiable if the enclosing type of the member method is differentiable. If you wish to excludethis parameter from differentiation, use[NoDiffThis] attribute on the method:

structMyDifferentiableType:IDifferentiable{[NoDiffThis]// Make `this` parameter `no_diff`.floatcompute(floatx){...}}

Excluding Struct Members from Differentiation

When using automaticIDifferentiable conformance synthesis for astruct type, Slang will by-default treat all struct members that have a differentiable type as differentiable, and thus include a corresponding field in the generatedDifferential type for the struct.For example, given the following definition

structMyType:IDifferentiable{floatmember1;float2member2;}

Slang will generate:

structMyType.Differential:IDifferentiable{floatmember1;// derivative for MyType.member1float2member2;// derivative for MyType.member2}

If the user does not want a certain member to be treated as differentiable despite it has a differentiable type, ano_diff modifier can be used on the struct member to exclude it from differentiation.For example, the following code excludesmember1 from differentiation:

structMyType:IDifferentiable{no_difffloatmember1;// excluded from differentiationfloat2member2;}

The generatedDifferential in this case will be:

structMyType.Differential:IDifferentiable{float2member2;}

Assigning Differentiable Values into a Non-Differentiable Location

When a value with derivatives is being assigned to a location that is not differentiable, such as a struct member that is marked asno_diff, the derivative info is discarded and any derivative propagation is stopped at the assignment site.This may lead to unexpected results. For example:

structMyType:IDifferentiable{no_difffloatmember;floatsomeOtherMember;}[Differentiable]floatf(floatx){MyTypet;t.member=x*x;// Error: assigning value with derivative into a non-differentiable location.returnt.member;}

In this case, we are assigning the valuex*x, which carries a derivative, into a non-differentiable locationMyType.member, thus throwing away any derivative info. Whenf returnst.member, there will be no derivative associated with it, so the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with derivative info into a non-differentiable location as a compile-time error. To eliminate this error, the user should either maket.member differentiable, or to force the assignment by clarifying the intention to discard any derivatives using the built-indetach method.The following code will compile, and the derivatives will be discarded:

[Differentiable]floatf(floatx){MyTypet;// OK: the code has expressed clearly the intention to discard the derivative and perform the assignment.t.member=detach(x*x);returnt.member;}

Calling Non-Differentiable Functions from a Differentiable Function

Calling non-differentiable function from a differentiable function is allowed. However, derivatives will not be propagated through the call. The user is required to clarify the intention by prefixing the call with theno_diff keyword. An un-clarified call to non-differentiable function will result in a compile-time error.

For example, consider the following code:

floatg(floatx){return2*x;}[Differentiable]floatf(floatx){// Error: implicit call to non-differentiable function g.returng(x)+x*x;}

The derivative will not propagate through the call tog inf. As a result,fwd_diff(f)(diffPair(1.0, 1.0)) will return{3.0, 2.0} instead of{3.0, 4.0} as the derivative from2*x is lost through the non-differentiable call. To prevent unintended error, it is treated as a compile-time error to callg fromf. If such a non-differentiable call is intended, ano_diff prefix is required in the call:

[Differentiable]floatf(floatx){// OK. The intention to call a non-differentiable function is clarified.returnno_diffg(x)+x*x;}

However, theno_diff keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependent on the derivative being propagated through the call.

Treat Non-Differentiable Functions as Differentiable

Slang allows functions to be marked with a[TreatAsDifferentiable] attribute for them to be considered as differentiable functions by the type-system. When a function is marked as[TreatAsDifferentiable], the compiler will not generate derivative propagation code from the original function body or perform any additional checking on the function definition. Instead, it will generate trivial forward and backward propagation functions that returns 0.

This feature can be useful if the user marked aninterface method as forward or backward differentiable, but only wish to provide non-trivial derivative propagation functions for a subset of types that implement the interface. For other types that does not actually need differentiation, the user can simply put[TreatAsDifferentiable] on the method implementations for them to satisfy the interface requirement.

See the following code for an example of[TreatAsDifferentiable]:

interfaceIFoo{[Differentiable]floatf(floatv);}structB:IFoo{[TreatAsDifferentiable]floatf(floatv){returnv*v;}}[Differentiable]floatuse(IFooo,floatx){returno.f(x);}// Test:Bobj;floatresult=fwd_diff(use)(obj,diffPair(2.0,1.0)).d;// result == 0.0, since `[TreatAsDifferentiable]` causes a trivial derivative implementation// being generated regardless of the original code.

Higher-Order Differentiation

Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to usefwd_diff andbwd_diff operators inside a forward or backward differentiable function, or to nestfwd_diff andbwd_diff operators. For example,fwd_diff(fwd_diff(sin)) will have the following signature:

DifferentialPair<DifferentialPair<float>>sin_diff2(DifferentialPair<DifferentialPair<float>>x);

The input parameterx contains four fields:x.p.p,x.p.d,,x.d.p,x.d.d, wherex.p.p specifies the original input value, bothx.p.d andx.d.p store the first order derivative ifx, andx.d.d stores the second order derivative ofx. Callingfwd_diff(fwd_diff(sin)) withdiffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0)) will result{ { 1.0, 0.0 }, { 0.0, -1.0 } }.

User defined higher-order derivative functions can be specified by using[ForwardDerivative] or[BackwardDerivative] attribute on the derivative function, or by using[ForwardDerivativeOf] or[BackwardDerivativeOf] attribute on the higher-order derivative function.

Restrictions and Known Issues

The compiler can generate forward derivative and backward propagation implementations for most uses of array and struct types, including arbitrary read and write access at dynamic array indices, and supports uses of all types of control flows, mutable parameters, generics and interfaces. This covers the set of operations that is sufficient for a lot of functions. However, the user needs to be aware of the following restrictions when using automatic differentiation:

  • All operations to global resources, global variables and shader parameters, including texture reads or atomic writes, are treated as a non-differentiable operation. Slang provides support for special data-structures (such asTensor) through libraries such asSlangPy, which come with custom derivative implementations
  • If a differentiable function contains calls that cause side-effects such as updates to global memory, there is currently no guarantee on how many times side-effects will occur during the resulting derivative function or back-propagation function.
  • Loops: Loops must have a bounded number of iterations. If this cannot be inferred statically from the loop structure, the attribute[MaxIters(<count>)] can be used specify a maximum number of iterations. This will be used by compiler to allocate space to store intermediate data. If the actual number of iterations exceeds the provided maximum, the behavior is undefined. You can always mark a loop with the[ForceUnroll] attribute to instruct the Slang compiler to unroll the loop before generating derivative propagation functions. Unrolled loops will be treated the same way as ordinary code and are not subject to any additional restrictions.
  • Double backward derivatives (higher-order differentiation): The compiler does not currently support multiple backward derivative calls such asbwd_diff(bwd_diff(fn)). The vast majority of higher-order derivative applications can be acheived more efficiently via multiple forward-derivative calls or a single layer ofbwd_diff on functions that use one or morefwd_diff passes.

The above restrictions do not apply if a user-defined derivative or backward propagation function is provided.

Reference

This section contains some additional information for operators that are not currently included in thestandard library reference

fwd_diff(f : slang_function) -> slang_function

Thefwd_diff operator can be used on a differentiable function to obtain the forward derivative propagation function.

A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. Given an original function, the signature of its forward propagation function is determined using the following rules:

  • If the return typeR implementsIDifferentiable the forward propagation function will return a correspondingDifferentialPair<R> that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified asR.
  • If a parameter has typeT that implementsIDifferentiable, it will be translated into aDifferentialPair<T> parameter in the derivative function, where the differential component of theDifferentialPair holds the initial derivatives of each parameter with regard to their upstream parameters.
  • If a parameter has typeT that implementsIDifferentiablePtrType, it will be translated into aDifferentialPtrPair<T> parameter where the differential component references the differential component.
  • All parameter directions are unchanged. For example, anout parameter in the original function will remain anout parameter in the derivative function.
  • Differentiable methods cannot have a type implementingIDifferentiablePtrType as anout orinout parameter, or a return type. Types implementingIDifferentiablePtrType can only be used for input parameters to a differentiable method. Marking such a method as[Differentiable] will result in a compile-time diagnostic error.

For example, given original function:

[Differentiable]Roriginal(T0p0,inoutT1p1,T2p2,T3p3);

WhereR,T0,T1 : IDifferentiable,T2 is non-differentiable, andT3 : IDifferentiablePtrType, the forward derivative function will have the following signature:

DifferentialPair<R>derivative(DifferentialPair<T0>p0,inoutDifferentialPair<T1>p1,T2p2,DifferentialPtrPair<T3>p3);

This forward propagation function takes the initial primal value ofp0 inp0.p, and the partial derivative ofp0 with regard to some upstream parameter inp0.d. It takes the initial primal and derivative values ofp1 and updatesp1 to hold the newly computed value and propagated derivative. Sincep2 is not differentiable, it remains unchanged.

bwd_diff(f : slang_function) -> slang_function

A backward derivative propagation function propagates the derivative of the function output to all the input parameters simultaneously.

Given an original functionf, the general rule for determining the signature of its backward propagation function is that a differentiable outputo becomes an input parameter holding the partial derivative of a downstream output with regard to the differentiable output, i.e. \(\partial y/\partial o\); an input differentiable parameteri in the original function will become an output in the backward propagation function, holding the propagated partial derivative \(\partial y/\partial i\); and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function.

More specifically, the signature of its backward propagation function is determined using the following rules:

  • A backward propagation function always returnsvoid.
  • A differentiablein parameter of typeT : IDifferentiable will become aninout DifferentialPair<T> parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input.
  • A differentiableout parameter of typeT : IDifferentiable will become anin T.Differential parameter, carrying the partial derivative of some downstream term with regard to the return value.
  • A differentiableinout parameter of typeT : IDifferentiable will become aninout DifferentialPair<T> parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter willnot be updated.
  • A differentiable return value of typeR will become an additionalin R.Differential parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function.
  • A non-differentiable return value of typeNDR will be dropped.
  • A non-differentiablein parameter of typeND will remain unchanged in the backward propagation function.
  • A non-differentiableout parameter of typeND will be removed from the parameter list of the backward propagation function.
  • A non-differentiableinout parameter of typeND will become anin ND parameter.
  • Types implementedIDifferentiablePtrType work the same was as the forward-mode case. They can only be used within parameters, and are converted intoDifferentialPtrPair types. Their directions arenot affected.

For example consider the following original function:

structT:IDifferentiable{...}structR:IDifferentiable{...}structP:IDifferentiablePtrType{...}structND{}// Non differentiable[Differentiable]Roriginal(Tp0,outTp1,inoutTp2,NDp3,outNDp4,inoutNDp5,Pp6);

The signature of its backward propagation function is:

voidback_prop(inoutDifferentialPair<T>p0,T.Differentialp1,inoutDifferentialPair<T>p2,NDp3,NDp5,DifferentialPtrPair<P>p6,R.DifferentialdResult);

Note that althoughp2 is stillinout in the backward propagation function, the backward propagation function will only write propagated derivative top2.d and will not modifyp2.p.

Built-in Differentiable Functions

The following built-in functions are differentiable and both their forward and backward derivative functions are already defined in the standard library’s core module:

  • Arithmetic functions:abs,max,min,sqrt,rcp,rsqrt,fma,mad,fmod,frac,radians,degrees
  • Interpolation and clamping functions:lerp,smoothstep,clamp,saturate
  • Trigonometric functions:sin,cos,sincos,tan,asin,acos,atan,atan2
  • Hyperbolic functions:sinh,cosh,tanh
  • Exponential and logarithmic functions:exp,exp2,pow,log,log2,log10
  • Vector functions:dot,cross,length,distance,normalize,reflect,refract
  • Matrix transforms:mul(matrix, vector),mul(vector, matrix),mul(matrix, matrix)
  • Matrix operations:transpose,determinant
  • Legacy blending and lighting intrinsics:dst,lit

[8]ページ先頭

©2009-2026 Movatter.jp