Source code for masa.common.constraints.cmdp
"""
Overview
--------
Cumulative-cost constraints in the CMDP style.
This module provides a simple *budgeted cumulative cost* constraint, commonly
used to model constrained MDPs (CMDPs). At each step a cost is computed from the
current label set:
.. math::
c_t \\triangleq c(L(s_t)),
and accumulated over the episode:
.. math::
C_T \\triangleq \\sum_{t=0}^{T-1} c_t.
The episode is considered *satisfied* when:
.. math::
C_T \\le B,
where :math:`B` is the user-specified budget.
The wrapper :class:`CumulativeCostEnv` updates the monitor each step by reading
``info["labels"]`` from the wrapped :class:`~masa.common.labelled_env.LabelledEnv`.
API Reference
-------------
"""
from __future__ import annotations
from typing import Any, Dict
from masa.common.constraints.base import Constraint, BaseConstraintEnv, CostFn
from masa.common.dummy import cost_fn as dummy_cost_fn
[docs]
class CumulativeCost(Constraint):
"""CMDP-style cumulative cost constraint with a fixed budget.
The monitor keeps:
- ``step_cost``: the instantaneous cost :math:`c_t`,
- ``total``: the accumulated cost :math:`C_T`.
Args:
cost_fn: Mapping from a label set to a scalar cost.
budget: Episode budget :math:`B`. The episode is satisfied if
``total <= budget``.
Attributes:
cost_fn: The cost function ``labels -> float``.
budget: Maximum allowed cumulative cost.
total: Running cumulative cost for the current episode.
step_cost: Cost at the most recent update.
"""
def __init__(self, cost_fn: CostFn, budget: float):
self.cost_fn = cost_fn
self.budget = budget
[docs]
def reset(self):
"""Reset episode counters."""
self.total = 0.0
self.step_cost = 0.0
[docs]
def update(self, labels: Iterable[str]):
"""Update costs from the current label set.
Args:
labels: Iterable of atomic proposition strings for the current step.
"""
self.step_cost = self.cost_fn(labels)
self.total += self.step_cost
[docs]
def satisfied(self) -> bool:
"""Check whether the episode remains within budget.
Returns:
``True`` iff ``total <= budget``.
"""
return self.total <= self.budget
[docs]
def episode_metric(self) -> Dict[str, float]:
"""End-of-episode metrics.
Returns:
A dict containing:
- ``"cum_cost"``: cumulative cost over the episode,
- ``"satisfied"``: ``1.0`` if within budget else ``0.0``.
"""
return {"cum_cost": self.total, "satisfied": float(self.satisfied())}
[docs]
def step_metric(self) -> Dict[str, float]:
"""Per-step metrics.
Returns:
A dict containing:
- ``"cost"``: instantaneous cost,
- ``"violation"``: 1.0 if the instantaneous cost is considered unsafe
under the local convention ``cost >= 0.5``,
- ``"cum_cost"``: running total.
"""
return {"cost": self.step_cost, "violation": float(self.step_cost >= 0.5), "cum_cost": self.total}
@property
def constraint_type(self) -> str:
"""Stable identifier string: ``"cmdp"``."""
return "cmdp"
[docs]
class CumulativeCostEnv(BaseConstraintEnv):
"""Gymnasium wrapper that attaches :class:`CumulativeCost` to an environment.
Args:
env: Base environment (must be a :class:`~masa.common.labelled_env.LabelledEnv`).
cost_fn: Cost function mapping label sets to float cost.
budget: Cumulative cost budget :math:`B`.
**kw: Extra keyword arguments forwarded to :class:`BaseConstraintEnv`.
"""
def __init__(self, env: gym.Env, cost_fn: CostFn = dummy_cost_fn, budget: float = 20.0, **kw):
super().__init__(env, CumulativeCost(cost_fn=cost_fn, budget=budget), **kw)