Source code for masa.common.constraints.reach_avoid

"""
Overview
--------

Reach-avoid constraint monitor.

A reach-avoid property requires eventually reaching a target set while never
visiting an unsafe set. Using atomic propositions:

- ``reach_label`` indicates a target condition,
- ``avoid_label`` indicates an unsafe condition.

A typical reach-avoid specification can be described as:

.. math::

   (\\neg\\mathsf{avoid})\\ \\mathcal{U}\\ \\mathsf{reach}

i.e., "avoid is never true until reach becomes true" (informally).

This implementation tracks:

- ``reached``: whether ``reach_label`` has been observed at least once,
- ``violated``: whether ``avoid_label`` has been observed at least once,
- ``satisfied``: whether the property has been satisfied so far.

API Reference
-------------
"""

from __future__ import annotations
from typing import Any, Dict, Iterable
from masa.common.constraints.base import Constraint, BaseConstraintEnv

[docs] class ReachAvoid(Constraint): """Reach target label set while avoiding unsafe label set. At each step, given a label set ``labels``: - reaching condition: ``reach = (reach_label in labels)`` - avoiding condition: ``avoid_ok = (avoid_label not in labels)`` State updates: - ``reached`` becomes true once reach is observed, - ``violated`` becomes true once avoid is violated, - ``satisfied`` becomes true once reached is true and violated is false. Args: avoid_label: Atomic proposition name indicating unsafe/avoid condition. reach_label: Atomic proposition name indicating the target condition. Attributes: avoid_label: Name of unsafe label. reach_label: Name of target label. reached: Whether target has been reached at least once. violated: Whether unsafe has been observed at least once. satisfied: Whether reach-avoid has been satisfied so far. """ def __init__(avoid_label: str, reach_label: str): self.avoid_label = avoid_label self.reach_label = reach_label
[docs] def reset(self): """Reset episode flags.""" self.reached = False self.violated = False self.satisfied = False
[docs] def update(self, labels: Iterable[str]): """Update reach/avoid flags from the current label set. Args: labels: Iterable of atomic propositions for the current step. """ self.reach = self.reach_label in labels self.avoid = self.avoid_label not in labels self.reached = self.reached or self.reach self.violated = self.violated or bool(not self.avoid) self.satisfied = self.satisfied or (self.reached and bool(not self.violated))
[docs] def episode_metric(self) -> Dict[str, float]: """End-of-episode metrics. Returns: Dict containing: - ``"reached"``: whether the target was ever reached, - ``"violated"``: whether unsafe was ever visited, - ``"satisfied"``: 1.0 if satisfied else 0.0. """ return {"reached": self.reached, "violated": self.violated, "satisfied": float(self.satisfied)}
[docs] def step_metric(self) -> Dict[str, float]: """Per-step metrics. Returns: Dict containing: - ``"cost"``: 1.0 if avoid is violated at this step else 0.0, - ``"violation"``: 1.0 if avoid violated else 0.0, - ``"reached"``: 1.0 if reach holds at this step else 0.0. """ return {"cost": float(not self.avoid), "violation": bool(not self.avoid), "reached": self.reach}
@property def constraint_type(self) -> str: """Stable identifier string: ``"reach_avoid"``.""" return "reach_avoid"
[docs] class ReachAvoidEnv(BaseConstraintEnv): """Gymnasium wrapper for the :class:`ReachAvoid` monitor. Args: env: Base environment (must be a :class:`~masa.common.labelled_env.LabelledEnv`). avoid_label: Atomic proposition name for unsafe/avoid condition. reach_label: Atomic proposition name for target condition. **kw: Extra keyword arguments forwarded to :class:`BaseConstraintEnv`. """ def __init__(self, env: gym.Env, avoid_label: str = "unsafe", reach_label: str = "target", **kw): super().__init__(env, ReachAvoid(avoid_label=avoid_label, reach_label=reach_label), **kw)