Step-wise Probabilistic Constraint

Overview

Undiscounted probabilistic safety constraint.

This monitor tracks the empirical fraction of unsafe steps in an episode and requires it to be at most alpha.

Let:

  • \(u_t \in \{0,1\}\) indicate whether step \(t\) is unsafe, computed from labels by thresholding a cost function:

    \[u_t = \mathbf{1}[\mathrm{cost}(L(s_t)) \ge 0.5].\]
  • Then the empirical unsafe fraction after \(T\) steps is:

    \[\hat{p}_{\text{unsafe}} = \frac{1}{T} \sum_{t=0}^{T-1} u_t.\]

The episode is satisfied iff:

\[\hat{p}_{\text{unsafe}} \le \alpha.\]

API Reference

class masa.common.constraints.prob.ProbabilisticSafety(cost_fn: Callable[[Iterable[str]], float], alpha: float)[source]

Bases: Constraint

Undiscounted probabilistic constraint based on unsafe-step frequency.

Parameters:
  • cost_fn – Mapping from a label set to a scalar cost.

  • alpha – Allowed maximum fraction of unsafe steps in an episode.

Variables:
  • total – Number of steps observed so far.

  • total_unsafe – Number of steps considered unsafe so far.

  • step_cost – Most recent cost.

reset()[source]

Reset episode counters.

update(labels: Iterable[str])[source]

Update counters from the current label set.

Parameters:

labels – Iterable of atomic propositions for the current step.

prob_unsafe() float[source]

Return the empirical fraction of unsafe steps.

Returns:

total_unsafe / total.

satisfied() bool[source]

Check whether the unsafe fraction is within the threshold.

episode_metric() Dict[str, float][source]

End-of-episode metrics.

Returns:

  • "cum_unsafe": count of unsafe steps,

  • "p_unsafe": proportion of unsafe states in the current trace,

  • "satisfied": 1.0 if p_unsafe <= self.alpha else 0.0.

Return type:

Dict containing

step_metric() Dict[str, float][source]

Per-step metrics.

Returns:

  • "cost": current step cost,

  • "violation": 1.0 if cost >= 0.5 else 0.0.

Return type:

Dict containing

property constraint_type: str

"prob".

Type:

Stable identifier string

class masa.common.constraints.prob.ProbabilisticSafetyEnv(env: gym.Env, cost_fn: CostFn = dummy_cost_fn, alpha: float = 0.01, **kw)[source]

Bases: BaseConstraintEnv

Gymnasium wrapper for ProbabilisticSafety.

Parameters:
  • env – Base environment (must be a LabelledEnv).

  • cost_fn – Cost function mapping labels to a scalar.

  • alpha – Allowed maximum unsafe-step fraction.

  • **kw – Extra keyword arguments forwarded to BaseConstraintEnv.

Initialize the wrapper.

Parameters:
  • env – Base environment. Must already be wrapped as a LabelledEnv so that step/reset provide label sets in info["labels"].

  • constraint – A constraint monitor implementing Constraint.

  • **kw – Unused extra keyword arguments (kept for wrapper compatibility).

Raises:

TypeError – If env is not a LabelledEnv.