Skip to content

Integrator

Low-level ODE integration functions. These are the building blocks used by ParcelModel.run() and are public for advanced workflows: custom optimization loops, partial pipelines, or composing with external JAX transforms.

All functions are JAX-traceable and compatible with jax.grad / jax.vmap unless noted otherwise.

When to use these directly

Use the integrator functions directly when you need:

  • Gradients through the ODE (use max_supersaturation or nd_from_parcel)
  • A batched vmap kernel without the ParcelModel overhead
  • The raw diffrax.Solution object
  • Fine-grained control over ts, tolerances, or step limits

Core solve

integrate_parcel

integrate_parcel(
    y0: ArrayLike,
    args: tuple,
    ts: ArrayLike,
    *,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    dtmax: float | None = None,
    max_steps: int = 100000,
    progress_meter: AbstractProgressMeter | None = None,
) -> dfx.Solution

Integrate the parcel ODE and return the full diffrax solution.

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial (equilibrated) state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys.

required
ts array, shape ``(n_out,)``

Output times (monotonic). Output is dense-interpolated at these times from a single adaptive solve; ts[0] is t0 and ts[-1] is t1.

required
rtol float / array

Solver tolerances. atol defaults to the per-component CVode vector.

STATE_RTOL
atol float / array

Solver tolerances. atol defaults to the per-component CVode vector.

STATE_RTOL
dtmax float

Maximum internal step. None lets the controller choose freely.

None
max_steps int

Upper bound on adaptive steps (must be finite under jit).

100000
progress_meter AbstractProgressMeter

Live progress reporting (e.g. diffrax.TextProgressMeter()) for interactive runs. Defaults to NoProgressMeter -- keep it that way under vmap/large batches and in the differentiable core to avoid per-element host syncs.

None

Returns:

Type Description
Solution

sol.ts, sol.ys (shape (n_out, 7 + nr)), sol.result.

Source code in pyrcel/integrator.py
def integrate_parcel(
    y0: ArrayLike,
    args: tuple,
    ts: ArrayLike,
    *,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    dtmax: float | None = None,
    max_steps: int = 100_000,
    progress_meter: dfx.AbstractProgressMeter | None = None,
) -> dfx.Solution:
    """Integrate the parcel ODE and return the full ``diffrax`` solution.

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial (equilibrated) state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
    ts : array, shape ``(n_out,)``
        Output times (monotonic). Output is dense-interpolated at these times from a
        single adaptive solve; ``ts[0]`` is ``t0`` and ``ts[-1]`` is ``t1``.
    rtol, atol : float / array
        Solver tolerances. ``atol`` defaults to the per-component CVode vector.
    dtmax : float, optional
        Maximum internal step. ``None`` lets the controller choose freely.
    max_steps : int
        Upper bound on adaptive steps (must be finite under ``jit``).
    progress_meter : diffrax.AbstractProgressMeter, optional
        Live progress reporting (e.g. ``diffrax.TextProgressMeter()``) for interactive
        runs. Defaults to ``NoProgressMeter`` -- keep it that way under ``vmap``/large
        batches and in the differentiable core to avoid per-element host syncs.

    Returns
    -------
    diffrax.Solution
        ``sol.ts``, ``sol.ys`` (shape ``(n_out, 7 + nr)``), ``sol.result``.
    """
    y0 = jnp.asarray(y0)
    ts = jnp.asarray(ts)
    nr = int(y0.shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)
    if progress_meter is None:
        progress_meter = dfx.NoProgressMeter()
    return _solve(y0, args, ts, rtol, atol, dtmax, max_steps, progress_meter)

integrate_parcel_arrays

integrate_parcel_arrays(
    y0: ArrayLike, args: tuple, ts: ArrayLike, **kwargs: Any
) -> tuple[Array, Array, bool]

Integrate the parcel ODE and return raw arrays.

