Source code for masa.common.constraints.prob

"""
Overview
--------

Undiscounted probabilistic safety constraint.

This monitor tracks the empirical fraction of unsafe steps in an episode and
requires it to be at most ``alpha``.

Let:

- :math:`u_t \\in \\{0,1\\}` indicate whether step :math:`t` is unsafe,
  computed from labels by thresholding a cost function:

  .. math::

     u_t = \\mathbf{1}[\\mathrm{cost}(L(s_t)) \\ge 0.5].

- Then the empirical unsafe fraction after :math:`T` steps is:

  .. math::

     \\hat{p}_{\\text{unsafe}} = \\frac{1}{T} \\sum_{t=0}^{T-1} u_t.

The episode is satisfied iff:

.. math::

   \\hat{p}_{\\text{unsafe}} \\le \\alpha.

API Reference
-------------
"""

from __future__ import annotations
from typing import Any, Dict
from masa.common.constraints.base import Constraint, BaseConstraintEnv, CostFn
from masa.common.dummy import cost_fn as dummy_cost_fn

[docs] class ProbabilisticSafety(Constraint): """Undiscounted probabilistic constraint based on unsafe-step frequency. Args: cost_fn: Mapping from a label set to a scalar cost. alpha: Allowed maximum fraction of unsafe steps in an episode. Attributes: total: Number of steps observed so far. total_unsafe: Number of steps considered unsafe so far. step_cost: Most recent cost. """ def __init__(self, cost_fn: CostFn, alpha: float): self.cost_fn = cost_fn self.alpha = alpha
[docs] def reset(self): """Reset episode counters.""" self.total = 0 self.total_unsafe = 0.0 self.step_cost = 0.0
[docs] def update(self, labels: Iterable[str]): """Update counters from the current label set. Args: labels: Iterable of atomic propositions for the current step. """ self.step_cost = self.cost_fn(labels) self.total_unsafe += float(self.step_cost >= 0.5) self.total += 1
[docs] def prob_unsafe(self) -> float: """Return the empirical fraction of unsafe steps. Returns: ``total_unsafe / total``. """ if not self.total: return 0.0 else: return self.total_unsafe / self.total
[docs] def satisfied(self) -> bool: """Check whether the unsafe fraction is within the threshold.""" return self.prob_unsafe() <= self.alpha
[docs] def episode_metric(self) -> Dict[str, float]: """End-of-episode metrics. Returns: Dict containing: - ``"cum_unsafe"``: count of unsafe steps, - ``"p_unsafe"``: proportion of unsafe states in the current trace, - ``"satisfied"``: 1.0 if p_unsafe <= self.alpha else 0.0. """ return {"cum_unsafe": float(self.total_unsafe), "p_unsafe": self.prob_unsafe(), "satisfied": float(self.satisfied())}
[docs] def step_metric(self) -> Dict[str, float]: """Per-step metrics. Returns: Dict containing: - ``"cost"``: current step cost, - ``"violation"``: 1.0 if ``cost >= 0.5`` else 0.0. """ return {"cost": self.step_cost, "violation": float(self.step_cost >= 0.5)}
@property def constraint_type(self) -> str: """Stable identifier string: ``"prob"``.""" return "prob"
[docs] class ProbabilisticSafetyEnv(BaseConstraintEnv): """Gymnasium wrapper for :class:`ProbabilisticSafety`. Args: env: Base environment (must be a :class:`~masa.common.labelled_env.LabelledEnv`). cost_fn: Cost function mapping labels to a scalar. alpha: Allowed maximum unsafe-step fraction. **kw: Extra keyword arguments forwarded to :class:`BaseConstraintEnv`. """ def __init__(self, env: gym.Env, cost_fn: CostFn = dummy_cost_fn, alpha: float = 0.01, **kw): super().__init__(env, ProbabilisticSafety(cost_fn=cost_fn, alpha=alpha), **kw)