Source code for masa.common.constraints.pctl
"""
Overview
--------
Probabilistic CTL (PCTL) style constraint monitor (simplified).
The current implementation mirrors the safety-style structure used elsewhere:
it accumulates unsafe occurrences based on a per-step cost function applied to
labels. While named "PCTL", this monitor presently behaves like a boolean
"safety so far" tracker under the local convention ``cost >= 0.5``.
Conceptually, a PCTL-style safety constraint might aim to bound the probability
of reaching unsafe states, e.g.:
.. math::
\\Pr(\\Diamond\\,\\mathsf{unsafe}) \\le \\alpha,
but note that this file's current implementation does not compute an explicit
probability estimate; it tracks whether any unsafe event occurred.
API Reference
-------------
"""
from __future__ import annotations
from typing import Any, Dict
from masa.common.constraints.base import Constraint, BaseConstraintEnv, CostFn
from masa.common.dummy import cost_fn as dummy_cost_fn
[docs]
class PCTL(Constraint):
"""Simplified PCTL-named monitor tracking whether any unsafe step occurred.
Args:
cost_fn: Mapping from label sets to scalar cost.
alpha: Threshold parameter stored for downstream use (not currently used
in the logic in this file).
Attributes:
cost_fn: Cost function ``labels -> float``.
alpha: User-specified parameter (reserved for probabilistic variants).
safe: True until an unsafe cost is observed.
step_cost: Most recent cost value.
total_unsafe: Count of unsafe steps (as floats).
"""
def __init__(self, cost_fn: CostFn, alpha: float):
self.cost_fn = cost_fn
self.alpha = alpha
[docs]
def reset(self):
"""Reset episode counters."""
self.safe = True
self.step_cost = 0.0
self.total_unsafe = 0.0
[docs]
def update(self, labels: Iterable[str]):
"""Update safety flags from the current label set.
Args:
labels: Iterable of atomic propositions.
"""
self.step_cost = self.cost_fn(labels)
self.total_unsafe += float(self.step_cost >= 0.5)
self.safe = self.safe and (not self.total_unsafe)
[docs]
def satisfied(self) -> bool:
"""Whether the episode remains safe so far."""
return self.safe
[docs]
def episode_metric(self) -> Dict[str, float]:
"""End-of-episode metrics.
Returns:
Dict containing:
- ``"cum_unsafe"``: count of unsafe steps,
- ``"satisfied"``: 1.0 if safe else 0.0.
"""
return {"cum_unsafe": float(self.total_unsafe), "satisfied": float(self.satisfied())}
[docs]
def step_metric(self) -> Dict[str, float]:
"""Per-step metrics.
Returns:
Dict containing:
- ``"cost"``: current step cost,
- ``"violation"``: 1.0 if ``cost >= 0.5`` else 0.0.
"""
return {"cost": self.step_cost, "violation": float(self.step_cost >= 0.5)}
@property
def constraint_type(self) -> str:
"""Stable identifier string: ``"pctl"``."""
return "pctl"
[docs]
class PCTLEnv(BaseConstraintEnv):
"""Gymnasium wrapper for the :class:`PCTL` monitor.
Args:
env: Base environment (must be a :class:`~masa.common.labelled_env.LabelledEnv`).
cost_fn: Cost function mapping labels to a scalar.
alpha: Threshold parameter stored on the monitor (see :class:`PCTL`).
**kw: Extra keyword arguments forwarded to :class:`BaseConstraintEnv`.
"""
def __init__(self, env: gym.Env, cost_fn: CostFn = dummy_cost_fn, alpha: float = 0.01, **kw):
super().__init__(env, PCTL(cost_fn=cost_fn, alpha=alpha), **kw)