Interface Design

JAX is an extensible system for transforming numerical functions. The basic transformations are: grad, jit, vmap, and pmap. To make XAJ interpolate well with JAX, we design odeint to behave the same way as grad and be composable with other JAX transformations. This is challenging because we are mixing autodiff, which works well with pure functions, with numerical integration, which requires some internal states.

Solution Callable

Consider the ordinary differential equations (ODEs)

\[\frac{d\mathbf{x}}{dt} = f(t, \mathbf{x})\]

and use a python callable x(t) to represent its (numerical) solution. We want to design XAJ in a way that it works naturally with JAX’s automatic differentiation interface.

  1. We would like jax.jacfwd(x) to return the right hand side callable f(t, x).

  2. We would like evaluating x(t) at different t as simple as function calls, i.e., x(0.0), x(1.0), x(2.0), …

In addition, motivated by general relativistic ray tracing (GRRT) applications, where it is common to integrate the geodesics backward in the affine parameter, we also want the following.

  1. XAJ should support integrating in both positive and negative t directions.

Composability

To make XAJ work with the rest of the JAX ecosystem, we require the following.

  1. odeint should have built-in pytree support just like grad. See, e.g., the JAX MLP example.

  2. The numerical solution x(t) should support other JAX transformations such as jit and vmap.

Because of the Single Instruction, Multiple Data (SIMD) architectures of GPUs and TPUs, we don’t expect performance gain by evaluating different systems of ODEs at the same time. Hence, vmap over different f(t, x) is not supported in XAJ, just like JAX does not support vmapping to multiple functions vmap(grad)([f1, f2, f3]).

Although it may seem natural to allow implicit vectorization, i.e., to use x(jnp.array([0.0, 1.0, 2.0])) to evaluate x(t) pointwisely on the array t , we purposely disfavor it in order to be consistent with JAX’s derivative interface grad, jacfwd, and jacrev.

  1. The preferred way to evaluate x(t) pointwisely is to vmap it for arbitrary pytree t, i.e., vmap(x)(t).

  2. We should support vmap over the initial conditions. This complicates our design but we can expect something similar to x = vmap(odeint(f))(t0, x0) to work, where t0 and x0 are arbitrary pytrees.

  3. We should also support vmap over auxiliary parameters of the ODEs. This allows, for example, integrating geodesics around multiple black holes with different spins. The interface should be compatible with vmapping over initial conditions, i.e., x = vmap(odeint(f))(aux=aux) where aux is an arbitrary pytree.

Call Signature

How should we design the call signature of odeint, its invert function, and the numerical solution? Because ODEs are uniquely specified only when the initial conditions are given, XAJ’s integration interface is more complicated than JAX’s derivative interface.

Let’s consider a specific ODE

\[\frac{dx}{dt} = f(t, x) = x + t.\]

It has analytical solution \(x(t) = c e^t - t - 1\). Using the initial condition \(x(t_0) = x_0\), we can rewrite the analytical solution as

\[x(t; t_0, x_0) = (x_0 + t_0 + 1)e^{t - t_0} - t - 1.\]

It is straightforward to verify \(\partial_t x(t; t_0, x_0) = x - t = f(t, x)\), where \(t_0\) and \(x_0\) can be seen as parameters of the solution \(x\). Given that jacfwd (or jacrev or grad) by default takes the derivatives with respect to only the first argument, we can use the convention

odeint: f(t, x, aux) -> x(t, t0, x0, aux)
jacfwd: x(t, t0, x0, aux) -> f(t, x, aux)

where t and t0 are scalars and f, x, and x0 may be arrays. In the special case that f is independent of t, we have

odeint: f(x, aux) -> x(t-t0, x0, aux)
jacfwd: x(t-t0, x0, aux) -> f(x, aux)

We may see odeint as a functional (or high-order function) that adds a new independent variable t, while jacfwd is its invert and removes the independent variable.

Dual Integrators

XAJ needs to support integration in both positive and negative directions according to requirement 3 above. Unless the numerical integrator is symmetric, its behavior changes depending on its integration direction. Thus, what should be the meaning of selecting an integrator for XAJ?

It is useful to guide our design based on consistency. Suppose that an explicit integrator, e.g., forward Euler, is selected. One consistent way to use the integrator is that the integration from x0 = x(t0) to x1 = x(t1) always give the same values of x0 and x1, independent on whether (t0, x0) is chosen as the initial condition and integrate is done toward t1, or (t1, x1) is chosen as the initial condition and integrate is done toward t0. However, for this requirement to hold, we need to use the backward Euler method when integrating from t1 to t0.

This motivates the concept of dual integrators. When setting up an integrator in XAJ, we should enable selecting a pair of dual integrator, in addition using a single integrator for both positive and negative directions.