Numerical Methods¶
This page documents the ODE formulation, solver configuration, step-size control, event detection, and automatic differentiation strategy used by the v2 JAX/diffrax backend. For the physical model equations see Scientific Description; for practical usage see the API reference.
The ODE system¶
The parcel model integrates a system of \(7 + n_r\) ordinary differential equations, where \(n_r\) is the total number of aerosol size bins across all species. The state vector is
with the bulk parcel equations (altitude \(z\), pressure \(P\), temperature \(T\), water
mixing ratios \(w_v\)/\(w_c\)/\(w_i\), supersaturation \(S\)) driving per-bin droplet
growth equations for the wet radii \(r_i\). The right-hand side
\(\mathbf{f}(t, \mathbf{y}, \boldsymbol{\theta})\) — parameterised by aerosol
properties \(\boldsymbol{\theta} = (\mathbf{r}_\text{dry}, \mathbf{N}_i,
\boldsymbol{\kappa}, \alpha_c, V)\) — is implemented as a single JAX-native
array operation in pyrcel.parcel_aux and is jax.jit-compiled on first call.
The bulk state equations follow [Nenes2001] and [Ghan2011]; see Scientific Description for the full derivation. The per-bin growth equation is
where \(D_v'\) and \(k_a'\) are the non-continuum-corrected vapour diffusivity and thermal conductivity [Seinfeld2006], and \(S_{\mathrm{eq},i}\) is the \(\kappa\)-Köhler equilibrium supersaturation [Petters2007].
Solver selection¶
Why Kvaerno5 (ESDIRK-5/4)?¶
The parcel ODE is mildly stiff near activation, where the rapid growth of
critical-size droplets drives a fast \(r_i\) relaxation that is coupled back to the
supersaturation tendency \(dS/dt\). The v1 backend used SUNDIALS CVode with a BDF
scheme of order up to 5 and an internally adaptive Newton iteration. The v2
backend uses diffrax.Kvaerno5 [Kvaerno2004], an Explicit Singly Diagonally
Implicit Runge–Kutta (ESDIRK) method of order 5 with an embedded order-4
error estimator.
Kvaerno5 is A-stable and L-stable ("stiffly accurate"), making it robust
for the stiff components of the system without requiring a variable-order BDF
controller. The algebraic stage equations are solved with diffrax's default
VeryChord Newton root-finder, which inherits the same rtol/atol vector
used by the step-size controller — there is no separate Newton tolerance to tune.
For a brief comparison with the legacy solver:
| Property | v1 (CVode BDF) | v2 (Kvaerno5) |
|---|---|---|
| Method family | Linear multistep (BDF) | Implicit Runge–Kutta (ESDIRK) |
| Maximum order | 5 (adaptive) | 5 (fixed) |
| A-stability | ✓ (BDF ≥ order 3 weakly) | ✓ (all stages) |
| L-stability | — | ✓ |
| Step-size control | Internal (SUNDIALS) | PID controller |
jit/vmap/grad |
— | ✓ |
The accuracy difference between the two solvers, at matching tolerances, is within the physical uncertainty of the model; see Numerical accuracy below.
When to consider a different solver¶
For the vast majority of parcel simulations Kvaerno5 with the default
tolerances is the right choice. Two situations where you might deviate:
- Very high aerosol loading (\(N \gtrsim 10^4\,\text{cm}^{-3}\) with many
bins): the stiffness ratio grows, and reducing
rtolto \(10^{-8}\) andatolfor radii to \(10^{-14}\) m may be needed to suppress spurious oscillation near activation. - Coarse output cadence / speed over accuracy:
diffrax.Tsit5(explicit Runge–Kutta 5/4, non-stiff) runs faster at reduced accuracy when the simulation is below the cloud base where the ODE is effectively non-stiff. In practice the adaptive step-size controller already handles this automatically withKvaerno5, so switching solvers is rarely necessary. - Adjoint failure at extreme conditions: when
jax.gradis called at very high updraft speeds (\(V \gtrsim 2\) m/s) combined with very small median radii (\(\mu \lesssim 0.03\,\mu\)m), the implicit solver's Newton step Jacobian can become nearly singular, causing the adjoint backward pass to raiseEquinoxRuntimeError: A linear solver returned non-finite output. See Adjoint conditioning failures below for diagnosis and workarounds.
Solver tolerances¶
The default tolerances match the CVode configuration from the legacy model:
The mixed absolute tolerance vector reflects the disparate scales of the state
variables: mixing ratios near \(10^{-5}\) kg/kg demand tighter absolute control
than temperature or pressure, and wet radii near \(10^{-8}\) m demand even tighter.
The atol for radii (\(10^{-12}\) m) is deliberately more stringent than the
initial equilibrium tolerance so that the radius trajectory does not accumulate
truncation error during the ascent.
You can override tolerances per-call through the integrator functions in
pyrcel.integrator:
from pyrcel.integrator import integrate_parcel, atol_vector
import jax.numpy as jnp
# Tighten radius tolerance for high-concentration runs
custom_atol = jnp.concatenate([
jnp.array([1e-4, 1e-4, 1e-4, 1e-10, 1e-10, 1e-4, 1e-8]),
jnp.full(nr, 1e-14), # tighter radius atol
])
sol = integrate_parcel(y0, args, ts, rtol=1e-8, atol=custom_atol)
Step-size control¶
The v2 backend uses diffrax.PIDController with proportional, integral, and
derivative terms acting on the local error estimate [Soderlind2003]:
where \(e_n = \|\mathbf{y}_n^\text{err}\|_\text{rms}\) is the normalised local
error. The default diffrax PID coefficients are \(k_I = 0.3/q\), \(k_P = 0.4/q\),
\(k_D = 0\) (pure PI), where \(q = \min(\text{order}_\text{sol},
\text{order}_\text{err}) + 1 = 5\). This is the standard Gustafsson PI controller
widely used in stiff ODE software.
The dtmax parameter (forwarded to PIDController) can be used to cap the
maximum internal step size, which is occasionally useful for ensuring adequate
density of the trajectory near activation:
Dense output and trajectory storage¶
Rather than integrating in fixed-size chunks and manually accumulating output
(the v1 solver_dt/output_dt loop), the v2 backend uses diffrax's
SaveAt(ts=...) facility. A single adaptive solve covers the entire time
interval \([0, t_\text{end}]\); the solution is dense-interpolated at the
requested output times using the solver's built-in continuous extension. This
removes the chunking bookkeeping entirely and reduces the number of JAX
compilations to one per unique (nr, n_output) configuration.
For the nd_from_parcel diagnostic, the backend uses SaveAt(t1=True) to
retain only the final state, minimising the memory footprint of the adjoint
checkpoints. max_supersaturation uses SaveAt(ts=ts) — the same kernel as
integrate_parcel — so no second JIT compilation is triggered for gradient
workflows.
The live=True mode on ParcelModel.run() is the one exception to single-call
integration: it deliberately re-introduces a Python chunk loop to print per-step
diagnostics, at the cost of JIT overhead per chunk and incompatibility with
jax.grad. See the migration guide for the full display-option
comparison (progress, live, trajectory_table).
For interactive runs without early termination (terminate=False), \(S_\text{max}\)
is located post-hoc from the saved trajectory rather than by a second ODE solve.
The interactive layer constructs cubic Hermite polynomials over the two
sub-intervals bracketing the discrete argmax, using parcel_ode_sys to supply
the exact endpoint derivatives \(\dot{S}\) — the same polynomial family diffrax
stores per solver step when SaveAt(dense=True) is used — and finds the
analytic maximum by solving the resulting quadratic \(dp/du = 0\). The cost is
three RHS evaluations; the accuracy is O(\(\Delta t^4\)) where \(\Delta t\) is the
output spacing.
Supersaturation-maximum event detection¶
The parcel model's termination criterion is defined by the supersaturation
maximum: the point at which \(dS/dt\) changes sign from positive to negative,
marking the end of primary droplet activation. After this point the model is
integrated a further terminate_depth metres before stopping.
Event detection is implemented as a continuous-root-finding solve in
diffrax.Event:
def dS_dt_event(t, y, args, **kwargs):
return parcel_ode_sys(t, y, args)[6] # the S component of dy/dt
event = dfx.Event(dS_dt_event,
root_finder=optx.Newton(rtol=1e-8, atol=1e-12),
direction=False) # trigger only on a downward zero-crossing
The direction=False flag ensures that only the downward zero-crossing
(supersaturation maximum) triggers the event, not the initial upward sweep.
The root is located precisely by Newton iteration to tolerances of \(10^{-8}\)
(relative) and \(10^{-12}\) (absolute), giving \(t_\text{smax}\) accurate to
\(\lesssim 10^{-6}\) s.
Why the event-based solve is not differentiable¶
The diffrax.Event mechanism makes the stop time \(t_\text{smax}\) a
data-dependent function of the initial conditions: JAX cannot differentiate
through a data-dependent control-flow decision. As a result,
integrate_parcel_terminated (the interactive path) is not suitable for
jax.grad. For gradient-based workflows, use the fixed-horizon functions
instead:
# Differentiable: Hermite cubic peak finder (exact via envelope theorem)
from pyrcel.integrator import max_supersaturation
smax = max_supersaturation(y0, args, ts)
grad_smax = jax.grad(max_supersaturation, argnums=(0, 1))(y0, args, ts)
# Differentiable: fixed t_end, soft Nd via sigmoid threshold
from pyrcel.integrator import nd_from_parcel
nd = nd_from_parcel(y0, args, t_end, epsilon=1e-8)
grad_nd = jax.grad(nd_from_parcel, argnums=(0, 1))(y0, args, t_end)
The ts array passed to max_supersaturation must span the supersaturation
peak; see Differentiable \(S_\text{max}\)
below for guidance on sizing and scaling ts.
Differentiable \(S_\text{max}\): Hermite cubic interpolation¶
Computing a gradient-accurate \(S_\text{max}\) requires more than returning
jnp.max(sol.ys[:, 6]). The adaptive solver takes slightly different internal
step structures for slightly different parameter values (especially updraft
speed \(V\)). With a discrete output grid, the argmax occasionally jumps by one
time step as \(V\) varies, producing an aliased, non-monotone \(S_\text{max}(V)\)
curve with alternating-sign gradient kinks of \(\sim 10^{-3}\) relative
amplitude.
max_supersaturation eliminates the aliasing via analytic peak-finding on a
cubic Hermite interpolant constructed from the discrete trajectory returned
by _solve (the same kernel as integrate_parcel). No second dense-output
solve is required.
Stage 1 — coarse bracket¶
The argmax of the saved discrete supersaturation values \(\{S_j\}\) locates the approximate peak bin:
\(i^\ast\) is stop_gradient'd and used solely to define the two sub-intervals
\([t_{i^\ast - 1},\, t_{i^\ast}]\) and \([t_{i^\ast},\, t_{i^\ast + 1}]\).
Stage 2 — endpoint derivatives¶
Three evaluations of the ODE right-hand side at the bracket points supply exact endpoint slopes:
where \(\dot{S}_j = \mathbf{f}(t_j, \mathbf{y}_j, \boldsymbol{\theta})[6]\).
These are the same Hermite coefficients diffrax stores per step internally when
SaveAt(dense=True) is used.
Stage 3 — analytic quadratic solve¶
On each sub-interval \([t_0, t_1]\) with \(h = t_1 - t_0\), the Hermite cubic in the normalised coordinate \(u = (t - t_0)/h \in [0,1]\) is
Setting \(dp/du = 0\) yields a quadratic in \(u\):
The root in \((0,1)\) at which \(dp^2/du^2 < 0\) is the interior maximum. If both
sub-intervals have an interior maximum, the larger wins. All root-finding
operates under jax.lax.stop_gradient (including the discriminant and the
maximum-of-two selection).
Why stop_gradient on \(u^\ast\) is exact¶
Treating the peak parameter \(u^\ast\) as a non-differentiable constant is not an approximation. By the envelope theorem, at a smooth interior maximum
the implicit dependence of \(u^\ast\) on \(\theta\) drops out identically. Gradient flows through \(p(u^\ast)\) via the endpoint values \(S_{i^\ast-1},\, S_{i^\ast},\, S_{i^\ast+1}\) (ODE adjoint) and the endpoint derivatives \(\dot{S}_{i^\ast-1},\, \dot{S}_{i^\ast},\, \dot{S}_{i^\ast+1}\) (ODE RHS at those points).
Output-grid sizing and \(t_\text{end}\) scaling¶
ts is used directly for both the ODE output grid and the coarse bracket,
so it must be dense enough to contain the supersaturation peak.
Rule 1 — ensure the peak lies in the interior. ts[-1] must extend past
\(S_\text{max}\) with at least one grid point beyond the peak. For a parcel
ascending from cloud base, the peak occurs at altitude
\(z \approx z_0 + V \cdot t_\text{peak}\). A robust horizon choice is
which guarantees the solver reaches \(\sim 1500\) m above cloud base regardless of updraft speed.
Rule 2 — fix the array shape across \(V\) for JIT reuse. All calls with the
same ts.shape[0] reuse the same compiled XLA kernel. Scale only t_end with
\(V\), keeping the number of output points constant:
_N_TS = 600 # fixed; triggers one JIT compilation
def ts_for_V(V: float) -> jnp.ndarray:
t_end = max(200.0, min(1500.0 / V, 15000.0))
return jnp.linspace(0.0, t_end, _N_TS)
Rule 3 — t_end(V) does not bias the adjoint. Because \(S_\text{max}\) is
at an interior maximum, the envelope theorem also gives
\(\partial S_\text{max}/\partial t_\text{end} = 0\), so the \(t_\text{end} \propto
1/V\) coupling introduces no gradient error.
Differentiating through the ODE: RecursiveCheckpointAdjoint¶
Reverse-mode differentiation through an ODE (the adjoint method) requires
storing or recomputing intermediate trajectory values during the backward pass.
diffrax implements this via its adjoint argument to diffeqsolve; the
default in diffrax ≥ 0.4 is RecursiveCheckpointAdjoint [Chen2018]
[Kidger2021].
RecursiveCheckpointAdjoint uses a recursive binary tree of checkpoints along
the forward trajectory: the trajectory is divided in half, and each half is
re-solved during the backward pass from the stored midpoint. This gives memory
complexity \(O(\log N_\text{steps})\) in contrast to the \(O(N_\text{steps})\)
cost of storing the full forward trajectory. No configuration is needed — the
default is in force whenever _solve is called via
max_supersaturation or nd_from_parcel.
For the adjoint ODE the backward pass computes \(\partial \mathcal{L}/\partial \mathbf{y}_0\) and \(\partial \mathcal{L}/\partial \boldsymbol{\theta}\) by integrating the adjoint equations backward in time
where \(\boldsymbol{\lambda}(T) = \partial \mathcal{L}/\partial \mathbf{y}(T)\)
is the terminal condition. Because parcel_ode_sys is fully JAX-traceable,
diffrax constructs the Jacobian-vector products \(\boldsymbol{\lambda}^\top
(\partial \mathbf{f}/\partial \mathbf{y})\) with jax.vjp automatically —
no hand-coded adjoint is required.
Adjoint conditioning failures¶
At extreme parameter combinations — particularly high updraft speeds
(\(V \gtrsim 2\) m/s) paired with very small particles (\(\mu \lesssim 0.03\,\mu\)m)
— the parcel ODE develops very rapid dynamics near activation. The implicit
solver must then take very stiff Newton steps whose Jacobians can be
nearly singular. RecursiveCheckpointAdjoint backpropagates through these
discrete steps, and the linear solve in the adjoint inherits that
ill-conditioning, eventually producing NaN or inf and raising:
This is a numerical issue with the discrete adjoint of a stiff implicit solver, not a bug in the model. The forward pass succeeds at the same point; only the gradient computation fails.
Mitigations (roughly in order of invasiveness):
-
Tighten forward tolerances: smaller
rtol/atolforce shorter steps, keeping each step's Jacobian better-conditioned. Tryrtol=1e-9and radiusatol=1e-14. The adjoint then inherits a smoother trajectory. -
Use
BacksolveAdjoint: replaces the discrete adjoint with a continuous adjoint ODE solved backwards in time (Pontryagin), which avoids differentiating through the Newton iterations entirely. More numerically stable in this regime but can accumulate drift relative to the true gradient for very stiff problems. Pass viadiffeqsolve(... adjoint=BacksolveAdjoint(...)). -
Lower-order implicit solver:
diffrax.Kvaerno3ordiffrax.Eulerhave simpler per-step Jacobian structure and can improve conditioning at the cost of accuracy. -
Numerical fallback: for parameter-space sweeps where exact gradients at isolated grid points fail,
numpy.gradienton the pre-computed \(S_\text{max}\) grid provides a reliable finite-difference approximation. The Sensitivity Sweep example demonstrates this pattern.
MBN2014 activation: gradient via the Implicit Function Theorem¶
The Morales Betancourt & Nenes (2014) [MBN2014] activation parameterization finds \(S_\text{max}\) as the root of a balance equation
where the zero-crossing of \(F\) depends on the aerosol properties \(\boldsymbol{\theta} = (\mathbf{N}, \boldsymbol{\mu}, \boldsymbol{\sigma}, \boldsymbol{\kappa}, V, T, P, \alpha_c)\). The root is located by bisection, which is not differentiable by default (the bisection loop contains a data-dependent branch).
v2 implements the gradient via the Implicit Function Theorem (IFT): at the root \(F = 0\), total differentiation gives
This is implemented using jax.custom_vjp in pyrcel.activation._mbn2014. The
forward pass runs the bisection to find \(S_\text{max}\); the custom backward
pass evaluates the IFT formula above using jax.grad(F), which is
\(O(1)\) in memory and independent of the number of bisection iterations.
The two partial derivatives needed are
both evaluated at the converged root. Because \(F\) is an analytic function of
all its arguments, both partial derivatives are computed exactly by JAX's
forward-mode AD (jax.jacfwd).
Differentiable activated droplet number: sigmoid soft threshold¶
The hard-threshold Nd diagnostic (model.run(mode="nd")) counts bins where the
final wet radius exceeds the Köhler critical radius:
The Heaviside indicator \(\mathbf{1}[\cdot]\) has zero gradient almost everywhere,
making this non-differentiable. The nd_from_parcel function replaces the
indicator with a logistic sigmoid:
where \(\varepsilon\) (default \(10^{-8}\) m \(\approx 10\) nm) controls the sharpness of the threshold. In the limit \(\varepsilon \to 0\), \(N_d^\text{soft} \to N_d^\text{hard}\) (using the approximate Köhler formula); in practice \(|N_d^\text{soft} - N_d^\text{hard}|/N_d^\text{hard} \lesssim 10\%\) at \(\varepsilon = 10^{-8}\) m, with the residual discrepancy arising from the approximate vs. exact Köhler critical radius.
The critical radius is computed with the approximate Köhler formula [Petters2007]:
This form is JAX-traceable and vectorises over all bins in a single operation,
unlike the exact fminbound solve used by binned_activation.
Gradients of \(N_d^\text{soft}\) with respect to \((\mathbf{y}_0,
\boldsymbol{\theta})\) flow through both the ODE trajectory (via
RecursiveCheckpointAdjoint) and the Köhler computation:
from pyrcel.integrator import nd_from_parcel
grad_fn = jax.grad(nd_from_parcel, argnums=(0, 1))
grad_y0, grad_args = grad_fn(y0, args, t_end, epsilon=1e-8)
dNd_dV = grad_args[4].V # gradient w.r.t. updraft speed (ConstantV module)
Low-level integrator API¶
ParcelModel.run() is the recommended entry point, but the underlying functions
are public and fully composable for advanced workflows — building a custom
optimization loop, wrapping the ODE in an external framework, or constructing
partial pipelines with jax.vmap.
Equilibration¶
from pyrcel.equilibrate import equilibrate_initial_state
y0 = equilibrate_initial_state(T0, S0, P0, r_drys, kappas, Nis)
# y0: shape (7 + nr,) — [z, P, T, wv, wc, wi, S, r_0, ..., r_{nr-1}]
equilibrate_initial_state computes the equilibrium wet-radius spectrum at the
initial conditions and assembles the full (7 + nr)-element state vector. It is
the direct equivalent of ParcelModel._setup_run and is itself JAX-traceable
(uses optimistix bisection).
Core solve¶
| Function | Returns | When to use |
|---|---|---|
integrate_parcel(y0, args, ts) |
diffrax.Solution |
Full solution object with sol.ts, sol.ys, sol.result |
integrate_parcel_arrays(y0, args, ts) |
(ts, ys, success) |
Raw arrays; avoids unpacking the diffrax object |
from pyrcel.integrator import integrate_parcel_arrays
ts_out = jnp.linspace(0.0, t_end, n_out)
ts, ys, ok = integrate_parcel_arrays(y0, args, ts_out)
# ys: shape (n_out, 7 + nr)
args is the tuple (r_drys, Nis, kappas, accom, V) where V is a
ConstantV or InterpolatedUpdraft.
Differentiable diagnostics¶
| Function | Returns | Notes |
|---|---|---|
max_supersaturation(y0, args, ts) |
float |
\(S_\text{max}\) via Hermite cubic; jax.grad-able |
nd_from_parcel(y0, args, t_end) |
float |
Soft-threshold \(N_d\) (m⁻³); jax.grad-able |
Both are designed for the differentiable path and documented in detail in the adjoint and sigmoid Nd sections above.
Terminated-run pipeline¶
When terminate=True the interactive path uses a three-function pipeline:
from pyrcel.integrator import (
find_smax,
terminate_cutoff_time,
integrate_parcel_terminated,
)
# 1. Precisely localize S_max via a dS/dt = 0 event
t_smax, smax, y_smax, activated = find_smax(y0, args, t_end)
# 2. Compute the altitude-based cutoff time (terminate_depth m past S_max)
t_cut = terminate_cutoff_time(
y0, args, t_end,
terminate_depth=10.0, t_smax=t_smax, y_smax=y_smax, activated=activated,
)
# 3. Integrate to the cutoff and return the full trajectory
ts, ys = integrate_parcel_terminated(y0, args, t_cut, output_dt=1.0)
find_smax is not differentiable (event detection is discontinuous; see
the event section). Use
max_supersaturation on the differentiable path.
Tolerance constants¶
The module-level constants STATE_RTOL, STATE_ATOL, and RADIUS_ATOL are
importable and match the CVode configuration from the legacy model:
from pyrcel.integrator import STATE_RTOL, atol_vector
atol = atol_vector(nr) # per-component vector matching legacy CVode setup
Performance tips¶
JIT compilation warm-up¶
The first call to any jax.jit-compiled function traces and compiles the
computation graph to XLA HLO. For a 200-bin parcel simulation this takes
\(\sim 10\)–\(30\) s on a modern CPU. Subsequent calls with the same shapes are
fast (typically \(< 1\) s). To amortise this cost:
- Call the function once with representative inputs before timing production runs.
- Use
jax.block_until_readyto ensure the first call completes before starting a timer.
import jax
from pyrcel.integrator import max_supersaturation
# Warm up the JIT cache
_ = jax.block_until_ready(max_supersaturation(y0, args, ts))
# Now time the production call
import time
t0 = time.perf_counter()
smax = float(jax.block_until_ready(max_supersaturation(y0, args, ts)))
print(f"S_max = {smax:.5f} [{time.perf_counter()-t0:.3f}s]")
jax.vmap for ensemble runs¶
For a batch of \(n\) parcel simulations that share the same ODE structure but
differ in initial conditions or parameters (e.g., an updraft-speed ensemble),
jax.vmap compiles a single batched kernel that runs all \(n\) members in parallel:
import jax
# Batch over updraft speeds
V_samples = jnp.linspace(0.1, 3.0, 128)
def single_run(V):
new_args = (r_drys, Nis, kappas, accom, pm.ConstantV(V))
return max_supersaturation(y0, new_args, ts)
smax_batch = jax.jit(jax.vmap(single_run))(V_samples) # shape (128,)
This is substantially faster than a Python for loop: the XLA compiler can fuse
the batch dimension with the ODE internal loop and exploit SIMD or GPU
parallelism. The run_updraft_ensemble convenience function does exactly this:
result = pm.run_updraft_ensemble(
[aerosol], T0=283.15, S0=0.0, P0=85000.0,
mean=0.5, std=0.2, n=1024
)
Gradient computation and jax.jacfwd¶
For scalar outputs (max_supersaturation, nd_from_parcel), jax.grad is the
most efficient gradient computation. For vector outputs or Jacobian matrices,
jax.jacfwd (forward-mode) is usually faster when the number of inputs is small
relative to outputs, while jax.jacrev (reverse-mode) is faster in the opposite
regime:
# Gradient of S_max w.r.t. a vector of N_i (nr inputs, 1 output → jacrev)
grad_N = jax.jacrev(lambda N: max_supersaturation(
y0, (r_drys, N, kappas, accom, V), ts
))(Nis)
GPU dispatch¶
Pass device="gpu" to ParcelModel or wrap integrator calls in a
jax.default_device context:
For ensemble runs the GPU backend provides the largest speedups: a 1024-member ensemble that takes \(\sim 60\) s on 8 CPU cores typically completes in \(< 5\) s on a single A100.
Numerical accuracy¶
The v2 solver at default tolerances agrees with the v1 CVode baseline to within:
| Quantity | Typical discrepancy |
|---|---|
| \(S_\text{max}\) | \(< 10^{-5}\) (absolute) |
| Activated fraction | \(\lesssim 0.1\%\) (relative) |
| Droplet number \(N_d\) | \(\lesssim 1\%\) (relative) |
| Trajectory \(r_i(t)\) | \(< 10^{-10}\) m (absolute) |
These values are confirmed by the golden-data regression test suite in
tests/test_golden.py and tests/test_integration.py, which compare v2 output
against frozen v1 reference trajectories across a range of aerosol configurations
and updraft speeds.
References¶
[Nenes2001] Nenes, A., Ghan, S., Abdul-Razzak, H., Chuang, P. Y., & Seinfeld, J. H. (2001). Kinetic limitations on cloud droplet formation and impact on cloud albedo. Tellus B, 53(2), 133–149. https://doi.org/10.1034/j.1600-0889.2001.d01-12.x
[Ghan2011] Ghan, S. J., et al. (2011). Droplet nucleation: Physically-based parameterizations and comparative evaluation. J. Adv. Model. Earth Syst., 3, M10001. https://doi.org/10.1029/2011MS000074
[Seinfeld2006] Seinfeld, J. H., & Pandis, S. N. (2006). Atmospheric Chemistry and Physics: From Air Pollution to Climate Change (2nd ed.). Wiley.
[Petters2007] Petters, M. D., & Kreidenweis, S. M. (2007). A single parameter representation of hygroscopic growth and cloud condensation nucleus activity. Atmos. Chem. Phys., 7, 1961–1971. https://doi.org/10.5194/acp-7-1961-2007
[Kvaerno2004] Kvaernø, A. (2004). Singly diagonally implicit Runge–Kutta methods with an explicit first stage. BIT Numerical Mathematics, 44(3), 489–502. https://doi.org/10.1023/B:BITN.0000046811.70614.38
[Soderlind2003] Söderlind, G. (2003). Digital filters in adaptive time-stepping. ACM Trans. Math. Softw., 29(1), 1–26. https://doi.org/10.1145/641876.641877
[Chen2018] Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural ordinary differential equations. NeurIPS, 6571–6583. https://arxiv.org/abs/1806.07366
[Kidger2021] Kidger, P. (2021). On Neural Differential Equations (Ph.D. thesis). University of Oxford. https://arxiv.org/abs/2202.02435
[MBN2014] Morales Betancourt, R., & Nenes, A. (2014). Droplet activation parameterization: the population-splitting concept revisited. Geosci. Model Dev., 7, 2345–2357. https://doi.org/10.5194/gmd-7-2345-2014
[Rothenberg2016] Rothenberg, D., & Wang, C. (2016). Metamodeling of droplet activation for global climate models. J. Atmos. Sci., 73(4), 1255–1272. https://doi.org/10.1175/JAS-D-15-0223.1