Catalyst enables just-in-time (JIT) and ahead-of-time (AOT) compilation ofquantum programs and workflows, while taking into account both classical and quantum code, andultimately leverages modern compilation tools to speed up quantum applications.
You can imagine compiling a function once in advance and then benefit from fasterexecution on all subsequent calls of the function, similar to thejax.jit
functionality.However, compared to JAX we are also able to compile the quantum code natively without havingto rely on callbacks to any Python-based PennyLane devices. We can thus compile/execute entire workflows(such as variational algorithms) as a single program or unit, without having to go back and forth betweendevice execution and the Python interpreter.
The first thing we need to do is importqjit()
and QJIT compatible methods in Catalyst,as well asPennyLane and the version ofNumPyprovided by JAX.
fromcatalystimportqjit,measure,cond,for_loop,while_loop,gradimportpennylaneasqmlfromjaximportnumpyasjnp
You should be able to express your quantum functions in the way you are accustomed to usingPennyLane. However, some of PennyLane’s features may not be fully supported yet, such as optimizers.
Warning
Not all PennyLane devices currently work with Catalyst. Supported backend devices includelightning.qubit
,lightning.kokkos
,lightning.gpu
, andbraket.aws.qubit
. Fora full of supported devices, please seeSupported devices.
PennyLane tapes are still used internally by Catalyst and you can express your circuits in theway you are used to, as long as you ensure that all operations are added to the main tape.
Let’s start learning more about Catalyst by running a simple circuit.
@qml.qnode(qml.device("lightning.qubit",wires=2))defcircuit(theta):qml.Hadamard(wires=0)qml.RX(theta,wires=1)qml.CNOT(wires=[0,1])returnqml.expval(qml.PauliZ(wires=1))
In PennyLane, theqml.qnode()
decorator creates a device specific quantum function. For each quantumfunction, we can specify the number of wires.
Theqjit()
decorator can be used to jit a workflow of quantum functions:
jitted_circuit=qjit(circuit)
>>>jitted_circuit(0.7)Array(0., dtype=float64)
In Catalyst, dynamic wire values are fully supported for operations, observables and measurements.For example, the following circuit can be jitted with wires as arguments:
@qjit@qml.qnode(qml.device("lightning.qubit",wires=5))defcircuit(arg0,arg1,arg2):qml.RX(arg0,wires=[arg1+1])qml.RY(arg0,wires=[arg2])qml.CNOT(wires=[arg1,arg2])returnqml.probs(wires=[arg1+1])
>>>circuit(jnp.pi/3,1,2)Array([0.625, 0.375], dtype=float64)
Catalyst allows you to usequantum operationsavailable in PennyLane either via native support by the runtime or PennyLane’s decomposition rules.Theqml.adjoint()
andqml.ctrl()
functions inPennyLane are also supported via the decomposition mechanism in Catalyst. For example,
@qml.qnode(qml.device("lightning.qubit",wires=2))defcircuit():qml.Rot(0.3,0.4,0.5,wires=0)qml.adjoint(qml.SingleExcitation(jnp.pi/3,wires=[0,1]))returnqml.state()
In addition, you can qjit mostPennyLane templates to easily construct and evaluatemore complex quantum circuits.
Important
Decomposition will generally happen in accordance with the specification provided by devices,which can vary from device to device (e.g.,default.qubit
andlightning.qubit
mightdecompose quite differently).However, Catalyst’s decomposition logic will differ in the following cases:
For devices without native controlled gates support (e.g.,lightning.kokkos
andlightning.gpu
), allqml.Controlled
operations will decompose toqml.QubitUnitary
operations.
The set of operations supported by Catalyst itself can in some instances lead to additional decompositions compared to the device itself.
The Catalyst has support forPennyLane observables.
For example, the following circuit is a QJIT compatible function that calculates the expectation value ofa tensor product of aqml.PauliX
,qml.Hadamard
andqml.Hermitian
observables.
@qml.qnode(qml.device("lightning.qubit",wires=3))defcircuit(x,y):qml.RX(x,0)qml.RX(y,1)qml.CNOT([0,2])qml.CNOT([1,2])h_matrix=jnp.array([[complex(1.0,0.0),complex(2.0,0.0)],[complex(2.0,0.0),complex(-1.0,0.0)]])returnqml.expval(qml.PauliX(0)@qml.Hadamard(1)@qml.Hermitian(h_matrix,2))
Most PennyLanemeasurement processesare supported in Catalyst, although not all features are supported for all measurement types.
The expectation value of observables is supported analytically as well as with finite-shots. | |
The variance of observables is supported analytically as well as with finite-shots. | |
Samples in the computational basis only are supported. | |
Sample counts in the computational basis only are supported. | |
The probabilities is supported in the computational basis as well as with finite-shots. | |
The state in the computational basis only is supported. | |
The projective mid-circuit measurement is supported via its own operation in Catalyst. |
For bothqml.sample()
andqml.counts()
omitting the wiresparameters produces samples on all declared qubits in the same format as in PennyLane.
Counts are returned a bit differently, namely as a pair of arrays representing a dictionary from basis statesto the number of observed samples. We thus have to do a bit of extra work to display them nicely.Note that the basis states are represented in their equivalent binary integer representation, inside of afloat data type. This way they are compatible with eigenvalue sampling, but this may change in the future.
@qjit@qml.qnode(qml.device("lightning.qubit",wires=2,shots=1000))defcounts():qml.Rot(0.1,0.2,0.3,wires=[0])returnqml.counts(wires=[0])basis_states,counts=counts()
>>>{format(int(state),'01b'):countforstate,countinzip(basis_states,counts)}{'0': 985, '1': 15}
You can specify the number of shots to be used in sample-based measurements when you create a device.qml.sample()
andqml.counts()
willautomatically use the device’sshots
parameter when performing measurements.In the following example, the number of shots is set to\(500\) in the device instantiation.
Note
You can return any combination of measurement processes as a tuple from quantum functions.In addition, Catalyst allows you to return any classical values computed inside quantum functions as well.
@qjit@qml.qnode(qml.device("lightning.qubit",wires=3,shots=500))defcircuit(params):qml.RX(params[0],wires=0)qml.RX(params[1],wires=1)qml.RZ(params[2],wires=2)return(qml.sample(),qml.counts(),qml.expval(qml.PauliZ(0)),qml.var(qml.PauliZ(0)),qml.probs(wires=[0,1]),qml.state(),)
>>>circuit([0.3,0.5,0.7])(Array([[0, 0, 0], [0, 0, 0], [0, 0, 0], ..., [0, 0, 0], [0, 0, 0], [0, 0, 0]], dtype=int64), (Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int64), Array([453, 0, 31, 0, 16, 0, 0, 0], dtype=int64)), Array(0.936, dtype=float64), Array(0.138816, dtype=float64), Array([0.926, 0.048, 0.026, 0. ], dtype=float64), Array([ 0.89994966-0.32850727j, 0. +0.j , -0.08388168-0.22979488j, 0. +0.j , -0.04964902-0.13601409j, 0. +0.j , -0.0347301 +0.01267748j, 0. +0.j ], dtype=complex128))
The PennyLane projective mid-circuit measurement is also supported in Catalyst.measure()
is a QJIT compatible mid-circuit measurement for Catalyst that onlyrequires a list of wires that the measurement process acts on.
Important
Theqml.measure()
function isnot QJIT compatible andmeasure()
from Catalyst should be used instead:
fromcatalystimportmeasure
In the following example,m
will be equal toTrue
if wire\(0\) is rotated by\(180\) degrees.
@qjit@qml.qnode(qml.device("lightning.qubit",wires=2))defcircuit(x):qml.RX(x,wires=0)m=measure(wires=0)returnm
>>>circuit(jnp.pi)Array(True, dtype=bool)>>>circuit(0.0)Array(False, dtype=bool)
In Catalyst, there are two ways of compiling quantum functions depending on when the compilationis triggered.
In just-in-time (JIT), the compilation is triggered at the call site the first timethe quantum function is executed. For example,circuit
is compiled as early as the first call.
@qjit@qml.qnode(qml.device("lightning.qubit",wires=2))defcircuit(theta):qml.Hadamard(wires=0)qml.RX(theta,wires=1)qml.CNOT(wires=[0,1])returnqml.expval(qml.PauliZ(wires=1))
>>>circuit(0.5)# the first call, compilation occurs hereArray(0., dtype=float64)>>>circuit(0.5)# the precompiled quantum function is calledArray(0., dtype=float64)
An alternative is to trigger the compilation without specifying any concrete values for the functionparameters. This works by specifying the argument signature right in the function definition, whichwill trigger compilation “ahead-of-time” (AOT) before the program is executed. We can use both builtinPython scalar types, as well as the specialShapedArray
type that JAX uses to represent the shapeand data type of a tensor:
fromjax.coreimportShapedArray@qjit# compilation happens at definition@qml.qnode(qml.device("lightning.qubit",wires=2))defcircuit(x:complex,z:ShapedArray(shape=(3,),dtype=jnp.float64)):theta=jnp.abs(x)qml.RY(theta,wires=0)qml.Rot(z[0],z[1],z[2],wires=0)returnqml.state()
>>>circuit(0.2j,jnp.array([0.3,0.6,0.9]))# calls precompiled functionArray([0.75634905-0.52801002j, 0. +0.j, 0.35962678+0.14074839j, 0. +0.j], dtype=complex128)
At this stage the compilation already happened, so the execution ofcircuit
calls the compiled function directly onthe first call, resulting in faster initial execution. Note that implicit type promotion for most datatypes are allowedin the compilation as long as it doesn’t lead to a loss of data.
Catalyst has support for natively compiled control flow as “first-class” components of any quantumprogram, providing a much smaller representation and compilation time for large circuits, and also enablingthe compilation of arbitrarily parametrized circuits.
Catalyst-provided control flow operations:
A | |
A | |
A |
Note
Catalyst supports automatic conversion of native Python controlflow to the Catalyst control flow operations. For more details,see theAutoGraph guide.
cond()
is a functional version of the traditional if-else conditional for Catalyst.This means that each execution path, aTrue
branch and aFalse
branch, is provided as aseparate function. Both functions will be traced during compilation, but only one of them the will beexecuted at runtime, depending of the value of a Boolean predicate. The JAX equivalent is thejax.lax.cond
function, but this version is optimized to work with quantum programs in PennyLane.
Note thatcond()
can also be used outside of theqjit()
context for better interoperability with PennyLane.
Values produced inside the scope of a conditional can be returned to the outside context, butthe return type signature of each branch must be identical. If no values are returned, theFalse
branch is optional. Refer to the example below to learn more about the syntax of thisdecorator.
@cond(predicate:bool)defconditional_fn():# do something when the predicate is truereturn"optionally return some value"@conditional_fn.otherwisedefconditional_fn():# optionally define an alternative execution pathreturn"if provided, return types need to be identical in both branches"ret_val=conditional_fn()# must invoke the defined function
Warning
The conditional functions can only return JAX compatible data types.
for_loop()
andwhile_loop()
are functional versions of the traditional for- andwhile-loop for Catalyst. That is, any variables that are modified across iterations need to beprovided as inputs and outputs to the loop body function. Input arguments contain the value of avariable at the start of an iteration, while output arguments contain the value at the end of theiteration. The outputs are then fed back as inputs to the next iteration. The final iteration valuesare also returned from the transformed function.
for_loop()
andwhile_loop()
can also be interpreted without needing to compile its surrounding context.
The for-loop statement:
Thefor_loop()
executes a fixed number of iterations as indicated via the values specifiedin its header: alower_bound
, anupper_bound
, and astep
size.
The loop body function must always have the iteration index (in the below examplei
) as itsfirst argument and its value can be used arbitrarily inside the loop body. As the value of the indexacross iterations is handled automatically by the provided loop bounds, it must not be returned fromthe body function.
@for_loop(lower_bnd,upper_bnd,step)defloop_body(i,*args):# code to be executed over index i starting# from lower_bnd to upper_bnd - 1 by stepreturnargsfinal_args=loop_body(init_args)
The semantics offor_loop()
are given by the following Python implementation:
foriinrange(lower_bnd,upper_bnd,step):args=body_fn(i,*args)
The while-loop statement:
Thewhile_loop()
, on the other hand, is able to execute an arbitrary number of iterations,until the condition function specified in its header returnsFalse
.
The loop condition is evaluated every iteration and can be any callable with an identical signatureas the loop body function. The return type of the condition function must be a Boolean.
@while_loop(lambda*args:"some condition")defloop_body(*args):# perform some work and update (some of) the argumentsreturnargsfinal_args=loop_body(init_args)
Catalyst-provided gradient operations:
A | |
A | |
A | |
A | |
A |
grad()
is a QJIT compatible grad decorator in Catalyst that can differentiate a hybrid quantum functionusing finite-difference, parameter-shift, or adjoint-jacobian methods. See the documentation for more details.
This decorator accepts the function to differentiate, a differentiation strategy, and the argument indices of the function with which to differentiate:
@qjitdefworkflow(x):@qml.qnode(qml.device("lightning.qubit",wires=1))defcircuit(x):qml.RX(jnp.pi*x,wires=0)returnqml.expval(qml.PauliY(0))g=grad(circuit)returng(x)
>>>workflow(2.0)Array(-3.14159265, dtype=float64)
To specify the differentiation strategy, themethod
argument can be passedto thegrad
function:
method="auto"
: Quantum components of the hybrid function aredifferentiated according to the corresponding QNodediff_method
, whilethe classical computation is differentiated using traditional autodiff.
With this strategy, Catalyst only currently supports QNodes withdiff_method="parameter-shift"
anddiff_method="adjoint"
.
method="fd"
: First-order finite-differences for the entire hybridfunction. Thediff_method
argument for each QNode is ignored.
Currently, higher-order differentiation is only supported by thefinite-difference method. The gradient of circuits with QJIT compatiblecontrol flow is supported for all methods in Catalyst.
You can further provide the step size (h
-value) of finite-difference in thegrad()
method.For example, the gradient call to differentiatecircuit
with respect to its second argument usingfinite-difference andh
-value\(0.1\) should be:
g_fd=grad(circuit,method="fd",argnums=1,h=0.1)
Gradients of quantum functions can be calculated for a range or tensor of parameters.For example,grad(circuit,argnums=[0,1])
would calculate the gradient ofcircuit
using the finite-difference method for the first and second parameters.In addition, the gradient of the following circuit with a tensor of parameters isalso feasible.
@qjitdefworkflow(params):@qml.qnode(qml.device("lightning.qubit",wires=1))defcircuit(params):qml.RX(params[0]*params[1],wires=0)returnqml.expval(qml.PauliY(0))returngrad(circuit,argnums=0)(params)
>>>workflow(jnp.array([2.0,3.0]))Array([-2.88051099, -1.92034063], dtype=float64)
Thegrad()
decorator works for functions that return a scalar value. You can also use thejacobian()
decorator to compute Jacobian matrices of general hybrid functions with multiple or multivariate results.
@qjitdefworkflow(x):@qml.qnode(qml.device("lightning.qubit",wires=1))defcircuit(x):qml.RX(jnp.pi*x[0],wires=0)qml.RY(x[1],wires=0)returnqml.probs()g=jacobian(circuit,method="auto")returng(x)
>>>workflow(jnp.array([2.0,1.0]))Array([[ 3.48786850e-16 -4.20735492e-01] [-8.71967125e-17 4.20735492e-01]], dtype=float64)
This decorator has the same methods and API asgrad
. See the documentation for more details.
You can develop your own optimization algorithm using thegrad()
method, control-flow operators that arecompatible with QJIT, or by utilizing differentiable optimizers inOptax.
Warning
Catalyst currently does not provide any optimization tools and does not support the optimizers offeredby PennyLane. However, this feature is planned for future implementation.
For example, you can useoptax.sgd
in a QJIT workflow to calculatethe gradient descent optimizer. The following example shows a simple use case of thisfeature in Catalyst.
Theoptax.sgd
gets a smooth function of the formgd_fun(params,*args,**kwargs)
and calculates either just the value or both the value and gradient of the function depending onthe value ofvalue_and_grad
argument. To optimize params iteratively, you later need to usejax.lax.fori_loop
to loop over the gradient descent steps.
importoptaxfromjax.laximportfori_loopdev=qml.device("lightning.qubit",wires=1)@qml.qnode(dev)defcircuit(param):qml.Hadamard(0)qml.RY(param,wires=0)returnqml.expval(qml.PauliZ(0))@qjitdefworkflow():defgd_fun(param):diff=grad(circuit,argnums=0)returncircuit(param),diff(param)opt=optax.sgd(learning_rate=0.4)defgd_update(i,args):param,state=args_,gradient=gd_fun(param)(updates,state)=opt.update(gradient,state)param=optax.apply_updates(param,updates)return(param,state)param=0.1state=opt.init(param)(param,_)=fori_loop(0,100,gd_update,(param,state))returnparam
>>>workflow()Array(1.57079633, dtype=float64)
Catalyst programs can also be used inside of a larger JAX workflow which usesJIT compilation, automatic differentiation, and other JAX transforms.
Note
Note that, in general, best performance will be seen when the Catalyst@qjit
decorator is used to JIT the entire hybrid workflow. However, theremay be cases where you may want to delegate only the quantum part of yourworkflow to Catalyst, and let JAX handle classical components (for example,due to missing a feature or compatibility issue in Catalyst).
For example, call a Catalyst qjit-compiled function from within a JAX jit-compiledfunction:
dev=qml.device("lightning.qubit",wires=1)@qjit@qml.qnode(dev)defcircuit(x):qml.RX(jnp.pi*x[0],wires=0)qml.RY(x[1]**2,wires=0)qml.RX(x[1]*x[2],wires=0)returnqml.probs(wires=0)@jax.jitdefcost_fn(weights):x=jnp.sin(weights)returnjnp.sum(jnp.cos(circuit(x))**2)
>>>cost_fn(jnp.array([0.1,0.2,0.3]))Array(1.32269195, dtype=float64)
Catalyst-compiled functions can now also be automatically differentiatedvia JAX, both in forward and reverse mode to first-order,
>>>jax.grad(cost_fn)(jnp.array([0.1,0.2,0.3]))Array([0.49249037, 0.05197949, 0.02991883], dtype=float64)
as well as vectorized usingjax.vmap
:
>>>jax.vmap(cost_fn)(jnp.array([[0.1,0.2,0.3],[0.4,0.5,0.6]]))Array([1.32269195, 1.53905377], dtype=float64)