Source code for masa.common.constraints.base

"""
Overview
--------

Base constraint interfaces and Gymnasium wrappers.

This module defines:

- :data:`CostFn`: a callable that maps a set/iterable of atomic proposition labels
  (strings) to a scalar cost.
- :class:`Constraint`: a protocol for stateful constraint monitors that can be
  reset and updated from labels at each environment step.
- :class:`BaseConstraintEnv`: a Gymnasium wrapper that enforces the convention
  that the wrapped environment is a :class:`~masa.common.labelled_env.LabelledEnv`
  and that the constraint monitor is updated using ``info["labels"]``.

The overall convention used throughout MASA is:

1. The base environment (or a wrapper) provides a *labelling function* that maps
   an observation/state to a set of atomic propositions ``labels``.
2. Each call to :meth:`gymnasium.Env.step` returns these labels in the ``info``
   dict under the key ``"labels"``.
3. Constraint monitors are updated as ``constraint.update(labels)``.
4. Constraint wrappers expose metrics for logging/training.

Mathematically, a (labelled) MDP is typically written as

.. math::

   \\mathcal{M} = (\\mathcal{S}, \\mathcal{A}, P, r, L),

where:

- :math:`\\mathcal{S}` is the state space,
- :math:`\\mathcal{A}` is the action space,
- :math:`P(s'\\mid s,a)` is the transition kernel,
- :math:`r(s,a,s')` is a reward signal,
- :math:`L : \\mathcal{S} \\to 2^{\\mathsf{AP}}` is a labelling function mapping
  states to sets of atomic propositions from a finite alphabet :math:`\\mathsf{AP}`.

A cost function then maps labels to a scalar:

.. math::

   c(s) \\triangleq \\mathrm{cost}(L(s)) \\in \\mathbb{R}.

API Reference
-------------

"""

from __future__ import annotations
from typing import Any, Dict, Iterable, Mapping, Protocol, Callable
from masa.common.labelled_env import LabelledEnv
from typing import Dict, Protocol, Any
import gymnasium as gym

CostFn = Callable[Iterable[str], float]

