Interface Design¶
JAX
is an extensible system for transforming numerical functions.
The basic transformations are: grad
, jit
, vmap
, and
pmap
.
To make XAJ
interpolate well with JAX
, we design odeint
to
behave the same way as grad
and be composable with other JAX
transformations.
This is challenging because we are mixing autodiff, which works well
with pure functions, with numerical integration, which requires some
internal states.
Solution Callable¶
Consider the ordinary differential equations (ODEs)
and use a python callable x(t)
to represent its (numerical)
solution.
We want to design XAJ
in a way that it works naturally with
JAX
’s automatic differentiation interface.
We would like
jax.jacfwd(x)
to return the right hand side callablef(t, x)
.We would like evaluating
x(t)
at differentt
as simple as function calls, i.e.,x(0.0)
,x(1.0)
,x(2.0)
, …
In addition, motivated by general relativistic ray tracing (GRRT) applications, where it is common to integrate the geodesics backward in the affine parameter, we also want the following.
XAJ
should support integrating in both positive and negativet
directions.
Composability¶
To make XAJ
work with the rest of the JAX
ecosystem, we
require the following.
odeint
should have built-in pytree support just likegrad
. See, e.g., theJAX
MLP example.The numerical solution
x(t)
should support otherJAX
transformations such asjit
andvmap
.
Because of the Single Instruction, Multiple Data (SIMD) architectures
of GPUs and TPUs, we don’t expect performance gain by evaluating
different systems of ODEs at the same time.
Hence, vmap
over different f(t, x)
is not supported in
XAJ
, just like JAX
does not support vmapping to multiple
functions vmap(grad)([f1, f2, f3])
.
Although it may seem natural to allow implicit vectorization, i.e., to
use x(jnp.array([0.0, 1.0, 2.0]))
to evaluate x(t)
pointwisely
on the array t
, we purposely disfavor it in order to be
consistent with JAX
’s derivative interface grad
, jacfwd
,
and jacrev
.
The preferred way to evaluate
x(t)
pointwisely is tovmap
it for arbitrary pytreet
, i.e.,vmap(x)(t)
.We should support
vmap
over the initial conditions. This complicates our design but we can expect something similar tox = vmap(odeint(f))(t0, x0)
to work, wheret0
andx0
are arbitrary pytrees.We should also support
vmap
over auxiliary parameters of the ODEs. This allows, for example, integrating geodesics around multiple black holes with different spins. The interface should be compatible with vmapping over initial conditions, i.e.,x = vmap(odeint(f))(aux=aux)
whereaux
is an arbitrary pytree.
Call Signature¶
How should we design the call signature of odeint
, its invert
function, and the numerical solution?
Because ODEs are uniquely specified only when the initial conditions
are given, XAJ
’s integration interface is more complicated than
JAX
’s derivative interface.
Let’s consider a specific ODE
It has analytical solution \(x(t) = c e^t - t - 1\). Using the initial condition \(x(t_0) = x_0\), we can rewrite the analytical solution as
It is straightforward to verify \(\partial_t x(t; t_0, x_0) = x -
t = f(t, x)\), where \(t_0\) and \(x_0\) can be seen as
parameters of the solution \(x\).
Given that jacfwd
(or jacrev
or grad
) by default takes the
derivatives with respect to only the first argument, we can use the
convention
odeint: f(t, x, aux) -> x(t, t0, x0, aux)
jacfwd: x(t, t0, x0, aux) -> f(t, x, aux)
where t
and t0
are scalars and f
, x
, and x0
may be
arrays.
In the special case that f
is independent of t
, we have
odeint: f(x, aux) -> x(t-t0, x0, aux)
jacfwd: x(t-t0, x0, aux) -> f(x, aux)
We may see odeint
as a functional (or high-order function) that
adds a new independent variable t
, while jacfwd
is its invert
and removes the independent variable.
Dual Integrators¶
XAJ
needs to support integration in both positive and negative
directions according to requirement 3 above.
Unless the numerical integrator is symmetric, its behavior changes
depending on its integration direction.
Thus, what should be the meaning of selecting an integrator for
XAJ
?
It is useful to guide our design based on consistency.
Suppose that an explicit integrator, e.g., forward Euler, is selected.
One consistent way to use the integrator is that the integration
from x0 = x(t0)
to x1 = x(t1)
always give the same
values of x0
and x1
, independent on whether (t0, x0)
is
chosen as the initial condition and integrate is done toward t1
,
or (t1, x1)
is chosen as the initial condition and integrate is
done toward t0
.
However, for this requirement to hold, we need to use the backward
Euler method when integrating from t1
to t0
.
This motivates the concept of dual integrators.
When setting up an integrator in XAJ
, we should enable selecting a
pair of dual integrator, in addition using a single integrator for
both positive and negative directions.