Source code for pygotm.fabm.engine

"""Thin Python wrapper around ``pyfabm.Model``."""

from __future__ import annotations

from collections.abc import Callable, Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import numpy as np
from numpy.typing import NDArray

from pygotm.fabm.state import fabm_state_variable_names

FloatArray = NDArray[np.float64]

__all__ = ["FABMEngine", "FABMOutputKind", "FABMOutputSpec"]

FABMOutputKind = Literal["z", "scalar"]


[docs] @dataclass(frozen=True, slots=True) class FABMOutputSpec: """Metadata for one FABM output variable exposed by pyfabm.""" name: str kind: FABMOutputKind units: str long_name: str
[docs] class FABMEngine: """Manage one pyfabm model while pyGOTM owns transport and timestepping.""" def __init__( self, config_path: str | Path, *, model_factory: Callable[[str], Any] | None = None, ) -> None: self.config_path = Path(config_path) self._model_factory = model_factory self.model: Any | None = None self._state = np.zeros(0, dtype=np.float64) self._rates = np.zeros(0, dtype=np.float64) self._dependency_buffers: dict[str, FloatArray] = {} self._dependency_cache: dict[str, Any | None] = {}
[docs] def initialize( self, nlev: int | None = None, h_col: np.ndarray | None = None, *, skip_start: bool = False, ) -> None: """Construct and start the pyfabm model for a 1-D water column. *nlev* is the number of grid cells (pyGOTM ``nlev``). When provided, the model is created with ``shape=(nlev,)`` so pyfabm allocates 1-D arrays of the correct size. *h_col* (length *nlev*, bottom→surface) is the initial cell-thickness array required by pyfabm before ``start()``. """ if not self.config_path.is_file(): msg = f"FABM model YAML not found: {self.config_path}" raise RuntimeError(msg) factory = self._model_factory if factory is None: try: import pyfabm except ImportError as exc: msg = ( "FABM is enabled, but pyfabm could not be imported. " "Install FABM's Python front-end and ensure its compiled " "FABM libraries are discoverable before running this case." ) raise RuntimeError(msg) from exc factory = pyfabm.Model try: if nlev is not None: self.model = factory(str(self.config_path), shape=(nlev,)) # type: ignore[call-arg] else: self.model = factory(str(self.config_path)) except Exception as exc: msg = f"failed to initialize pyfabm.Model from {self.config_path}: {exc}" raise RuntimeError(msg) from exc self._dependency_cache.clear() # pyfabm 3.0: cell_thickness is a write-only property setter — use direct # assignment so pyfabm registers the array before start() / getRates(). if nlev is not None and h_col is not None: try: arr = np.ascontiguousarray(h_col, dtype=np.float64) self.model.cell_thickness = arr except AttributeError: pass if not skip_start: self._start_model() self._state = self._read_model_state() self._rates = np.zeros_like(self._state)
[docs] def has_dependency(self, name: str) -> bool: """Return whether the pyfabm model exposes a dependency by *name*.""" return self._find_dependency(name) is not None
[docs] def set_dependency( self, name: str, value: float | np.ndarray, ) -> None: """Set one FABM dependency from a scalar or contiguous float64 array.""" dependency = self._require_dependency(name) self._set_dependency_value(name, dependency, value)
[docs] def set_dependency_if_present( self, name: str, value: float | np.ndarray, ) -> bool: """Set one optional FABM dependency and report whether it exists.""" dependency = self._find_dependency(name) if dependency is None: return False self._set_dependency_value(name, dependency, value) return True
def _set_dependency_value( self, name: str, dependency: Any, value: float | np.ndarray, ) -> None: if isinstance(value, np.ndarray): array = np.ascontiguousarray(value, dtype=np.float64) existing = self._dependency_buffers.get(name) if existing is None or existing.shape != array.shape: existing = np.empty_like(array) self._dependency_buffers[name] = existing np.copyto(existing, array) self._assign_value(dependency, existing) return self._assign_value(dependency, float(value))
[docs] def set_state(self, state: np.ndarray) -> None: """Replace the pyGOTM-owned FABM state buffer.""" array = np.ascontiguousarray(state, dtype=np.float64) if self._state.shape != array.shape: self._state = np.empty_like(array) np.copyto(self._state, array) if self.model is None: return if hasattr(self.model, "state"): self.model.state[...] = self._state return variables = self._state_variables() if variables: flat_state = self._state.reshape((self._state.shape[0], -1)) for index, variable in enumerate(variables[: flat_state.shape[0]]): self._assign_value(variable, flat_state[index])
@property def state(self) -> np.ndarray: """Return the current mutable FABM state buffer.""" return self._state
[docs] def get_rates( self, *, surface: bool = True, bottom: bool = True, time: float | None = None, ) -> np.ndarray: """Return FABM source rates for each layer. When *surface=True* (default for backwards compat) the returned rates include the air-sea surface exchange distributed over ALL layers by pyfabm — which is physically wrong for a 1D column. Pass ``surface=False, bottom=False`` to get bulk-only reaction rates that can safely be applied to every layer, then handle boundary exchange explicitly on the top/bottom layers. """ model = self._require_model() get_fn = getattr(model, "getRates", None) or getattr(model, "get_rates", None) if get_fn is None: msg = "pyfabm model does not expose getRates()" raise RuntimeError(msg) rates = self._call_get_rates_flags( get_fn, surface=surface, bottom=bottom, time=time, ) array = np.ascontiguousarray(rates, dtype=np.float64) if self._rates.shape != array.shape: self._rates = np.empty_like(array) np.copyto(self._rates, array) return self._rates
[docs] def get_vertical_movement(self) -> np.ndarray | None: """Return vertical movement velocities for all interior state variables. Returns an array of shape ``(n_state_vars, nlev)`` in m s⁻¹ (positive upward), or ``None`` if the underlying model does not expose this API. Must be called after ``get_rates()`` has updated the model's internal diagnostics. """ model = self._require_model() fn = getattr(model, "get_vertical_movement", None) if not callable(fn): return None try: result = fn() if result is None: return None arr = np.ascontiguousarray(result, dtype=np.float64) return arr if arr.ndim == 2 else None except Exception: return None
[docs] def diagnostics(self) -> dict[str, np.ndarray | float]: """Return current FABM diagnostic values as NumPy arrays or floats.""" model = self._require_model() diagnostics: dict[str, np.ndarray | float] = {} raw = getattr(model, "diagnostics", None) if isinstance(raw, dict): for name, value in raw.items(): diagnostics[str(name)] = self._diagnostic_value(value, copy=True) for variable in self._diagnostic_variables(): name = str(getattr(variable, "name", getattr(variable, "id", ""))) if not name: continue diagnostics[name] = self._diagnostic_value( getattr(variable, "value", 0.0), copy=True, ) return diagnostics
[docs] def diagnostic( self, name: str, *, copy: bool = True, ) -> np.ndarray | float | None: """Return one FABM diagnostic by name, or ``None`` when unavailable.""" model = self._require_model() raw = getattr(model, "diagnostics", None) if isinstance(raw, dict): value = raw.get(name) if value is not None: return self._diagnostic_value(value, copy=copy) for raw_name, raw_value in raw.items(): if str(raw_name) == name: return self._diagnostic_value(raw_value, copy=copy) for variable in self._diagnostic_variables(): variable_name = str(getattr(variable, "name", getattr(variable, "id", ""))) if variable_name == name: return self._diagnostic_value( getattr(variable, "value", 0.0), copy=copy, ) return None
[docs] def unresolved_dependencies(self) -> tuple[str, ...]: """Return a best-effort list of pyfabm dependencies still unset.""" missing: list[str] = [] for dependency in self._dependencies(): name = str(getattr(dependency, "name", getattr(dependency, "id", ""))) if not name: name = repr(dependency) if self._dependency_is_missing(dependency): missing.append(name) return tuple(dict.fromkeys(missing))
[docs] def start(self) -> None: """Call model.start() after dependencies are set; update state buffers. Use this when ``initialize()`` was called with ``skip_start=True`` to defer start until the caller has supplied initial dependency values. """ self._start_model() self._state = self._read_model_state() self._rates = np.zeros_like(self._state)
[docs] def state_variable_names(self) -> tuple[str, ...]: """Return state-variable names exposed by the wrapped pyfabm model.""" return fabm_state_variable_names(self._require_model())
[docs] def output_variable_specs(self) -> tuple[FABMOutputSpec, ...]: """Return state and enabled diagnostic variables for NetCDF output.""" model = self._require_model() specs: list[FABMOutputSpec] = [] seen: set[str] = set() def add(variable: Any, kind: FABMOutputKind) -> None: name = self._variable_output_name(variable) if name in seen: msg = f"duplicate FABM output variable {name!r}" raise RuntimeError(msg) seen.add(name) specs.append( FABMOutputSpec( name=name, kind=kind, units=str(getattr(variable, "units", "") or ""), long_name=str(getattr(variable, "long_name", "") or ""), ) ) for variable in self._interior_state_variables(model): add(variable, "z") for variable in self._bottom_state_variables(model): add(variable, "scalar") for variable in self._surface_state_variables(model): add(variable, "scalar") horizontal_diagnostics = [ variable for variable in self._horizontal_diagnostic_variables() if self._variable_outputs_enabled(variable) ] horizontal_diagnostic_names = { self._variable_output_name(variable) for variable in horizontal_diagnostics } for variable in self._diagnostic_variables(): if not self._variable_outputs_enabled(variable): continue if self._variable_output_name(variable) in horizontal_diagnostic_names: continue add(variable, "z") for variable in horizontal_diagnostics: add(variable, "scalar") return tuple(specs)
def _require_model(self) -> Any: if self.model is None: msg = "FABMEngine.initialize() has not been called" raise RuntimeError(msg) return self.model def _start_model(self) -> None: model = self._require_model() if not self._check_ready(): missing = self.unresolved_dependencies() deps = ", ".join(missing) msg = f"pyfabm model is not ready; unresolved dependencies: {deps}" raise RuntimeError(msg) start = getattr(model, "start", None) if callable(start): try: start() except Exception as exc: msg = f"pyfabm model.start() failed: {exc}" raise RuntimeError(msg) from exc def _check_ready(self) -> bool: model = self._require_model() check_ready = getattr(model, "checkReady", None) if not callable(check_ready): return True for args, kwargs in (((), {}), ((), {"stop": False}), ((False,), {})): try: return bool(check_ready(*args, **kwargs)) except TypeError: continue except Exception as exc: msg = f"pyfabm model readiness check failed: {exc}" raise RuntimeError(msg) from exc return True def _read_model_state(self) -> FloatArray: model = self._require_model() if hasattr(model, "state"): return np.ascontiguousarray(model.state, dtype=np.float64) variables = self._state_variables() if not variables: return np.zeros(0, dtype=np.float64) values = [ np.asarray(getattr(variable, "value", 0.0), dtype=np.float64) for variable in variables ] return np.ascontiguousarray(np.stack(values), dtype=np.float64) def _state_variables(self) -> list[Any]: model = self._require_model() variables = getattr(model, "state_variables", None) if variables is None: variables = getattr(model, "stateVariables", None) if variables is None: return [] return list(variables) def _diagnostic_variables(self) -> list[Any]: model = self._require_model() for attr in ("diagnostic_variables", "diagnosticVariables"): variables = getattr(model, attr, None) if variables is not None: return list(variables) return [] def _horizontal_diagnostic_variables(self) -> list[Any]: model = self._require_model() for attr in ( "horizontal_diagnostic_variables", "horizontalDiagnosticVariables", ): variables = getattr(model, attr, None) if variables is not None: return list(variables) return [] @staticmethod def _interior_state_variables(model: Any) -> list[Any]: for attr in ( "interior_state_variables", "bulk_state_variables", "state_variables", "stateVariables", ): variables = getattr(model, attr, None) if variables is not None: return list(variables) return [] @staticmethod def _bottom_state_variables(model: Any) -> list[Any]: variables = getattr(model, "bottom_state_variables", None) return [] if variables is None else list(variables) @staticmethod def _surface_state_variables(model: Any) -> list[Any]: variables = getattr(model, "surface_state_variables", None) return [] if variables is None else list(variables) @staticmethod def _variable_output_name(variable: Any) -> str: for attr in ("output_name", "name", "id", "long_name"): value = getattr(variable, attr, None) if value: return str(value).replace("/", "_") msg = f"FABM variable {variable!r} does not expose a usable name" raise RuntimeError(msg) @staticmethod def _variable_outputs_enabled(variable: Any) -> bool: output = getattr(variable, "output", None) if output is None: return True return bool(output) def _dependencies(self) -> Iterable[Any]: model = self._require_model() for attr in ( "dependencies", "required_dependencies", "dependencies_unfulfilled", ): dependencies = getattr(model, attr, None) if dependencies is not None: return list(dependencies) return () def _find_dependency(self, name: str) -> Any | None: if name in self._dependency_cache: return self._dependency_cache[name] model = self._require_model() finder = getattr(model, "findDependency", None) if callable(finder): try: dependency = finder(name) except Exception: dependency = None if dependency is not None: self._dependency_cache[name] = dependency return dependency for dependency in self._dependencies(): dependency_name = getattr(dependency, "name", getattr(dependency, "id", "")) if str(dependency_name) == name: self._dependency_cache[name] = dependency return dependency self._dependency_cache[name] = None return None def _require_dependency(self, name: str) -> Any: dependency = self._find_dependency(name) if dependency is None: msg = f"FABM dependency {name!r} is not exposed by this model" raise KeyError(msg) return dependency @staticmethod def _assign_value(target: Any, value: float | np.ndarray) -> None: if hasattr(target, "value"): current = target.value if isinstance(current, np.ndarray) and isinstance(value, np.ndarray): np.copyto(current, value) else: target.value = value return if callable(getattr(target, "set", None)): target.set(value) return msg = f"FABM target {target!r} cannot receive a value" raise RuntimeError(msg) @staticmethod def _diagnostic_value(value: object, *, copy: bool) -> np.ndarray | float: if isinstance(value, np.ndarray): array = np.asarray(value, dtype=np.float64) if copy: return np.ascontiguousarray(array).copy() return array if isinstance(value, (int, float, np.floating)): return float(value) array = np.asarray(value, dtype=np.float64) if array.ndim == 0: return float(array) if copy: return np.ascontiguousarray(array).copy() return array @staticmethod def _dependency_is_missing(dependency: Any) -> bool: for attr in ("is_set", "is_fulfilled", "fulfilled", "ready"): value = getattr(dependency, attr, None) if value is not None: return not bool(value() if callable(value) else value) if bool(getattr(dependency, "missing", False)): return True return getattr(dependency, "value", None) is None def _call_get_rates_flags( self, method: Callable[..., Any], *, surface: bool, bottom: bool, time: float | None, ) -> Any: """Call getRates respecting pyfabm 3.0 surface/bottom kwargs.""" calls: tuple[tuple[tuple[float, ...], dict[str, object]], ...] if time is None: calls = ( ((), {"surface": surface, "bottom": bottom}), ((), {}), ) else: calls = ( ((float(time),), {"surface": surface, "bottom": bottom}), ((), {"t": float(time), "surface": surface, "bottom": bottom}), ((float(time),), {}), ((), {}), ) for args, kwargs in calls: try: return method(*args, **kwargs) except TypeError: continue return method(self._state)