Implementation¶
To better implement JAX
’s
interface and
design patterns,
we need to break its numerical solutions into multiple components.
Stepper¶
An xaj.core.Stepper
advances the state of a system of ODEs to a
different state.
It is implemented as a pure function using python function closure,
and is composible with standard JAX
transformations such as
jit
and vmap
.
Error Estimator¶
An xaj.core.ErrorEstimator
estimates the numerical error of a step
and is used in adaptive stepsize control.
It may be based on step-doubling or embedding technique, or can be a
custom function based on, e.g., some conserved quantities in the
system of ODEs.
It is implemented as a pure function using python function closure,
and is composible with standard JAX
transformations such as
jit
and vmap
.
Stepsize Controller¶
An xaj.core.StepController
adjusts the step size based on the
error estimator and/or custom function.
While it may be more natural to implement it as a stateful object, we
choose not to do it to match better with the JAX
ecosystem.
It is implemented as a pure function using python function closure,
and is composible with standard JAX
transformations such as
jit
and vmap
.
Stepping Engine¶
We use jax.lax.while_loop
to implement the
Stepping Engine, which has the semantics
def while_loop(cond, body, state):
while cond(state):
state = body(state)
return state
Because the functions cond
and body
are pure, we need to pass
to them the full state, which includes the current step size, the
current solution of the ODEs, the shared states across steps, etc.
The body function can be logically implement with the following code.
def body(state):
_, (t,x), (h,k), (i,r) = state
E, T, X, K = step(h, t, x, k)
H, redo = ctrl(E, T, X, t, x)
if not redo:
state = _, (T,X), (H,k), (i+1,r+1) # retry
else:
state = _, (T,X), (H,k), (i+1,0) # continue
return state
Cache and Dense Output¶
An xaj.core.Cache
stores the solution of the system of ODEs at
full step or dense outputs.
They are implemented as pytrees.
Pacer¶
pace (verb): walk at a steady and consistent speed, especially back and forth and as an expression of one’s anxiety or annoyance.
An xaj.core.Pacer
combines Stepper
, ErrorEstimator
,
StepController
and States
to integrate the system of ODEs one
step toward the required direction.
Treker¶
Trek (verb): go on a long arduous journey, typically on foot.
An xaj.core.Treker
loops through multiple paces to integrate the
system of ODEs.
Given that the adaptive stepsize controller may provide a stepsize
larger than the target, Treker may integrate beyond the target,
and then return the target value by interpolation using dense output.
In addition, because XAJ
allows for integrating backward,
Treker
needs to keep both the lower and upper limits of the
independent parameter.
When a solution is first set up, we have treker.lower =
trek.upper = t0
.
If we ask for a solution at t > t0
, we integrate forward in
time and move treker.upper >= t
.
If we ask for a solution at t < t0
, we integrate backward in
time and move treker.lower <= t
.
If another t
is asked at a later time while treker.lower <=
t <= treker.upper
, then no actual integration will be done.
The value of x(t)
will simply be given by the dense output.
Solution¶
An xaj.core.Solution
is a pytree with a __call__()
method.
It keeps track of the internal states and cache of numerical solutions
to a system of ODEs.
Because it is a pytree, it supports JAX
transformations such as
jit
and vmap
.