"""Thin Python wrapper around ``pyfabm.Model``."""
from __future__ import annotations
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import numpy as np
from numpy.typing import NDArray
from pygotm.fabm.state import fabm_state_variable_names
FloatArray = NDArray[np.float64]
__all__ = ["FABMEngine", "FABMOutputKind", "FABMOutputSpec"]
FABMOutputKind = Literal["z", "scalar"]
[docs]
@dataclass(frozen=True, slots=True)
class FABMOutputSpec:
"""Metadata for one FABM output variable exposed by pyfabm."""
name: str
kind: FABMOutputKind
units: str
long_name: str
[docs]
class FABMEngine:
"""Manage one pyfabm model while pyGOTM owns transport and timestepping."""
def __init__(
self,
config_path: str | Path,
*,
model_factory: Callable[[str], Any] | None = None,
) -> None:
self.config_path = Path(config_path)
self._model_factory = model_factory
self.model: Any | None = None
self._state = np.zeros(0, dtype=np.float64)
self._rates = np.zeros(0, dtype=np.float64)
self._dependency_buffers: dict[str, FloatArray] = {}
self._dependency_cache: dict[str, Any | None] = {}
[docs]
def initialize(
self,
nlev: int | None = None,
h_col: np.ndarray | None = None,
*,
skip_start: bool = False,
) -> None:
"""Construct and start the pyfabm model for a 1-D water column.
*nlev* is the number of grid cells (pyGOTM ``nlev``). When provided,
the model is created with ``shape=(nlev,)`` so pyfabm allocates 1-D
arrays of the correct size. *h_col* (length *nlev*, bottom→surface) is
the initial cell-thickness array required by pyfabm before ``start()``.
"""
if not self.config_path.is_file():
msg = f"FABM model YAML not found: {self.config_path}"
raise RuntimeError(msg)
factory = self._model_factory
if factory is None:
try:
import pyfabm
except ImportError as exc:
msg = (
"FABM is enabled, but pyfabm could not be imported. "
"Install FABM's Python front-end and ensure its compiled "
"FABM libraries are discoverable before running this case."
)
raise RuntimeError(msg) from exc
factory = pyfabm.Model
try:
if nlev is not None:
self.model = factory(str(self.config_path), shape=(nlev,)) # type: ignore[call-arg]
else:
self.model = factory(str(self.config_path))
except Exception as exc:
msg = f"failed to initialize pyfabm.Model from {self.config_path}: {exc}"
raise RuntimeError(msg) from exc
self._dependency_cache.clear()
# pyfabm 3.0: cell_thickness is a write-only property setter — use direct
# assignment so pyfabm registers the array before start() / getRates().
if nlev is not None and h_col is not None:
try:
arr = np.ascontiguousarray(h_col, dtype=np.float64)
self.model.cell_thickness = arr
except AttributeError:
pass
if not skip_start:
self._start_model()
self._state = self._read_model_state()
self._rates = np.zeros_like(self._state)
[docs]
def has_dependency(self, name: str) -> bool:
"""Return whether the pyfabm model exposes a dependency by *name*."""
return self._find_dependency(name) is not None
[docs]
def set_dependency(
self,
name: str,
value: float | np.ndarray,
) -> None:
"""Set one FABM dependency from a scalar or contiguous float64 array."""
dependency = self._require_dependency(name)
self._set_dependency_value(name, dependency, value)
[docs]
def set_dependency_if_present(
self,
name: str,
value: float | np.ndarray,
) -> bool:
"""Set one optional FABM dependency and report whether it exists."""
dependency = self._find_dependency(name)
if dependency is None:
return False
self._set_dependency_value(name, dependency, value)
return True
def _set_dependency_value(
self,
name: str,
dependency: Any,
value: float | np.ndarray,
) -> None:
if isinstance(value, np.ndarray):
array = np.ascontiguousarray(value, dtype=np.float64)
existing = self._dependency_buffers.get(name)
if existing is None or existing.shape != array.shape:
existing = np.empty_like(array)
self._dependency_buffers[name] = existing
np.copyto(existing, array)
self._assign_value(dependency, existing)
return
self._assign_value(dependency, float(value))
[docs]
def set_state(self, state: np.ndarray) -> None:
"""Replace the pyGOTM-owned FABM state buffer."""
array = np.ascontiguousarray(state, dtype=np.float64)
if self._state.shape != array.shape:
self._state = np.empty_like(array)
np.copyto(self._state, array)
if self.model is None:
return
if hasattr(self.model, "state"):
self.model.state[...] = self._state
return
variables = self._state_variables()
if variables:
flat_state = self._state.reshape((self._state.shape[0], -1))
for index, variable in enumerate(variables[: flat_state.shape[0]]):
self._assign_value(variable, flat_state[index])
@property
def state(self) -> np.ndarray:
"""Return the current mutable FABM state buffer."""
return self._state
[docs]
def get_rates(
self,
*,
surface: bool = True,
bottom: bool = True,
time: float | None = None,
) -> np.ndarray:
"""Return FABM source rates for each layer.
When *surface=True* (default for backwards compat) the returned rates
include the air-sea surface exchange distributed over ALL layers by
pyfabm — which is physically wrong for a 1D column. Pass
``surface=False, bottom=False`` to get bulk-only reaction rates
that can safely be applied to every layer, then handle boundary
exchange explicitly on the top/bottom layers.
"""
model = self._require_model()
get_fn = getattr(model, "getRates", None) or getattr(model, "get_rates", None)
if get_fn is None:
msg = "pyfabm model does not expose getRates()"
raise RuntimeError(msg)
rates = self._call_get_rates_flags(
get_fn,
surface=surface,
bottom=bottom,
time=time,
)
array = np.ascontiguousarray(rates, dtype=np.float64)
if self._rates.shape != array.shape:
self._rates = np.empty_like(array)
np.copyto(self._rates, array)
return self._rates
[docs]
def get_vertical_movement(self) -> np.ndarray | None:
"""Return vertical movement velocities for all interior state variables.
Returns an array of shape ``(n_state_vars, nlev)`` in m s⁻¹ (positive
upward), or ``None`` if the underlying model does not expose this API.
Must be called after ``get_rates()`` has updated the model's internal
diagnostics.
"""
model = self._require_model()
fn = getattr(model, "get_vertical_movement", None)
if not callable(fn):
return None
try:
result = fn()
if result is None:
return None
arr = np.ascontiguousarray(result, dtype=np.float64)
return arr if arr.ndim == 2 else None
except Exception:
return None
[docs]
def diagnostics(self) -> dict[str, np.ndarray | float]:
"""Return current FABM diagnostic values as NumPy arrays or floats."""
model = self._require_model()
diagnostics: dict[str, np.ndarray | float] = {}
raw = getattr(model, "diagnostics", None)
if isinstance(raw, dict):
for name, value in raw.items():
diagnostics[str(name)] = self._diagnostic_value(value, copy=True)
for variable in self._diagnostic_variables():
name = str(getattr(variable, "name", getattr(variable, "id", "")))
if not name:
continue
diagnostics[name] = self._diagnostic_value(
getattr(variable, "value", 0.0),
copy=True,
)
return diagnostics
[docs]
def diagnostic(
self,
name: str,
*,
copy: bool = True,
) -> np.ndarray | float | None:
"""Return one FABM diagnostic by name, or ``None`` when unavailable."""
model = self._require_model()
raw = getattr(model, "diagnostics", None)
if isinstance(raw, dict):
value = raw.get(name)
if value is not None:
return self._diagnostic_value(value, copy=copy)
for raw_name, raw_value in raw.items():
if str(raw_name) == name:
return self._diagnostic_value(raw_value, copy=copy)
for variable in self._diagnostic_variables():
variable_name = str(getattr(variable, "name", getattr(variable, "id", "")))
if variable_name == name:
return self._diagnostic_value(
getattr(variable, "value", 0.0),
copy=copy,
)
return None
[docs]
def unresolved_dependencies(self) -> tuple[str, ...]:
"""Return a best-effort list of pyfabm dependencies still unset."""
missing: list[str] = []
for dependency in self._dependencies():
name = str(getattr(dependency, "name", getattr(dependency, "id", "")))
if not name:
name = repr(dependency)
if self._dependency_is_missing(dependency):
missing.append(name)
return tuple(dict.fromkeys(missing))
[docs]
def start(self) -> None:
"""Call model.start() after dependencies are set; update state buffers.
Use this when ``initialize()`` was called with ``skip_start=True`` to
defer start until the caller has supplied initial dependency values.
"""
self._start_model()
self._state = self._read_model_state()
self._rates = np.zeros_like(self._state)
[docs]
def state_variable_names(self) -> tuple[str, ...]:
"""Return state-variable names exposed by the wrapped pyfabm model."""
return fabm_state_variable_names(self._require_model())
[docs]
def output_variable_specs(self) -> tuple[FABMOutputSpec, ...]:
"""Return state and enabled diagnostic variables for NetCDF output."""
model = self._require_model()
specs: list[FABMOutputSpec] = []
seen: set[str] = set()
def add(variable: Any, kind: FABMOutputKind) -> None:
name = self._variable_output_name(variable)
if name in seen:
msg = f"duplicate FABM output variable {name!r}"
raise RuntimeError(msg)
seen.add(name)
specs.append(
FABMOutputSpec(
name=name,
kind=kind,
units=str(getattr(variable, "units", "") or ""),
long_name=str(getattr(variable, "long_name", "") or ""),
)
)
for variable in self._interior_state_variables(model):
add(variable, "z")
for variable in self._bottom_state_variables(model):
add(variable, "scalar")
for variable in self._surface_state_variables(model):
add(variable, "scalar")
horizontal_diagnostics = [
variable
for variable in self._horizontal_diagnostic_variables()
if self._variable_outputs_enabled(variable)
]
horizontal_diagnostic_names = {
self._variable_output_name(variable) for variable in horizontal_diagnostics
}
for variable in self._diagnostic_variables():
if not self._variable_outputs_enabled(variable):
continue
if self._variable_output_name(variable) in horizontal_diagnostic_names:
continue
add(variable, "z")
for variable in horizontal_diagnostics:
add(variable, "scalar")
return tuple(specs)
def _require_model(self) -> Any:
if self.model is None:
msg = "FABMEngine.initialize() has not been called"
raise RuntimeError(msg)
return self.model
def _start_model(self) -> None:
model = self._require_model()
if not self._check_ready():
missing = self.unresolved_dependencies()
deps = ", ".join(missing)
msg = f"pyfabm model is not ready; unresolved dependencies: {deps}"
raise RuntimeError(msg)
start = getattr(model, "start", None)
if callable(start):
try:
start()
except Exception as exc:
msg = f"pyfabm model.start() failed: {exc}"
raise RuntimeError(msg) from exc
def _check_ready(self) -> bool:
model = self._require_model()
check_ready = getattr(model, "checkReady", None)
if not callable(check_ready):
return True
for args, kwargs in (((), {}), ((), {"stop": False}), ((False,), {})):
try:
return bool(check_ready(*args, **kwargs))
except TypeError:
continue
except Exception as exc:
msg = f"pyfabm model readiness check failed: {exc}"
raise RuntimeError(msg) from exc
return True
def _read_model_state(self) -> FloatArray:
model = self._require_model()
if hasattr(model, "state"):
return np.ascontiguousarray(model.state, dtype=np.float64)
variables = self._state_variables()
if not variables:
return np.zeros(0, dtype=np.float64)
values = [
np.asarray(getattr(variable, "value", 0.0), dtype=np.float64)
for variable in variables
]
return np.ascontiguousarray(np.stack(values), dtype=np.float64)
def _state_variables(self) -> list[Any]:
model = self._require_model()
variables = getattr(model, "state_variables", None)
if variables is None:
variables = getattr(model, "stateVariables", None)
if variables is None:
return []
return list(variables)
def _diagnostic_variables(self) -> list[Any]:
model = self._require_model()
for attr in ("diagnostic_variables", "diagnosticVariables"):
variables = getattr(model, attr, None)
if variables is not None:
return list(variables)
return []
def _horizontal_diagnostic_variables(self) -> list[Any]:
model = self._require_model()
for attr in (
"horizontal_diagnostic_variables",
"horizontalDiagnosticVariables",
):
variables = getattr(model, attr, None)
if variables is not None:
return list(variables)
return []
@staticmethod
def _interior_state_variables(model: Any) -> list[Any]:
for attr in (
"interior_state_variables",
"bulk_state_variables",
"state_variables",
"stateVariables",
):
variables = getattr(model, attr, None)
if variables is not None:
return list(variables)
return []
@staticmethod
def _bottom_state_variables(model: Any) -> list[Any]:
variables = getattr(model, "bottom_state_variables", None)
return [] if variables is None else list(variables)
@staticmethod
def _surface_state_variables(model: Any) -> list[Any]:
variables = getattr(model, "surface_state_variables", None)
return [] if variables is None else list(variables)
@staticmethod
def _variable_output_name(variable: Any) -> str:
for attr in ("output_name", "name", "id", "long_name"):
value = getattr(variable, attr, None)
if value:
return str(value).replace("/", "_")
msg = f"FABM variable {variable!r} does not expose a usable name"
raise RuntimeError(msg)
@staticmethod
def _variable_outputs_enabled(variable: Any) -> bool:
output = getattr(variable, "output", None)
if output is None:
return True
return bool(output)
def _dependencies(self) -> Iterable[Any]:
model = self._require_model()
for attr in (
"dependencies",
"required_dependencies",
"dependencies_unfulfilled",
):
dependencies = getattr(model, attr, None)
if dependencies is not None:
return list(dependencies)
return ()
def _find_dependency(self, name: str) -> Any | None:
if name in self._dependency_cache:
return self._dependency_cache[name]
model = self._require_model()
finder = getattr(model, "findDependency", None)
if callable(finder):
try:
dependency = finder(name)
except Exception:
dependency = None
if dependency is not None:
self._dependency_cache[name] = dependency
return dependency
for dependency in self._dependencies():
dependency_name = getattr(dependency, "name", getattr(dependency, "id", ""))
if str(dependency_name) == name:
self._dependency_cache[name] = dependency
return dependency
self._dependency_cache[name] = None
return None
def _require_dependency(self, name: str) -> Any:
dependency = self._find_dependency(name)
if dependency is None:
msg = f"FABM dependency {name!r} is not exposed by this model"
raise KeyError(msg)
return dependency
@staticmethod
def _assign_value(target: Any, value: float | np.ndarray) -> None:
if hasattr(target, "value"):
current = target.value
if isinstance(current, np.ndarray) and isinstance(value, np.ndarray):
np.copyto(current, value)
else:
target.value = value
return
if callable(getattr(target, "set", None)):
target.set(value)
return
msg = f"FABM target {target!r} cannot receive a value"
raise RuntimeError(msg)
@staticmethod
def _diagnostic_value(value: object, *, copy: bool) -> np.ndarray | float:
if isinstance(value, np.ndarray):
array = np.asarray(value, dtype=np.float64)
if copy:
return np.ascontiguousarray(array).copy()
return array
if isinstance(value, (int, float, np.floating)):
return float(value)
array = np.asarray(value, dtype=np.float64)
if array.ndim == 0:
return float(array)
if copy:
return np.ascontiguousarray(array).copy()
return array
@staticmethod
def _dependency_is_missing(dependency: Any) -> bool:
for attr in ("is_set", "is_fulfilled", "fulfilled", "ready"):
value = getattr(dependency, attr, None)
if value is not None:
return not bool(value() if callable(value) else value)
if bool(getattr(dependency, "missing", False)):
return True
return getattr(dependency, "value", None) is None
def _call_get_rates_flags(
self,
method: Callable[..., Any],
*,
surface: bool,
bottom: bool,
time: float | None,
) -> Any:
"""Call getRates respecting pyfabm 3.0 surface/bottom kwargs."""
calls: tuple[tuple[tuple[float, ...], dict[str, object]], ...]
if time is None:
calls = (
((), {"surface": surface, "bottom": bottom}),
((), {}),
)
else:
calls = (
((float(time),), {"surface": surface, "bottom": bottom}),
((), {"t": float(time), "surface": surface, "bottom": bottom}),
((float(time),), {}),
((), {}),
)
for args, kwargs in calls:
try:
return method(*args, **kwargs)
except TypeError:
continue
return method(self._state)