desc.derivatives.AutoDiffDerivative

class desc.derivatives.AutoDiffDerivative(fun, argnum=0, mode='fwd', **kwargs)[source]

Computes derivatives using automatic differentiation with JAX.

Parameters:
  • fun (callable) – Function to be differentiated.

  • argnum (int, optional) – Specifies which positional argument to differentiate with respect to

  • mode (str, optional) – Automatic differentiation mode. One of 'fwd' (forward mode Jacobian), 'rev' (reverse mode Jacobian), 'grad' (gradient of a scalar function), 'hess' (Hessian of a scalar function), or 'jvp' (Jacobian vector product) Default = 'fwd'

Raises:

ValueError, if mode is not supported

Methods

compute(*args)

Compute the derivative matrix.

compute_jvp(fun, argnum, v, *args)

Compute df/dx*v.

compute_jvp2(fun, argnum1, argnum2, v1, v2, ...)

Compute d^2f/dx^2*v1*v2.

compute_jvp3(fun, argnum1, argnum2, argnum3, ...)

Compute d^3f/dx^3*v1*v2*v3.

Attributes

argnum

argument being differentiated with respect to.

fun

function being differentiated.

mode

the kind of derivative being computed (eg 'grad').