qml.capture.subroutine

subroutine(func, static_argnums=None, static_argnames=None)[source]

Denotes the creation of a function in the intermediate representation.

May be used to reduce compilation times. Instead of repeatedly compiling inlined versions of the function passed as a parameter, when functions are annotated with a subroutine, a single version of the function will be compiled and called from potentially multiple callsites.

Note

Subroutines are only available when using the program capture interface. To activate the program capture interface with Catalyst, please set qml.qjit(capture=True).

Parameters:
  • subroutine (Callable) – the function

  • static_argnums (None | int | Sequence[int]) – the indices of the static arguments

  • static_argnames (None | str | Sequence[str]) – the names of static arguments. May be provided instead of static_argnums for readability.

Example

qml.capture.enable()

@qml.capture.subroutine
def f(x, wires):
    qml.RX(x, wires)

@qml.qnode(qml.device('lightning.qubit', wires=5))
def c(x : float):
    f(x, 0)
    f(x, 1)
    return qml.state()

print(jax.make_jaxpr(c)(0.5))
let f = { lambda ; a:f64[] b:i64[]. let
    _:AbstractOperator() = RX[n_wires=1] a b
in () } in
{ lambda ; c:f64[]. let
    d:c128[32] = qnode[
    device=<lightning.qubit device (wires=5) at 0x12aac1c40>
    execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>)
    n_consts=0
    qfunc_jaxpr={ lambda ; e:f64[]. let
        quantum_subroutine_p[
            compiler_options_kvs=()
            ctx_mesh=Mesh(, axis_types=())
            donated_invars=(False, False)
            in_layouts=(None, None)
            in_shardings=(UnspecifiedValue, UnspecifiedValue)
            inline=False
            jaxpr=f
            keep_unused=False
            name=f
            out_layouts=()
            out_shardings=()
        ] e 0:i64[]
        quantum_subroutine_p[
            compiler_options_kvs=()
            ctx_mesh=Mesh(, axis_types=())
            donated_invars=(False, False)
            in_layouts=(None, None)
            in_shardings=(UnspecifiedValue, UnspecifiedValue)
            inline=False
            jaxpr=f
            keep_unused=False
            name=f
            out_layouts=()
            out_shardings=()
        ] e 1:i64[]
        g:AbstractMeasurement(n_wires=0) = state_wires
        in (g,) }
    qnode=<QNode: device='<lightning.qubit device (wires=5) at 0x12aac1c40>', interface='jax', diff_method='best', shots='Shots(total=None)'>
    shots_len=0
    ] c
in (d,)

If we create a qjit version of the QNode, we can inspect the mlir and see a FuncOp that is reused for both calls:

>>> qjit_c = qml.qjit(c)
>>> print(qjit_c.mlir[1010:1300]) 
%0 = quantum.alloc( 5) : !quantum.reg
%1 = call @f(%0, %arg0, %c_0) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%2 = call @f(%1, %arg0, %c) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%3 = quantum.compbasis qreg %2 : !quantum.obs
>>> print(qjit_c.mlir[1465:2070]) 
func.func private @f(%arg0: !quantum.reg, %arg1: tensor<f64>, %arg2: tensor<i64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>} {
    %extracted = tensor.extract %arg2[] : tensor<i64>
    %0 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
    %extracted_0 = tensor.extract %arg1[] : tensor<f64>
    %out_qubits = quantum.custom "RX"(%extracted_0) %0 : !quantum.bit
    %extracted_1 = tensor.extract %arg2[] : tensor<i64>
    %1 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit
    return %1 : !quantum.reg
    }
}

Contents

Using PennyLane

Release news

Development

API

Internals