desc.derivatives.AutoDiffDerivative
- class desc.derivatives.AutoDiffDerivative(fun, argnum=0, mode='fwd', **kwargs)Source
Computes derivatives using automatic differentiation with JAX.
- Parameters:
fun (callable) – Function to be differentiated.
argnum (int, optional) – Specifies which positional argument to differentiate with respect to
mode (str, optional) – Automatic differentiation mode. One of
'fwd'
(forward mode Jacobian),'rev'
(reverse mode Jacobian),'grad'
(gradient of a scalar function),'hess'
(Hessian of a scalar function), or'jvp'
(Jacobian vector product) Default ='fwd'
- Raises:
ValueError, if mode is not supported –
Methods
compute
(*args, **kwargs)Compute the derivative matrix.
compute_jvp
(fun, argnum, v, *args, **kwargs)Compute df/dx*v.
compute_jvp2
(fun, argnum1, argnum2, v1, v2, ...)Compute d^2f/dx^2*v1*v2.
compute_jvp3
(fun, argnum1, argnum2, argnum3, ...)Compute d^3f/dx^3*v1*v2*v3.
compute_vjp
(fun, argnum, v, *args, **kwargs)Compute v.T * df/dx.
Attributes
argument being differentiated with respect to.
function being differentiated.
the kind of derivative being computed (eg
'grad'
).