Integrator¶
Low-level ODE integration functions. These are the building blocks used by
ParcelModel.run() and are public for advanced workflows: custom optimization
loops, partial pipelines, or composing with external JAX transforms.
All functions are JAX-traceable and compatible with jax.grad / jax.vmap
unless noted otherwise.
When to use these directly
Use the integrator functions directly when you need:
- Gradients through the ODE (use
max_supersaturationornd_from_parcel) - A batched
vmapkernel without theParcelModeloverhead - The raw
diffrax.Solutionobject - Fine-grained control over
ts, tolerances, or step limits
Core solve¶
integrate_parcel ¶
integrate_parcel(
y0: ArrayLike,
args: tuple,
ts: ArrayLike,
*,
rtol: float = STATE_RTOL,
atol: ArrayLike | None = None,
dtmax: float | None = None,
max_steps: int = 100000,
progress_meter: AbstractProgressMeter | None = None,
) -> dfx.Solution
Integrate the parcel ODE and return the full diffrax solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial (equilibrated) state vector. |
required |
args
|
tuple
|
|
required |
ts
|
array, shape ``(n_out,)``
|
Output times (monotonic). Output is dense-interpolated at these times from a
single adaptive solve; |
required |
rtol
|
float / array
|
Solver tolerances. |
STATE_RTOL
|
atol
|
float / array
|
Solver tolerances. |
STATE_RTOL
|
dtmax
|
float
|
Maximum internal step. |
None
|
max_steps
|
int
|
Upper bound on adaptive steps (must be finite under |
100000
|
progress_meter
|
AbstractProgressMeter
|
Live progress reporting (e.g. |
None
|
Returns:
| Type | Description |
|---|---|
Solution
|
|
Source code in pyrcel/integrator.py
integrate_parcel_arrays ¶
integrate_parcel_arrays(
y0: ArrayLike, args: tuple, ts: ArrayLike, **kwargs: Any
) -> tuple[Array, Array, bool]
Integrate the parcel ODE and return raw arrays.
Convenience wrapper around integrate_parcel that unpacks
the
diffrax solution object into plain arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial (equilibrated) state vector. |
required |
args
|
tuple
|
|
required |
ts
|
array, shape ``(n_out,)``
|
Output times (monotonic); see integrate_parcel. |
required |
**kwargs
|
Any
|
Forwarded to integrate_parcel. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
ts |
array
|
Output times. |
ys |
array, shape ``(n_out, 7 + nr)``
|
State trajectory. |
success |
bool
|
Whether the solve completed successfully. |
Source code in pyrcel/integrator.py
Differentiable diagnostics¶
max_supersaturation ¶
max_supersaturation(
y0: ArrayLike,
args: tuple,
ts: ArrayLike,
*,
rtol: float = STATE_RTOL,
atol: ArrayLike | None = None,
max_steps: int = 100000,
) -> Array
Peak supersaturation via Hermite cubic interpolation of the saved trajectory.
Solves the parcel ODE with SaveAt(ts=ts) (the same kernel as
integrate_parcel), then locates the supersaturation peak using cubic Hermite
interpolation over the two sub-intervals bracketing the coarse argmax. The ODE
right-hand side supplies exact endpoint derivatives — the same polynomial family
diffrax uses for SaveAt(dense=True) — applied analytically to the saved
discrete output.
The peak is found by:
- Argmax of
sol.ys[:, S_IDX]→ coarse bracket index (stop-gradient'd). - Three
parcel_ode_sysevaluations at the three bracket points (i-1,i,i+1) for exact endpoint derivativesdS/dt. - Analytic solution of
dp/du = 0(a quadratic in the normalised parameteru) over each sub-interval; the larger interior maximum wins. - Evaluate the Hermite polynomial at
u*(stop-gradient'd).
By the envelope theorem, d(S_max)/dV = ∂S/∂V|_{u*} — the dependence of the
peak location on V drops out — so stop_gradient on u* is exact. The
gradient flows through the diffrax adjoint via the endpoint state values and
their ODE derivatives.
Cost: one ODE solve (shared JIT kernel with integrate_parcel) + 3 ODE RHS
evaluations. No SaveAt(dense=True) second kernel; no vmap fine grid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial (equilibrated) state vector. |
required |
args
|
tuple
|
|
required |
ts
|
array, shape ``(n_out,)``
|
Output times (monotonic, must span the supersaturation peak). |
required |
rtol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
atol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
max_steps
|
int
|
Upper bound on adaptive steps. |
100000
|
Returns:
| Type | Description |
|---|---|
float
|
Maximum supersaturation |
Source code in pyrcel/integrator.py
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 | |
nd_from_parcel ¶
nd_from_parcel(
y0: ArrayLike,
args: tuple,
t_end: float,
*,
epsilon: float = 1e-08,
rtol: float = STATE_RTOL,
atol: ArrayLike | None = None,
max_steps: int = 100000,
) -> Array
Differentiable activated droplet number concentration via sigmoid soft threshold.
Integrates the parcel ODE to t_end and evaluates:
where \(\sigma\) is the logistic sigmoid, \(r_{\text{crit},i}\) is the approximate Köhler critical radius for bin i, and \(\varepsilon\) controls the sharpness of the threshold.
Unlike the hard-threshold Nd diagnostic, this
function is fully differentiable with respect to y0 and args via
diffrax's default RecursiveCheckpointAdjoint::
dNd_dV = jax.grad(nd_from_parcel, argnums=1)(y0, args, t_end)[4].V
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial (equilibrated) state vector. |
required |
args
|
tuple
|
|
required |
t_end
|
float
|
Integration time (s). Should extend past \(S_{\text{max}}\) so that kinetically limited droplets have time to grow past their critical radius. |
required |
epsilon
|
float
|
Sigmoid half-width (m). Controls the accuracy/smoothness trade-off:
smaller |
1e-08
|
rtol
|
float or array
|
Solver tolerances; defaults mirror the CVode-equivalent values. |
STATE_RTOL
|
atol
|
float or array
|
Solver tolerances; defaults mirror the CVode-equivalent values. |
STATE_RTOL
|
max_steps
|
int
|
Upper bound on adaptive ODE steps. |
100000
|
Returns:
| Name | Type | Description |
|---|---|---|
Nd_soft |
Array
|
Soft activated droplet number concentration (m⁻³) at |
Source code in pyrcel/integrator.py
Terminated-run pipeline¶
find_smax ¶
Precisely localize the supersaturation maximum via a dS/dt = 0 event.
Unlike sampling a saved trajectory, t_smax is root-found rather than
quantized to the output cadence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial (equilibrated) state vector. |
required |
args
|
tuple
|
|
required |
t_end
|
float
|
Upper bound on integration time, s. |
required |
rtol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
atol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
max_steps
|
int
|
Upper bound on adaptive steps. |
100000
|
Returns:
| Name | Type | Description |
|---|---|---|
t_smax |
float
|
Time of the supersaturation maximum, s. |
smax |
float
|
Peak supersaturation value. |
y_smax |
array, shape ``(7 + nr,)``
|
State vector at |
activated |
bool
|
|
Source code in pyrcel/integrator.py
Not differentiable
find_smax uses event detection (dS/dt = 0) which is discontinuous and
therefore not compatible with jax.grad. Use max_supersaturation on the
differentiable path.
terminate_cutoff_time ¶
terminate_cutoff_time(
y0,
args,
t_end,
*,
terminate_depth: float,
t_smax,
y_smax,
activated: bool,
rtol: float = STATE_RTOL,
atol=None,
max_steps: int = 100000,
) -> float
Compute the integration stop time for altitude-based cutoff past S_max.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y0
|
array, shape ``(7 + nr,)``
|
Initial state vector. |
required |
args
|
tuple
|
|
required |
t_end
|
float
|
Upper bound on integration time, s. |
required |
terminate_depth
|
float
|
Extra vertical distance (m) past |
required |
t_smax
|
float
|
Time of the supersaturation maximum, s. |
required |
y_smax
|
array
|
State vector at |
required |
activated
|
bool
|
If |
required |
rtol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
atol
|
float or array
|
Solver tolerances. |
STATE_RTOL
|
max_steps
|
int
|
Upper bound on adaptive steps. |
100000
|
Returns:
| Type | Description |
|---|---|
float
|
Cutoff time, s. Always |
Source code in pyrcel/integrator.py
integrate_parcel_terminated ¶
integrate_parcel_terminated(
y0,
args,
t_end,
output_dt,
*,
terminate_depth: float = 100.0,
rtol: float = STATE_RTOL,
atol=None,
max_steps: int = 100000,
progress_meter=None,
phase_timer=None,
)
Integrate, stopping terminate_depth metres past the supersaturation max.
Reproduces the master terminate=True semantics: locate S_max (via the
dS/dt event), then continue an extra terminate_depth metres
(= terminate_depth / V seconds) and stop. Output is produced at output_dt
cadence over [0, t_cutoff] from a single adaptive solve.
The cutoff is altitude-based (stop terminate_depth metres above z_smax), so
it is correct for both a constant updraft and a time-varying V(t): for constant
V the cutoff time is computed analytically, otherwise it is root-found with a
second z-crossing event.
This is the interactive / parity path: t_cutoff is data-dependent so the output
grid length is dynamic (the outer call runs eagerly; the solves are jitted). For
jit/vmap/grad use the fixed-horizon
integrate_parcel /
find_smax /
max_supersaturation instead.
Returns (ts, ys, info) with ts/ys as numpy arrays and info a dict
of t_smax, smax, t_cutoff, activated, success.
Source code in pyrcel/integrator.py
677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 | |
Tolerance constants¶
atol_vector ¶
Build the per-component absolute tolerance vector for the ODE solver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nr
|
int
|
Number of aerosol radius bins. |
required |
Returns:
| Type | Description |
|---|---|
jnp.ndarray, shape ``(7 + nr,)``
|
Absolute tolerance vector: CVode-equivalent values for bulk state
variables followed by |
Source code in pyrcel/integrator.py
The module also exports STATE_RTOL, STATE_ATOL, and RADIUS_ATOL as
importable constants that match the CVode configuration from the legacy model.