[docs] class Constraint(Protocol): """Protocol for stateful constraint monitors. A :class:`Constraint` is a *monitor* that consumes atomic proposition labels at each step and maintains internal state (e.g., cumulative cost, whether an LTL automaton is in an accepting/unsafe state, etc.). Implementations are intended to be lightweight and compatible with Gymnasium wrappers: call :meth:`reset` at episode start and :meth:`update` after each environment transition using the label set from ``info["labels"]``. **Required interface** Implementations should provide: - :meth:`reset`: clear any episode state. - :meth:`update`: incorporate the current label set. - :attr:`constraint_type`: a stable identifier string for logging/dispatch. **Metrics interface** The protocol declares: - :meth:`step_metric` - :meth:`episode_metric` """
[docs] def reset(self): """Reset any episode-dependent internal state."""
[docs] def update(self, labels: Iterable[str]): """Update internal state given the current set of labels. Args: labels: Iterable of atomic proposition strings active at the current step (typically taken from ``info["labels"]``). """
@property def constraint_type(self) -> str: """A stable identifier for the constraint (e.g., ``"cmdp"``, ``"ltl_safety"``)."""
[docs] def step_metric(self) -> Dict[str, float]: """Return per-step logging metrics. Metrics returned here should be: - cheap to compute, - non-destructive (do not mutate state), - meaningful at *any* time step. Examples include running cumulative cost, a per-step violation flag, a current probability estimate, etc. Returns: Dictionary of scalar metrics (values should be JSON/log friendly). """
[docs] def episode_metric(self) -> Dict[str, float]: """Return end-of-episode logging metrics. This is intended to summarize what matters for evaluation/logging at episode termination (terminated or truncated). Returns: Dictionary of scalar metrics (values should be JSON/log friendly). """
[docs] class BaseConstraintEnv(gym.Wrapper, Constraint): """Common base wrapper for constraint-aware environments. This wrapper enforces the MASA convention that the wrapped environment is a :class:`~masa.common.labelled_env.LabelledEnv` and provides ``info["labels"]`` as a ``set`` (or ``frozenset``) of atomic propositions at each step. The wrapper: 1. Delegates reset/step to the underlying environment. 2. Extracts ``labels = info.get("labels", set())``. 3. Validates that ``labels`` is a set-like container of strings. 4. Calls ``self._constraint.update(labels)``. Attributes: env: The wrapped Gymnasium environment (must be a :class:`LabelledEnv`). _constraint: The underlying constraint monitor. Raises: TypeError: If ``env`` is not an instance of :class:`LabelledEnv`. ValueError: If ``info["labels"]`` exists but is not a ``set``/``frozenset``. Notes: The properties :attr:`label_fn` and :attr:`cost_fn` are convenience accessors for downstream algorithms. Depending on how wrappers are composed, these may be ``None``. """ def __init__(self, env: gym.Env, constraint: Constraint, **kw): """Initialize the wrapper. Args: env: Base environment. Must already be wrapped as a :class:`~masa.common.labelled_env.LabelledEnv` so that step/reset provide label sets in ``info["labels"]``. constraint: A constraint monitor implementing :class:`Constraint`. **kw: Unused extra keyword arguments (kept for wrapper compatibility). Raises: TypeError: If ``env`` is not a :class:`LabelledEnv`. """ if not isinstance(env, LabelledEnv): raise TypeError( f"{self.__class__.__name__} must wrap a LabelledEnv, " f"but got {type(env).__name__}. " "Please wrap your environment with LabelledEnv before applying a constraint wrapper." ) super().__init__(env) self._constraint = constraint
[docs] def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None): """Reset environment and constraint state. This calls ``env.reset(...)`` and then resets and updates the constraint using the initial label set in ``info["labels"]``. Args: seed: Optional RNG seed forwarded to the base environment. options: Optional reset options forwarded to the base environment. Returns: A tuple ``(obs, info)`` following the Gymnasium API. Raises: ValueError: If ``info["labels"]`` is present but not a set/frozenset. """ obs, info = self.env.reset(seed=seed, options=options) self._constraint.reset() labels = info.get("labels", set()) if not isinstance(labels, (set, frozenset)): raise ValueError( f"Expected 'labels' in info to be a set of atomic propositions, got {type(labels).__name__}" ) self._constraint.update(labels) return obs, info
[docs] def step(self, action: Any): """Step environment and update constraint from labels. Args: action: Action to pass to the underlying environment. Returns: A 5-tuple ``(obs, reward, terminated, truncated, info)`` following the Gymnasium API. Raises: ValueError: If ``info["labels"]`` is present but not a set/frozenset. """ obs, reward, terminated, truncated, info = self.env.step(action) labels = info.get("labels", set()) if not isinstance(labels, (set, frozenset)): raise ValueError( f"Expected 'labels' in info to be a set of atomic propositions, got {type(labels).__name__}" ) self._constraint.update(labels) return obs, reward, terminated, truncated, info
@property def cost_fn(self): """Expose the cost function if available. Returns: The underlying cost function if present on the wrapped stack, else ``None``. """ if self._constraint is not None: return getattr(self.env._constraint, "cost_fn", None) return None @property def label_fn(self): """Expose the labelling function if available. Returns: The environment labelling function if present, else ``None``. """ return getattr(self.env, "label_fn", None) @property def constraint_type(self) -> str: """Constraint identifier forwarded from the underlying monitor.""" return self._constraint.constraint_type
[docs] def constraint_step_metrics(self) -> Dict[str, float]: """Return per-step metrics from the underlying constraint. Returns: Dictionary of scalar metrics. """ return self._constraint.step_metric()
[docs] def constraint_episode_metrics(self) -> Dict[str, float]: """Return end-of-episode metrics from the underlying constraint. Returns: Dictionary of scalar metrics. """ return self._constraint.episode_metric()