Source code for monee.simulation.stepper

"""Stepper - externally-paced co-simulation adapter.

Steps a network forward in time at caller-supplied dt_h values (the external
framework owns the clock), maintaining inter-step state (linepack, LTC,
storage SoC, ...) via the shared StepState plumbing. Each
:meth:`Stepper.step` works on a fresh network copy.
"""

from __future__ import annotations

import logging
from collections.abc import Mapping
from typing import Any

import pandas

from monee.model import Network
from monee.simulation.step_state import StepState
from monee.simulation.timeseries import (
    StepResult,
    TimeseriesData,
    TimeseriesResult,
)
from monee.solver.dispatch import resolve_solver

_log = logging.getLogger(__name__)


[docs] class Stepper: """Externally-paced co-simulation adapter. Holds base network, solver, optional problem and timeseries data, and the persistent :class:`StepState`. ``max_history`` caps how many solved steps are retained (``None`` = unlimited): every step keeps a full solved network copy (once in the :class:`StepState`, once in the :class:`StepResult` history), so an open-ended co-simulation grows memory without bound otherwise. Set it to a small number (the longest lookback any ``inter_step_equations`` needs, e.g. 8) for long runs; :meth:`to_timeseries_result` then only covers the retained window.""" __slots__ = ( "_base_net", "_solver", "_optimization_problem", "_timeseries_data", "_initial_state", "_on_step_error", "_max_history", "_solver_kwargs", "_state", "_history", "_step_count", "_t_h", ) def __init__( self, net: Network, *, solver=None, backend: str | None = None, optimization_problem=None, timeseries_data: TimeseriesData | None = None, initial_state: Mapping[tuple, float] | None = None, on_step_error: str = "raise", max_history: int | None = None, **solver_kwargs: Any, ) -> None: if on_step_error not in ("raise", "skip"): raise ValueError( f"on_step_error must be 'raise' or 'skip', got {on_step_error!r}" ) if max_history is not None and max_history < 1: raise ValueError(f"max_history must be >= 1 or None, got {max_history}") self._base_net = net self._solver = resolve_solver(solver, backend=backend) self._optimization_problem = optimization_problem self._timeseries_data = timeseries_data self._initial_state: dict = dict(initial_state) if initial_state else {} self._on_step_error = on_step_error self._max_history = max_history self._solver_kwargs = dict(solver_kwargs) self._state = StepState( initial_state=self._initial_state, max_steps=max_history ) self._history: list[StepResult] = [] self._step_count: int = 0 self._t_h: float = 0.0 @property def state(self) -> StepState: return self._state @property def history(self) -> list[StepResult]: """Retained step results (the last ``max_history`` ones, or all).""" return list(self._history) @property def step_count(self) -> int: """Total number of step() calls, including dropped and failed ones.""" return self._step_count @property def t_h(self) -> float: return self._t_h
[docs] def step( self, dt_h: float, *, data_overrides: Mapping[tuple, float] | None = None, ts_index: int | None = None, ) -> StepResult: """Advance by *dt_h* hours. ``data_overrides`` are applied after the ``ts_index`` slice (overrides win on conflicts).""" if dt_h <= 0: raise ValueError(f"dt_h must be > 0, got {dt_h}") net_copy = self._base_net.copy() if self._timeseries_data is not None and ts_index is not None: self._timeseries_data.apply_to_network(net_copy, ts_index) if data_overrides: _apply_overrides(net_copy, data_overrides) self._state.dt_h = dt_h step_idx = self._step_count try: result = self._solver.solve( net_copy, optimization_problem=self._optimization_problem, step_state=self._state, **self._solver_kwargs, ) except Exception as exc: if self._on_step_error == "raise": raise _log.warning("Stepper step %d failed: %s", step_idx, exc) sr = StepResult(step=step_idx, result=None, failed=True, error=exc) self._record(sr, dt_h) return sr self._state.push(result.network) sr = StepResult(step=step_idx, result=result) self._record(sr, dt_h) return sr
[docs] def get(self, component_id, attr: str, step: int = -1): """Solved value of ``attr`` on component ``component_id`` - by default from the most recent successful step (the *get* side of a co-simulation adapter's set/step/get contract; ``data_overrides`` is the *set* side). ``step`` follows :meth:`StepState.get`: negative = relative to the latest solve, non-negative = absolute step index. Returns ``None`` (or the ``initial_state`` fallback) when no solve has written the value.""" return self._state.get(component_id, attr, step=step)
def _record(self, sr: StepResult, dt_h: float) -> None: self._history.append(sr) if self._max_history is not None and len(self._history) > self._max_history: del self._history[0] self._step_count += 1 self._t_h += dt_h
[docs] def reset( self, *, initial_state: Mapping[tuple, float] | None = None, ) -> None: """Clear step history and recreate the StepState.""" if initial_state is not None: self._initial_state = dict(initial_state) self._state = StepState( initial_state=self._initial_state, max_steps=self._max_history ) self._history = [] self._step_count = 0 self._t_h = 0.0
[docs] def to_timeseries_result( self, datetime_index: pandas.DatetimeIndex | None = None, ) -> TimeseriesResult: """Wrap the retained history as a :class:`TimeseriesResult`. With ``max_history`` set this covers only the retained window.""" return TimeseriesResult(list(self._history), datetime_index=datetime_index)
def __enter__(self) -> Stepper: return self def __exit__(self, exc_type, exc, tb) -> None: return None def __repr__(self) -> str: return ( f"Stepper(steps={self._step_count}, t_h={self._t_h:.4g}, " f"solver={type(self._solver).__name__})" )
def _build_id_index(net: Network) -> dict: """Map each component id to the list of models sharing it.""" by_id: dict = {} for node in net.nodes: by_id.setdefault(node.id, []).append(node.model) for child in net.childs_by_ids(node.child_ids): by_id.setdefault(child.id, []).append(child.model) for branch in net.branches: by_id.setdefault(branch.id, []).append(branch.model) for compound in net.compounds: by_id.setdefault(compound.id, []).append(compound.model) return by_id def _apply_overrides(net: Network, overrides: Mapping[tuple, float]) -> None: """Apply ``{(comp_id, attr): value}`` via :meth:`TimeseriesData._set_model_attr`. Unknown id/attr raise (treated as wiring bugs, not transient failures).""" if not overrides: return by_id = _build_id_index(net) for (comp_id, attr), value in overrides.items(): models = by_id.get(comp_id) if not models: raise KeyError( f"data_overrides: component id {comp_id!r} not found in network" ) applied = False for model in models: if hasattr(model, attr): TimeseriesData._set_model_attr(model, attr, value) applied = True if not applied: raise AttributeError( f"data_overrides: attribute {attr!r} not found on any model " f"with id {comp_id!r}" )