Implementation

To better implement JAX’s interface and design patterns, we need to break its numerical solutions into multiple components.

Stepper

An xaj.core.Stepper advances the state of a system of ODEs to a different state. It is implemented as a pure function using python function closure, and is composible with standard JAX transformations such as jit and vmap.

Error Estimator

An xaj.core.ErrorEstimator estimates the numerical error of a step and is used in adaptive stepsize control. It may be based on step-doubling or embedding technique, or can be a custom function based on, e.g., some conserved quantities in the system of ODEs. It is implemented as a pure function using python function closure, and is composible with standard JAX transformations such as jit and vmap.

Stepsize Controller

An xaj.core.StepController adjusts the step size based on the error estimator and/or custom function. While it may be more natural to implement it as a stateful object, we choose not to do it to match better with the JAX ecosystem. It is implemented as a pure function using python function closure, and is composible with standard JAX transformations such as jit and vmap.

Stepping Engine

We use jax.lax.while_loop to implement the Stepping Engine, which has the semantics

def while_loop(cond, body, state):
    while cond(state):
        state = body(state)
    return state

Because the functions cond and body are pure, we need to pass to them the full state, which includes the current step size, the current solution of the ODEs, the shared states across steps, etc. The body function can be logically implement with the following code.

def body(state):
    _, (t,x), (h,k), (i,r) = state

    E, T, X, K = step(h, t, x, k)
    H, redo = ctrl(E, T, X, t, x)

    if not redo:
        state = _, (T,X), (H,k), (i+1,r+1) # retry
    else:
        state = _, (T,X), (H,k), (i+1,0)   # continue

    return state

Cache and Dense Output

An xaj.core.Cache stores the solution of the system of ODEs at full step or dense outputs. They are implemented as pytrees.

Pacer

pace (verb): walk at a steady and consistent speed, especially back and forth and as an expression of one’s anxiety or annoyance.

An xaj.core.Pacer combines Stepper, ErrorEstimator, StepController and States to integrate the system of ODEs one step toward the required direction.

Treker

Trek (verb): go on a long arduous journey, typically on foot.

An xaj.core.Treker loops through multiple paces to integrate the system of ODEs. Given that the adaptive stepsize controller may provide a stepsize larger than the target, Treker may integrate beyond the target, and then return the target value by interpolation using dense output.

In addition, because XAJ allows for integrating backward, Treker needs to keep both the lower and upper limits of the independent parameter. When a solution is first set up, we have treker.lower = trek.upper = t0. If we ask for a solution at t > t0, we integrate forward in time and move treker.upper >= t. If we ask for a solution at t < t0, we integrate backward in time and move treker.lower <= t. If another t is asked at a later time while treker.lower <= t <= treker.upper, then no actual integration will be done. The value of x(t) will simply be given by the dense output.

Solution

An xaj.core.Solution is a pytree with a __call__() method. It keeps track of the internal states and cache of numerical solutions to a system of ODEs. Because it is a pytree, it supports JAX transformations such as jit and vmap.