Convenience wrapper around integrate_parcel that unpacks the diffrax solution object into plain arrays.

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial (equilibrated) state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys.

required
ts array, shape ``(n_out,)``

Output times (monotonic); see integrate_parcel.

required
**kwargs Any

Forwarded to integrate_parcel.

{}

Returns:

Name Type Description
ts array

Output times.

ys array, shape ``(n_out, 7 + nr)``

State trajectory.

success bool

Whether the solve completed successfully.

Source code in pyrcel/integrator.py
def integrate_parcel_arrays(
    y0: ArrayLike, args: tuple, ts: ArrayLike, **kwargs: Any
) -> tuple[Array, Array, bool]:
    """Integrate the parcel ODE and return raw arrays.

    Convenience wrapper around [integrate_parcel][pyrcel.integrator.integrate_parcel] that unpacks
    the
    ``diffrax`` solution object into plain arrays.

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial (equilibrated) state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
    ts : array, shape ``(n_out,)``
        Output times (monotonic); see [integrate_parcel][pyrcel.integrator.integrate_parcel].
    **kwargs
        Forwarded to [integrate_parcel][pyrcel.integrator.integrate_parcel].

    Returns
    -------
    ts : array
        Output times.
    ys : array, shape ``(n_out, 7 + nr)``
        State trajectory.
    success : bool
        Whether the solve completed successfully.
    """
    sol = integrate_parcel(y0, args, ts, **kwargs)
    success = bool(sol.result == dfx.RESULTS.successful)
    return sol.ts, sol.ys, success

Differentiable diagnostics

max_supersaturation

max_supersaturation(
    y0: ArrayLike,
    args: tuple,
    ts: ArrayLike,
    *,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    max_steps: int = 100000,
) -> Array

Peak supersaturation via Hermite cubic interpolation of the saved trajectory.

Solves the parcel ODE with SaveAt(ts=ts) (the same kernel as integrate_parcel), then locates the supersaturation peak using cubic Hermite interpolation over the two sub-intervals bracketing the coarse argmax. The ODE right-hand side supplies exact endpoint derivatives — the same polynomial family diffrax uses for SaveAt(dense=True) — applied analytically to the saved discrete output.

