Implementation ============== To better implement ``JAX``'s :doc:`interface ` and :doc:`design patterns `, we need to break its numerical solutions into multiple components. Stepper ------- An ``xaj.core.Stepper`` advances the state of a system of ODEs to a different state. It is implemented as a pure function using python function closure, and is composible with standard ``JAX`` transformations such as ``jit`` and ``vmap``. Error Estimator --------------- An ``xaj.core.ErrorEstimator`` estimates the numerical error of a step and is used in adaptive stepsize control. It may be based on step-doubling or embedding technique, or can be a custom function based on, e.g., some conserved quantities in the system of ODEs. It is implemented as a pure function using python function closure, and is composible with standard ``JAX`` transformations such as ``jit`` and ``vmap``. Stepsize Controller ------------------- An ``xaj.core.StepController`` adjusts the step size based on the error estimator and/or custom function. While it may be more natural to implement it as a stateful object, we choose not to do it to match better with the ``JAX`` ecosystem. It is implemented as a pure function using python function closure, and is composible with standard ``JAX`` transformations such as ``jit`` and ``vmap``. Stepping Engine --------------- We use ``jax.lax.while_loop`` to implement the :ref:`sec_pattern_stepping-engine`, which has the semantics .. code-block:: python3 def while_loop(cond, body, state): while cond(state): state = body(state) return state Because the functions ``cond`` and ``body`` are pure, we need to pass to them the full state, which includes the current step size, the current solution of the ODEs, the shared states across steps, etc. The body function can be logically implement with the following code. .. code-block:: python3 def body(state): _, (t,x), (h,k), (i,r) = state E, T, X, K = step(h, t, x, k) H, redo = ctrl(E, T, X, t, x) if not redo: state = _, (T,X), (H,k), (i+1,r+1) # retry else: state = _, (T,X), (H,k), (i+1,0) # continue return state Cache and Dense Output ---------------------- An ``xaj.core.Cache`` stores the solution of the system of ODEs at full step or dense outputs. They are implemented as pytrees. Pacer ----- pace (*verb*): walk at a steady and consistent speed, especially back and forth and as an expression of one's anxiety or annoyance. An ``xaj.core.Pacer`` combines ``Stepper``, ``ErrorEstimator``, ``StepController`` and ``States`` to integrate the system of ODEs one step toward the required direction. Treker ------ Trek (*verb*): go on a long arduous journey, typically on foot. An ``xaj.core.Treker`` loops through multiple paces to integrate the system of ODEs. Given that the adaptive stepsize controller may provide a stepsize larger than the target, `Treker` may integrate *beyond* the target, and then return the target value by interpolation using dense output. In addition, because ``XAJ`` allows for integrating backward, ``Treker`` needs to keep both the lower and upper limits of the independent parameter. When a solution is first set up, we have :code:`treker.lower = trek.upper = t0`. If we ask for a solution at :code:`t > t0`, we integrate forward in time and move :code:`treker.upper >= t`. If we ask for a solution at :code:`t < t0`, we integrate backward in time and move :code:`treker.lower <= t`. If another ``t`` is asked at a later time while :code:`treker.lower <= t <= treker.upper`, then no actual integration will be done. The value of ``x(t)`` will simply be given by the dense output. Solution -------- An ``xaj.core.Solution`` is a pytree with a ``__call__()`` method. It keeps track of the internal states and cache of numerical solutions to a system of ODEs. Because it is a pytree, it supports ``JAX`` transformations such as ``jit`` and ``vmap``.