Source code for masa.common.constraints.multi_agent.cmg

"""
Overview
--------

Constraint monitors for labelled PettingZoo parallel environments.

This module provides a constrained Markov game monitor based on cumulative
cost budgets. Each agent receives its own label set through
``infos[agent]["labels"]`` and incurs a step cost via ``cost_fn(labels)``.
Budgets are then evaluated over subsets of agents:

.. math::

   C_t^{(i)} = \mathrm{cost}(L_t^{(i)}),

.. math::

   B_t^{(k)} = \sum_{i \in \mathcal{G}_k} C_t^{(i)},

where :math:`\mathcal{G}_k` is the subset of agents assigned to budget
:math:`k`. Budgets may overlap, so the same agent cost can contribute to
multiple budget totals.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from collections.abc import Iterable, Mapping, Sequence

from pettingzoo import ParallelEnv

from masa.common.constraints.base import CostFn
from masa.common.dummy import cost_fn as dummy_cost_fn
from masa.common.labelled_pz_env import LabelledParallelEnv


[docs] @dataclass(frozen=True, slots=True) class Budget: """Shared cumulative-cost budget over a subset of agents. Args: amount: Maximum allowed cumulative cost for this budget. agents: Subset of agents from ``env.possible_agents`` covered by the budget. name: Optional metric prefix. If omitted, a generated name is used. Notes: Agent memberships are deduplicated while preserving order. Budgets may overlap, so a single agent may contribute to more than one budget. """ amount: float agents: tuple[str, ...] name: str | None = None def __post_init__(self): agents = tuple(dict.fromkeys(self.agents)) if not agents: raise ValueError("Budget agents must be non-empty.") object.__setattr__(self, "amount", float(self.amount)) object.__setattr__(self, "agents", agents)
[docs] class ConstrainedMarkovGame: """Cumulative-cost monitor for a labelled parallel PettingZoo environment.""" def __init__( self, possible_agents: Sequence[str], budgets: Sequence[Budget], cost_fn: CostFn = dummy_cost_fn, ): self.possible_agents = tuple(possible_agents) self._possible_agent_set = set(self.possible_agents) self.budgets = tuple(budgets) self.cost_fn = cost_fn if not self.possible_agents: raise ValueError("possible_agents must be non-empty.") if not self.budgets: raise ValueError("budgets must be non-empty.") self._budget_keys = [] seen_budget_keys = set() for index, budget in enumerate(self.budgets): invalid_agents = set(budget.agents) - self._possible_agent_set if invalid_agents: raise ValueError( f"Budget agents must exist in possible_agents. Invalid agents: {sorted(invalid_agents)}" ) key = budget.name or f"budget_{index}" if key in seen_budget_keys: raise ValueError(f"Budget names must be unique. Duplicate name: {key}") seen_budget_keys.add(key) self._budget_keys.append(key) self._budget_keys = tuple(self._budget_keys) self.reset()
[docs] def reset(self): """Reset per-agent and per-budget cumulative costs for a new episode.""" self.agent_step_costs = {agent: 0.0 for agent in self.possible_agents} self.agent_totals = {agent: 0.0 for agent in self.possible_agents} self.budget_step_costs = {key: 0.0 for key in self._budget_keys} self.budget_totals = {key: 0.0 for key in self._budget_keys}
[docs] def update(self, labels_by_agent: Mapping[str, Iterable[str]]): """Update the monitor from a mapping of agent ids to active labels.""" unknown_agents = set(labels_by_agent) - self._possible_agent_set if unknown_agents: raise ValueError(f"Unknown agents in labels_by_agent: {sorted(unknown_agents)}") for agent in self.possible_agents: labels = labels_by_agent.get(agent, set()) if not isinstance(labels, (set, frozenset)): raise ValueError( f"Expected labels for agent '{agent}' to be a set of atomic propositions, " f"got {type(labels).__name__}" ) step_cost = float(self.cost_fn(labels)) self.agent_step_costs[agent] = step_cost self.agent_totals[agent] += step_cost for key, budget in zip(self._budget_keys, self.budgets): step_cost = float(sum(self.agent_step_costs[agent] for agent in budget.agents)) self.budget_step_costs[key] = step_cost self.budget_totals[key] += step_cost
[docs] def satisfied(self) -> bool: """Return ``True`` when every budget remains within its cap.""" return all( self.budget_totals[key] <= budget.amount for key, budget in zip(self._budget_keys, self.budgets) )
[docs] def step_metric(self) -> dict[str, float]: """Return per-step metrics for agents and budgets.""" metrics: dict[str, float] = {} for agent in self.possible_agents: metrics[f"{agent}_cost"] = self.agent_step_costs[agent] metrics[f"{agent}_violation"] = float(self.agent_step_costs[agent] >= 0.5) metrics[f"{agent}_cum_cost"] = self.agent_totals[agent] for key, budget in zip(self._budget_keys, self.budgets): metrics[f"{key}_cost"] = self.budget_step_costs[key] metrics[f"{key}_cum_cost"] = self.budget_totals[key] metrics[f"{key}_satisfied"] = float(self.budget_totals[key] <= budget.amount) metrics["satisfied"] = float(self.satisfied()) return metrics
[docs] def episode_metric(self) -> dict[str, float]: """Return end-of-episode cumulative metrics for agents and budgets.""" metrics: dict[str, float] = {} for agent in self.possible_agents: metrics[f"{agent}_cum_cost"] = self.agent_totals[agent] for key, budget in zip(self._budget_keys, self.budgets): metrics[f"{key}_cum_cost"] = self.budget_totals[key] metrics[f"{key}_satisfied"] = float(self.budget_totals[key] <= budget.amount) metrics["satisfied"] = float(self.satisfied()) return metrics
@property def constraint_type(self) -> str: return "cmg"
[docs] class ConstrainedMarkovGameEnv(ParallelEnv): """PettingZoo parallel wrapper that updates a :class:`ConstrainedMarkovGame`.""" def __init__( self, env: ParallelEnv, budgets: Sequence[Budget], cost_fn: CostFn = dummy_cost_fn, **kw: Any, ): if not isinstance(env, LabelledParallelEnv): raise TypeError( f"{self.__class__.__name__} must wrap a LabelledParallelEnv, but got {type(env).__name__}." ) self.env = env self.metadata = getattr(env, "metadata", {}) self.possible_agents = tuple(env.possible_agents) self.agents = list(getattr(env, "agents", self.possible_agents)) self.label_fn = getattr(env, "label_fn", None) self.cost_fn = cost_fn self._constraint = ConstrainedMarkovGame( possible_agents=self.possible_agents, budgets=budgets, cost_fn=cost_fn, ) def __getattr__(self, name: str): return getattr(self.env, name)
[docs] def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): """Reset the wrapped env and seed the constraint from initial agent labels.""" obs, infos = self.env.reset(seed=seed, options=options) self.agents = list(getattr(self.env, "agents", self.possible_agents)) self._constraint.reset() self._constraint.update(self._labels_by_agent(infos)) return obs, infos
[docs] def step(self, actions): """Step the wrapped env and update the constraint from per-agent labels.""" obs, rewards, terminations, truncations, infos = self.env.step(actions) self.agents = list(getattr(self.env, "agents", self.possible_agents)) self._constraint.update(self._labels_by_agent(infos)) return obs, rewards, terminations, truncations, infos
[docs] def state(self): return self.env.state()
[docs] def render(self): return self.env.render()
[docs] def close(self): return self.env.close()
[docs] def observation_space(self, agent): return self.env.observation_space(agent)
[docs] def action_space(self, agent): return self.env.action_space(agent)
def constraint_step_metrics(self) -> dict[str, float]: return self._constraint.step_metric() def constraint_episode_metrics(self) -> dict[str, float]: return self._constraint.episode_metric() @property def constraint_type(self) -> str: return self._constraint.constraint_type def _labels_by_agent(self, infos: Mapping[str, Mapping[str, Any] | None]) -> dict[str, set[str] | frozenset[str]]: """Extract and validate ``infos[agent]['labels']`` for all possible agents.""" unknown_agents = set(infos) - set(self.possible_agents) if unknown_agents: raise ValueError(f"Unknown agents in infos: {sorted(unknown_agents)}") labels_by_agent: dict[str, set[str] | frozenset[str]] = {} for agent in self.possible_agents: agent_info = infos.get(agent, {}) if agent_info is None: agent_info = {} if not isinstance(agent_info, Mapping): raise ValueError( f"Expected info for agent '{agent}' to be a mapping, got {type(agent_info).__name__}" ) labels = agent_info.get("labels", set()) if not isinstance(labels, (set, frozenset)): raise ValueError( f"Expected labels for agent '{agent}' to be a set of atomic propositions, " f"got {type(labels).__name__}" ) labels_by_agent[agent] = labels return labels_by_agent