The peak is found by:

  1. Argmax of sol.ys[:, S_IDX] → coarse bracket index (stop-gradient'd).
  2. Three parcel_ode_sys evaluations at the three bracket points (i-1, i, i+1) for exact endpoint derivatives dS/dt.
  3. Analytic solution of dp/du = 0 (a quadratic in the normalised parameter u) over each sub-interval; the larger interior maximum wins.
  4. Evaluate the Hermite polynomial at u* (stop-gradient'd).

By the envelope theorem, d(S_max)/dV = ∂S/∂V|_{u*} — the dependence of the peak location on V drops out — so stop_gradient on u* is exact. The gradient flows through the diffrax adjoint via the endpoint state values and their ODE derivatives.

Cost: one ODE solve (shared JIT kernel with integrate_parcel) + 3 ODE RHS evaluations. No SaveAt(dense=True) second kernel; no vmap fine grid.

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial (equilibrated) state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys.

required
ts array, shape ``(n_out,)``

Output times (monotonic, must span the supersaturation peak).

required
rtol float or array

Solver tolerances.

STATE_RTOL
atol float or array

Solver tolerances.

STATE_RTOL
max_steps int

Upper bound on adaptive steps.

100000

Returns:

Type Description
float

Maximum supersaturation S_max.

Source code in pyrcel/integrator.py
def max_supersaturation(
    y0: ArrayLike,
    args: tuple,
    ts: ArrayLike,
    *,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    max_steps: int = 100_000,
) -> Array:
    """Peak supersaturation via Hermite cubic interpolation of the saved trajectory.

    Solves the parcel ODE with ``SaveAt(ts=ts)`` (the same kernel as
    ``integrate_parcel``), then locates the supersaturation peak using cubic Hermite
    interpolation over the two sub-intervals bracketing the coarse argmax.  The ODE
    right-hand side supplies exact endpoint derivatives — the same polynomial family
    diffrax uses for ``SaveAt(dense=True)`` — applied analytically to the saved
    discrete output.

    The peak is found by:

    1. Argmax of ``sol.ys[:, S_IDX]`` → coarse bracket index (stop-gradient'd).
    2. Three ``parcel_ode_sys`` evaluations at the three bracket points (``i-1``,
       ``i``, ``i+1``) for exact endpoint derivatives ``dS/dt``.
    3. Analytic solution of ``dp/du = 0`` (a quadratic in the normalised parameter
       ``u``) over each sub-interval; the larger interior maximum wins.
    4. Evaluate the Hermite polynomial at ``u*`` (stop-gradient'd).

    By the envelope theorem, ``d(S_max)/dV = ∂S/∂V|_{u*}`` — the dependence of the
    peak location on V drops out — so ``stop_gradient`` on ``u*`` is exact.  The
    gradient flows through the diffrax adjoint via the endpoint state values and
    their ODE derivatives.

    Cost: one ODE solve (shared JIT kernel with ``integrate_parcel``) + 3 ODE RHS
    evaluations.  No ``SaveAt(dense=True)`` second kernel; no vmap fine grid.

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial (equilibrated) state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
    ts : array, shape ``(n_out,)``
        Output times (monotonic, must span the supersaturation peak).
    rtol, atol : float or array, optional
        Solver tolerances.
    max_steps : int, optional
        Upper bound on adaptive steps.

    Returns
    -------
    float
        Maximum supersaturation ``S_max``.
    """
    y0 = jnp.asarray(y0)
    ts = jnp.asarray(ts)
    nr = int(y0.shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)

    sol = _solve(y0, args, ts, rtol, atol, None, max_steps)

    # Coarse bracket: argmax of saved S values (no polynomial evaluations needed).
    # i_pk locates the ±1-cell window; stop-gradient'd so discrete index ops are safe.
    S_coarse = sol.ys[:, _S_IDX]
    i_pk = jax.lax.stop_gradient(jnp.clip(jnp.argmax(S_coarse), 1, ts.shape[0] - 2))

    # Extract bracket states via dynamic indexing (i_pk is stop-gradient'd).
    t_m, t_c, t_p = ts[i_pk - 1], ts[i_pk], ts[i_pk + 1]
    y_m = sol.ys[i_pk - 1]
    y_c = sol.ys[i_pk]
    y_p = sol.ys[i_pk + 1]

    # ODE RHS at the three bracket points — exact dS/dt, no finite differences.
    # Gradient flows through both the ODE adjoint (via y_m/y_c/y_p) and the
    # explicit V-dependence in parcel_ode_sys.
    f_m = parcel_ode_sys(t_m, y_m, args)
    f_c = parcel_ode_sys(t_c, y_c, args)
    f_p = parcel_ode_sys(t_p, y_p, args)

    S_m, S_c, S_p = y_m[_S_IDX], y_c[_S_IDX], y_p[_S_IDX]
    dS_m, dS_c, dS_p = f_m[_S_IDX], f_c[_S_IDX], f_p[_S_IDX]

    # Hermite cubic on [t0, t1] parameterised by u = (t - t0) / h:
    #   p(u) = h00(u)*S0 + h01(u)*S1 + h*h10(u)*dS0 + h*h11(u)*dS1
    # Gradient flows through (S0, S1, dS0, dS1) at the stop-gradient'd u*.
    def _S_at_u(S0, S1, dS0, dS1, h, u):
        return (
            (2 * u**3 - 3 * u**2 + 1) * S0
            + (-2 * u**3 + 3 * u**2) * S1
            + h * (u**3 - 2 * u**2 + u) * dS0
            + h * (u**3 - u**2) * dS1
        )

    def _peak_u(S0, S1, dS0, dS1, h) -> Array:
        """Stop-gradient'd interior peak u in (0,1), or boundary fallback."""
        # All inputs are stop-gradient'd so u* carries no gradient.
        S0s, S1s, dS0s, dS1s, hs = (jax.lax.stop_gradient(x) for x in (S0, S1, dS0, dS1, h))
        A = 6.0 * (S0s - S1s) / hs + 3.0 * (dS0s + dS1s)
        B = 6.0 * (S1s - S0s) / hs - 2.0 * (2.0 * dS0s + dS1s)
        C = dS0s

        # Quadratic roots (valid when |A| is non-negligible).
        disc = B * B - 4.0 * A * C
        sq = jnp.sqrt(jnp.maximum(disc, 0.0))
        A_safe = jnp.where(jnp.abs(A) > 1e-30, A, 1.0)
        u1 = (-B + sq) / (2.0 * A_safe)
        u2 = (-B - sq) / (2.0 * A_safe)

        # Linear root: when A ≈ 0, dp/du = Bu + C is linear with root u = -C/B.
        # d²p/du² at the linear root ≈ B, so B < 0 is the maximum condition.
        B_safe = jnp.where(jnp.abs(B) > 1e-30, B, 1.0)
        u_lin = -C / B_safe

        def _is_interior_max_quad(u):
            return (
                (disc >= 0.0)
                & (jnp.abs(A) > 1e-30)
                & (u > 0.0)
                & (u < 1.0)
                & (2.0 * A * u + B < 0.0)
            )

        v1, v2 = _is_interior_max_quad(u1), _is_interior_max_quad(u2)
        v_lin = (
            (jnp.abs(A) <= 1e-30) & (jnp.abs(B) > 1e-30) & (B < 0.0) & (u_lin > 0.0) & (u_lin < 1.0)
        )
        S1v = _S_at_u(S0s, S1s, dS0s, dS1s, hs, u1)
        S2v = _S_at_u(S0s, S1s, dS0s, dS1s, hs, u2)
        return jnp.where(
            v1 & (~v2 | (S1v >= S2v)),
            u1,
            jnp.where(v2, u2, jnp.where(v_lin, u_lin, jnp.where(S0s >= S1s, 0.0, 1.0))),
        )

    # Find stop-gradient'd u* on each sub-interval.
    h_lo = jax.lax.stop_gradient(t_c - t_m)
    h_hi = jax.lax.stop_gradient(t_p - t_c)
    u_lo = _peak_u(S_m, S_c, dS_m, dS_c, h_lo)
    u_hi = _peak_u(S_c, S_p, dS_c, dS_p, h_hi)

    # Evaluate S at each candidate peak (gradient flows here via envelope theorem).
    S_lo = _S_at_u(S_m, S_c, dS_m, dS_c, h_lo, u_lo)
    S_hi = _S_at_u(S_c, S_p, dS_c, dS_p, h_hi, u_hi)

    # Take the sub-interval with the larger S_max (condition is stop-gradient'd).
    use_lo = jax.lax.stop_gradient(jax.lax.stop_gradient(S_lo) >= jax.lax.stop_gradient(S_hi))
    return jnp.where(use_lo, S_lo, S_hi)

nd_from_parcel

nd_from_parcel(
    y0: ArrayLike,
    args: tuple,
    t_end: float,
    *,
    epsilon: float = 1e-08,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    max_steps: int = 100000,
) -> Array

Differentiable activated droplet number concentration via sigmoid soft threshold.

Integrates the parcel ODE to t_end and evaluates:

\[ N_d^{\text{soft}} = \sum_i N_i \cdot \sigma\!\left( \frac{r_i(t_{\text{end}}) - r_{\text{crit},i}}{\varepsilon} \right) \]

where \(\sigma\) is the logistic sigmoid, \(r_{\text{crit},i}\) is the approximate Köhler critical radius for bin i, and \(\varepsilon\) controls the sharpness of the threshold.

Unlike the hard-threshold Nd diagnostic, this function is fully differentiable with respect to y0 and args via diffrax's default RecursiveCheckpointAdjoint::

dNd_dV = jax.grad(nd_from_parcel, argnums=1)(y0, args, t_end)[4].V

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial (equilibrated) state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys. Nis must be in m⁻³ (the unit stored on AerosolSpecies.Nis).

required
t_end float

Integration time (s). Should extend past \(S_{\text{max}}\) so that kinetically limited droplets have time to grow past their critical radius.

required
epsilon float

Sigmoid half-width (m). Controls the accuracy/smoothness trade-off: smaller ε is more accurate but has a steeper gradient. Default 1e-8 (≈ 10 nm, much smaller than a typical 100 nm critical radius).

1e-08
rtol float or array

Solver tolerances; defaults mirror the CVode-equivalent values.

STATE_RTOL
atol float or array

Solver tolerances; defaults mirror the CVode-equivalent values.

STATE_RTOL
max_steps int

Upper bound on adaptive ODE steps.

100000

Returns:

Name Type Description
Nd_soft Array

Soft activated droplet number concentration (m⁻³) at t_end. In the limit ε → 0 this converges to the hard-threshold count using the approximate Köhler critical radius.

Source code in pyrcel/integrator.py
def nd_from_parcel(
    y0: ArrayLike,
    args: tuple,
    t_end: float,
    *,
    epsilon: float = 1e-8,
    rtol: float = STATE_RTOL,
    atol: ArrayLike | None = None,
    max_steps: int = 100_000,
) -> Array:
    """Differentiable activated droplet number concentration via sigmoid soft threshold.

    Integrates the parcel ODE to ``t_end`` and evaluates:

    $$
    N_d^{\\text{soft}} = \\sum_i N_i \\cdot \\sigma\\!\\left(
        \\frac{r_i(t_{\\text{end}}) - r_{\\text{crit},i}}{\\varepsilon}
    \\right)
    $$

    where $\\sigma$ is the logistic sigmoid, $r_{\\text{crit},i}$ is
    the approximate Köhler critical radius for bin *i*, and $\\varepsilon$
    controls the sharpness of the threshold.

    Unlike the hard-threshold [Nd][pyrcel.model_output.ModelOutput.Nd] diagnostic, this
    function is fully differentiable with respect to ``y0`` and ``args`` via
    ``diffrax``'s default ``RecursiveCheckpointAdjoint``::

        dNd_dV = jax.grad(nd_from_parcel, argnums=1)(y0, args, t_end)[4].V

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial (equilibrated) state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
        ``Nis`` must be in m⁻³ (the unit stored on ``AerosolSpecies.Nis``).
    t_end : float
        Integration time (s). Should extend past $S_{\\text{max}}$ so that
        kinetically limited droplets have time to grow past their critical radius.
    epsilon : float, optional
        Sigmoid half-width (m). Controls the accuracy/smoothness trade-off:
        smaller ``ε`` is more accurate but has a steeper gradient. Default
        ``1e-8`` (≈ 10 nm, much smaller than a typical 100 nm critical radius).
    rtol, atol : float or array, optional
        Solver tolerances; defaults mirror the CVode-equivalent values.
    max_steps : int, optional
        Upper bound on adaptive ODE steps.

    Returns
    -------
    Nd_soft : Array
        Soft activated droplet number concentration (m⁻³) at ``t_end``.
        In the limit ``ε → 0`` this converges to the hard-threshold count
        using the approximate Köhler critical radius.
    """
    y0 = jnp.asarray(y0)
    nr = int(y0.shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)

    ts = jnp.stack([jnp.zeros((), dtype=jnp.float64), jnp.asarray(t_end, dtype=jnp.float64)])
    sol = _solve(y0, args, ts, rtol, atol, None, max_steps)

    y_final = sol.ys[-1]  # shape (7 + nr,); T is state index 2
    T_final = y_final[2]
    rs_final = y_final[N_STATE_VARS:]  # wet radii (m), shape (nr,)

    r_drys = jnp.asarray(args[0])
    Nis = jnp.asarray(args[1])
    kappas = jnp.asarray(args[2])

    r_crits, _ = _kohler_crit_approx(T_final, r_drys, kappas)
    return jnp.sum(Nis * jax.nn.sigmoid((rs_final - r_crits) / epsilon))

Terminated-run pipeline

find_smax

find_smax(y0, args, t_end, *, rtol=STATE_RTOL, atol=None, max_steps=100000)

Precisely localize the supersaturation maximum via a dS/dt = 0 event.

Unlike sampling a saved trajectory, t_smax is root-found rather than quantized to the output cadence.

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial (equilibrated) state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys.

required
t_end float

Upper bound on integration time, s.

required
rtol float or array

Solver tolerances.

STATE_RTOL
atol float or array

Solver tolerances.

STATE_RTOL
max_steps int

Upper bound on adaptive steps.

100000

Returns:

Name Type Description
t_smax float

Time of the supersaturation maximum, s.

smax float

Peak supersaturation value.

y_smax array, shape ``(7 + nr,)``

State vector at t_smax.

activated bool

False if no downward zero-crossing of dS/dt occurred before t_end (the parcel never reached a maximum in the integration window).

Source code in pyrcel/integrator.py
def find_smax(y0, args, t_end, *, rtol=STATE_RTOL, atol=None, max_steps=100_000):
    """Precisely localize the supersaturation maximum via a ``dS/dt = 0`` event.

    Unlike sampling a saved trajectory, ``t_smax`` is root-found rather than
    quantized to the output cadence.

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial (equilibrated) state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
    t_end : float
        Upper bound on integration time, s.
    rtol, atol : float or array, optional
        Solver tolerances.
    max_steps : int, optional
        Upper bound on adaptive steps.

    Returns
    -------
    t_smax : float
        Time of the supersaturation maximum, s.
    smax : float
        Peak supersaturation value.
    y_smax : array, shape ``(7 + nr,)``
        State vector at ``t_smax``.
    activated : bool
        ``False`` if no downward zero-crossing of ``dS/dt`` occurred before
        ``t_end`` (the parcel never reached a maximum in the integration window).
    """
    y0 = jnp.asarray(y0)
    nr = int(y0.shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)
    sol = _solve_to_smax(y0, args, t_end, rtol, atol, max_steps)
    t_smax = sol.ts[-1]
    y_smax = sol.ys[-1]
    smax = y_smax[6]
    activated = bool(t_smax < t_end)
    return t_smax, smax, y_smax, activated

Not differentiable

find_smax uses event detection (dS/dt = 0) which is discontinuous and therefore not compatible with jax.grad. Use max_supersaturation on the differentiable path.


terminate_cutoff_time

terminate_cutoff_time(
    y0,
    args,
    t_end,
    *,
    terminate_depth: float,
    t_smax,
    y_smax,
    activated: bool,
    rtol: float = STATE_RTOL,
    atol=None,
    max_steps: int = 100000,
) -> float

Compute the integration stop time for altitude-based cutoff past S_max.

Parameters:

Name Type Description Default
y0 array, shape ``(7 + nr,)``

Initial state vector.

required
args tuple

(r_drys, Nis, kappas, accom, V) for parcel_ode_sys.

required
t_end float

Upper bound on integration time, s.

required
terminate_depth float

Extra vertical distance (m) past z_smax before stopping.

required
t_smax float

Time of the supersaturation maximum, s.

required
y_smax array

State vector at t_smax.

required
activated bool

If False, returns t_end without any further integration.

required
rtol float or array

Solver tolerances.

STATE_RTOL
atol float or array

Solver tolerances.

STATE_RTOL
max_steps int

Upper bound on adaptive steps.

100000

Returns:

Type Description
float

Cutoff time, s. Always <= t_end.

Source code in pyrcel/integrator.py
def terminate_cutoff_time(
    y0,
    args,
    t_end,
    *,
    terminate_depth: float,
    t_smax,
    y_smax,
    activated: bool,
    rtol: float = STATE_RTOL,
    atol=None,
    max_steps: int = 100_000,
) -> float:
    """Compute the integration stop time for altitude-based cutoff past ``S_max``.

    Parameters
    ----------
    y0 : array, shape ``(7 + nr,)``
        Initial state vector.
    args : tuple
        ``(r_drys, Nis, kappas, accom, V)`` for `parcel_ode_sys`.
    t_end : float
        Upper bound on integration time, s.
    terminate_depth : float
        Extra vertical distance (m) past ``z_smax`` before stopping.
    t_smax : float
        Time of the supersaturation maximum, s.
    y_smax : array
        State vector at ``t_smax``.
    activated : bool
        If ``False``, returns ``t_end`` without any further integration.
    rtol, atol : float or array, optional
        Solver tolerances.
    max_steps : int, optional
        Upper bound on adaptive steps.

    Returns
    -------
    float
        Cutoff time, s. Always ``<= t_end``.
    """
    if not activated:
        return float(t_end)
    t_smax_f = float(t_smax)
    V = args[4]
    if isinstance(V, ConstantV):
        return min(t_smax_f + terminate_depth / float(V.V), float(t_end))
    if not callable(V):
        return min(t_smax_f + terminate_depth / float(V), float(t_end))
    z_target = float(y_smax[0]) + terminate_depth
    nr = int(jnp.asarray(y0).shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)
    sol_d = _solve_to_depth(y0, args, t_end, z_target, rtol, atol, max_steps)
    return min(float(sol_d.ts[-1]), float(t_end))

integrate_parcel_terminated

integrate_parcel_terminated(
    y0,
    args,
    t_end,
    output_dt,
    *,
    terminate_depth: float = 100.0,
    rtol: float = STATE_RTOL,
    atol=None,
    max_steps: int = 100000,
    progress_meter=None,
    phase_timer=None,
)

Integrate, stopping terminate_depth metres past the supersaturation max.

Reproduces the master terminate=True semantics: locate S_max (via the dS/dt event), then continue an extra terminate_depth metres (= terminate_depth / V seconds) and stop. Output is produced at output_dt cadence over [0, t_cutoff] from a single adaptive solve.

The cutoff is altitude-based (stop terminate_depth metres above z_smax), so it is correct for both a constant updraft and a time-varying V(t): for constant V the cutoff time is computed analytically, otherwise it is root-found with a second z-crossing event.

This is the interactive / parity path: t_cutoff is data-dependent so the output grid length is dynamic (the outer call runs eagerly; the solves are jitted). For jit/vmap/grad use the fixed-horizon integrate_parcel / find_smax / max_supersaturation instead.

Returns (ts, ys, info) with ts/ys as numpy arrays and info a dict of t_smax, smax, t_cutoff, activated, success.

Source code in pyrcel/integrator.py
def integrate_parcel_terminated(
    y0,
    args,
    t_end,
    output_dt,
    *,
    terminate_depth: float = 100.0,
    rtol: float = STATE_RTOL,
    atol=None,
    max_steps: int = 100_000,
    progress_meter=None,
    phase_timer=None,
):
    """Integrate, stopping ``terminate_depth`` metres past the supersaturation max.

    Reproduces the ``master`` ``terminate=True`` semantics: locate ``S_max`` (via the
    ``dS/dt`` event), then continue an extra ``terminate_depth`` metres
    (``= terminate_depth / V`` seconds) and stop. Output is produced at ``output_dt``
    cadence over ``[0, t_cutoff]`` from a single adaptive solve.

    The cutoff is altitude-based (stop ``terminate_depth`` metres above ``z_smax``), so
    it is correct for both a constant updraft and a time-varying ``V(t)``: for constant
    ``V`` the cutoff time is computed analytically, otherwise it is root-found with a
    second ``z``-crossing event.

    This is the interactive / parity path: ``t_cutoff`` is data-dependent so the output
    grid length is dynamic (the outer call runs eagerly; the solves are jitted). For
    ``jit``/``vmap``/``grad`` use the fixed-horizon
    [integrate_parcel][pyrcel.integrator.integrate_parcel] /
    [find_smax][pyrcel.integrator.find_smax] /
    [max_supersaturation][pyrcel.integrator.max_supersaturation] instead.

    Returns ``(ts, ys, info)`` with ``ts``/``ys`` as numpy arrays and ``info`` a dict
    of ``t_smax``, ``smax``, ``t_cutoff``, ``activated``, ``success``.
    """
    y0 = jnp.asarray(y0)
    nr = int(y0.shape[0] - N_STATE_VARS)
    if atol is None:
        atol = atol_vector(nr)

    def _find():
        return find_smax(y0, args, t_end, rtol=rtol, atol=atol, max_steps=max_steps)

    if phase_timer is not None:
        (t_smax, smax, y_smax, activated), _ = phase_timer.run(
            ("smax", nr), "S_max event solve", _find
        )
    else:
        t_smax, smax, y_smax, activated = _find()
    t_smax_f = float(t_smax)
    if not activated:
        t_cutoff = float(t_end)
    else:
        t_cutoff = terminate_cutoff_time(
            y0,
            args,
            t_end,
            terminate_depth=terminate_depth,
            t_smax=t_smax,
            y_smax=y_smax,
            activated=True,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    ts = np.append(np.arange(0.0, t_cutoff, output_dt), t_cutoff)
    ts = jnp.asarray(ts)
    if progress_meter is None:
        progress_meter = dfx.NoProgressMeter()

    def _traj():
        return _solve(y0, args, ts, rtol, atol, None, max_steps, progress_meter)

    if phase_timer is not None:
        sol, _ = phase_timer.run(("traj", nr, len(ts)), "trajectory solve", _traj)
    else:
        sol = _traj()
    info = {
        "t_smax": t_smax_f,
        "smax": float(smax),
        "t_cutoff": t_cutoff,
        "activated": activated,
        "success": bool(sol.result == dfx.RESULTS.successful),
        "z_smax": float(y_smax[0]),
        "z_end": float(np.asarray(sol.ys)[-1, 0]),
    }
    return np.asarray(sol.ts), np.asarray(sol.ys), info

Tolerance constants

atol_vector

atol_vector(nr: int) -> Array

Build the per-component absolute tolerance vector for the ODE solver.

Parameters:

Name Type Description Default
nr int

Number of aerosol radius bins.

required

Returns:

Type Description
jnp.ndarray, shape ``(7 + nr,)``

Absolute tolerance vector: CVode-equivalent values for bulk state variables followed by RADIUS_ATOL for each radius bin.

Source code in pyrcel/integrator.py
def atol_vector(nr: int) -> Array:
    """Build the per-component absolute tolerance vector for the ODE solver.

    Parameters
    ----------
    nr : int
        Number of aerosol radius bins.

    Returns
    -------
    jnp.ndarray, shape ``(7 + nr,)``
        Absolute tolerance vector: CVode-equivalent values for bulk state
        variables followed by ``RADIUS_ATOL`` for each radius bin.
    """
    return jnp.asarray(STATE_ATOL + [RADIUS_ATOL] * int(nr))

The module also exports STATE_RTOL, STATE_ATOL, and RADIUS_ATOL as importable constants that match the CVode configuration from the legacy model.