Design Patterns¶
While auto vectorizing and parallel maps vmap
and pmap
are
straightforward for pure functions because they are embarrassingly
(i.e., naturally) parallelizable, numerical integrations are not.
Depending on initial conditions, different realizations of ODE
solution may take different stepsize.
To maintain composability with JAX
transformations, we need to
limit the scope of XAJ
, which leads to interesting design
patterns.
Stepping Engine¶
Consider a scenario that we evolve n
planets around the sun.
Because their different orbital time scales and ellipticities, after,
e.g., 4 steps, their solutions would advance to
p0 |---->|-->|--->|---->|
p1 |-->|-->|-->|-->|
p2 |->|->|>|>|
p3 |>|>|>|>|
^t0 ^t1 ^t2
where t0
is the time that we specify the initial conditions, and
t1
is the minimal final time for all the planets.
Suppose we want to evolve all planets up to t2
, the most
computationally efficient way is to change the batch size and
integrate only p2
and p3
up to t2
.
However, we do not support this in XAJ
’s core API.
This is because such a smart rebatching is inconsistent with the
assumptions in JAX
and would break composability.
Nevertheless, this does not mean we cannot do better than the above
chart.
For p0
, the changing stepsize suggests that there are multiple
trial in the integration.
Hence, a better presentation this job is probably
p0 |---->|--->|
|-->|--->|
p1 |-->|-->|-->|-->|
p2 |->|->|->|
|>|>|
p3 |>|>|>|>|
^t0 ^t1 ^t2
In this new scenario, while each particle went through 4 steps, for
p0
and p2
, only 3 steps satisfy the tolerance and contribute
to the solution.
Given the SIMD architecture in GPU and TPU, a naive nested loop
implementation would evolve the above problem in the following steps
p0 |S|T:S|S:D|E|
p1 |S|S:D|S:D|S|
p2 |S|S:D|T:S|E|
p2 |S|S:D|S:D|S|
where S
stands for (successful) stepping, T
stands for trial,
D
stands for dropped calculation, E
stands for extra, and
:
seperates the steps in an inner loop.
Although the 2 trail T
calculations in the above chart are
unavoidable, there are 6 dropped D
and 2 extra E
calculations
that do not help to achieve the required numerical solution.
One interesting observation here is that the trail and the actual successful calculations are computationally identical. Therefore, if we integrate the step controller with the driver, it is possible to fuse the two types of calculations. This will result
p0 |S|T:S|S|
p1 |S|S|S|S|
p2 |S|S|T:S|
p2 |S|S|S|S|
which saves 33% of the computation from the naive implementation; or in other words avoid 50% of wasted computation from the optimal implementation.
Dense Output¶
In XAJ
’s interface signatures, x(t, t0, x0,
aux)
can in principle be a pure function without side effect (and
internal states).
However, many adaptive integrators support dense output, where the
numerical solutions are saved as piece-wise polynomials for later
interpolation.
This is particularly efficient if we want to obtain numerical
solutions on a large number of sampling points.
How should we design XAJ
to take advantage of dense output?
A natural choice is to cache the dense outputs with x(t, t0, x0,
aux)
, and reuse the cache whenever possible.
Given that the numerical solutions in general depend on t0
,
x0
, and aux
, we can use the tuple (t0, x0, aux)
as the
cache index.
The cache should also be a
pytree
in order to be composable with other JAX
functions.