from __future__ import annotations
import gymnasium as gym
from gymnasium import spaces
from masa.common.constraints.base import BaseConstraintEnv
from masa.common.ltl import DFACostFn, DFA, ShapedCostFn
from masa.common.running_mean_std import RunningMeanStd
import numpy as np
from collections import deque
from tqdm import tqdm
def is_wrapped(env: gym.Env, wrapper_class: gym.Wrapper) -> bool:
r"""
Check whether ``env`` is wrapped (anywhere in its wrapper chain) by
``wrapper_class``.
This helper walks through typical wrapper chains:
* Gymnasium-style wrappers via ``.env``.
* Vector-env style wrappers via a ``.venv`` attribute (commonly used by
vectorized environments and some third-party libraries).
Cycle protection is included: if the wrapper chain loops, this function
returns ``False`` rather than looping forever.
Args:
env: Environment or wrapper to inspect.
wrapper_class: Wrapper type to search for.
Returns:
``True`` if an instance of ``wrapper_class`` appears in the wrapper chain;
``False`` otherwise.
"""
current = env
visited = set()
while True:
if id(current) in visited:
return False
visited.add(id(current))
if isinstance(current, wrapper_class):
return True
if hasattr(current, "venv"):
current = current.venv
continue
if isinstance(current, gym.Wrapper):
current = current.env
continue
return False
def get_wrapped(env: gym.Env, wrapper_class: gym.Wrapper) -> gym.Env:
r"""
Return the first wrapper instance of type ``wrapper_class`` found in ``env``'s
wrapper chain.
The traversal rules match :func:`is_wrapped`.
Args:
env: Environment or wrapper to inspect.
wrapper_class: Wrapper type to retrieve.
Returns:
The first encountered instance of ``wrapper_class`` in the wrapper chain,
or ``None`` if it is not present (or if a cycle is detected).
"""
current = env
visited = set()
while True:
if id(current) in visited:
return None
visited.add(id(current))
if isinstance(current, wrapper_class):
return current
if hasattr(current, "venv"):
current = current.venv
continue
if isinstance(current, gym.Wrapper):
current = current.env
continue
return None
[docs]
class ConstraintPersistentWrapper(gym.Wrapper):
"""
Base wrapper that *persists* access to constraint-related attributes.
Many Gymnasium wrappers shadow attributes by changing ``self.env``. This
wrapper provides stable properties for:
* :attr:`_constraint` (if present on the underlying env)
* :attr:`cost_fn` (if exposed by the constraint)
* :attr:`label_fn` (if present on the underlying env)
Subclasses can rely on these properties even when stacked with additional
wrappers.
Args:
env: Base environment to wrap.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
@property
def _constraint(self):
"""
The underlying constraint object, if present.
Returns:
The object stored in ``self.env._constraint`` if it exists, otherwise
``None``.
"""
return getattr(self.env, "_constraint", None)
@property
def cost_fn(self):
"""
Cost function exposed by the underlying constraint, if available.
If ``self._constraint`` exists and it exposes ``cost_fn``, this returns
that callable-like object (often a :class:`masa.common.ltl.DFACostFn`).
Otherwise returns ``None``.
Returns:
A cost function-like object or ``None``.
"""
if self._constraint is not None:
return getattr(self.env._constraint, "cost_fn", None)
else:
return None
@property
def label_fn(self):
"""
Labelling function exposed by the underlying environment, if available.
Returns:
The object stored in ``self.env.label_fn`` if it exists, otherwise
``None``.
"""
return getattr(self.env, "label_fn", None)
[docs]
class ConstraintPersistentObsWrapper(ConstraintPersistentWrapper):
"""
Base class for wrappers that transform observations while preserving constraint access.
Subclasses must implement :meth:`_get_obs` which maps raw observations to the
wrapped observation representation.
Args:
env: Base environment to wrap.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
[docs]
def _get_obs(obs: Any) -> Any:
"""
Transform a raw observation into the wrapped observation.
Subclasses must implement this.
Args:
obs: Raw observation from the underlying environment.
Returns:
Transformed observation.
Raises:
NotImplementedError: If not implemented by the subclass.
"""
raise NotImplementedError
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset the environment and transform the returned observation.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A tuple ``(obs, info)`` where ``obs`` is transformed via :meth:`_get_obs`.
"""
obs, info = self.env.reset(seed=seed, options=options)
return self._get_obs(obs), info
[docs]
def step(self, action):
"""
Step the environment and transform the returned observation.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(obs, reward, terminated, truncated, info)`` where ``obs``
is transformed via :meth:`_get_obs`.
"""
obs, rew, term, trunc, info = self.env.step(action)
return self._get_obs(obs), rew, term, trunc, info
[docs]
class TimeLimit(ConstraintPersistentWrapper):
"""
Episode time-limit wrapper compatible with constraint persistence.
This is a minimal time-limit wrapper similar in spirit to Gymnasium's
time-limit handling. It sets the ``truncated`` flag to ``True`` once the
number of elapsed steps reaches :attr:`_max_episode_steps`.
Args:
env: Base environment to wrap.
max_episode_steps: Maximum number of steps per episode.
Attributes:
_max_episode_steps: Configured time limit in steps.
_elapsed_steps: Counter of steps elapsed in the current episode.
"""
def __init__(self, env: gym.Env, max_episode_steps: int):
super().__init__(env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
[docs]
def step(self, action):
"""
Step the environment and apply time-limit truncation.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(observation, reward, terminated, truncated, info)``. If
the time limit is reached, ``truncated`` is forced to ``True``.
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
truncated = True
return observation, reward, terminated, truncated, info
[docs]
def reset(self, **kwargs):
"""
Reset the environment and the elapsed step counter.
Args:
**kwargs: Forwarded to the underlying environment's ``reset``.
Returns:
The underlying environment's ``reset`` return value.
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)
[docs]
class ConstraintMonitor(ConstraintPersistentWrapper):
"""
Monitor that injects constraint metadata and metrics into ``info``.
This wrapper requires the wrapped environment to be a
:class:`masa.common.constraints.base.BaseConstraintEnv`, so it can query:
* :attr:`masa.common.constraints.base.BaseConstraintEnv.constraint_type`
* :meth:`masa.common.constraints.base.BaseConstraintEnv.constraint_step_metrics`
* :meth:`masa.common.constraints.base.BaseConstraintEnv.constraint_episode_metrics`
On each step, the wrapper writes:
* ``info["constraint"]["type"]``: the constraint type string
* ``info["constraint"]["step"]``: step-level metrics (cheap, safe)
* ``info["constraint"]["episode"]``: episode-level metrics (when available)
Args:
env: Constraint environment to wrap.
Raises:
TypeError: If ``env`` is not a :class:`~masa.common.constraints.base.BaseConstraintEnv`.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
if not isinstance(env, BaseConstraintEnv): # type: ignore[arg-type]
raise TypeError(
"ConstraintMonitor requires env to implement BaseConstraintEnv "
"(wrap your env with CumulativeCostEnv/StepWiseProbabilisticEnv/...)."
)
self._constraint_env: BaseConstraintEnv = env # type: ignore[assignment]
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset and populate initial constraint metadata in ``info``.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A tuple ``(obs, info)``. The returned ``info`` includes
``info["constraint"]["type"]`` and ``info["constraint"]["step"]``.
"""
obs, info = self.env.reset(seed=seed, options=options)
info = dict(info or {})
info.setdefault("constraint", {})["type"] = self._constraint_env.constraint_type
info["constraint"]["step"] = self._step_metrics()
return obs, info
[docs]
def step(self, action):
"""
Step and populate constraint metrics in ``info``.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(observation, reward, terminated, truncated, info)``.
The returned ``info`` includes ``constraint`` fields described in the
class docstring.
"""
observation, reward, terminated, truncated, info = self.env.step(action)
info = dict(info or {})
info.setdefault("constraint", {})["type"] = self._constraint_env.constraint_type
info["constraint"]["step"] = self._step_metrics()
info["constraint"]["episode"] = self._episode_metrics()
if terminated or truncated:
info["constraint"]["episode"] = self._episode_metrics()
return observation, reward, terminated, truncated, info
[docs]
def _step_metrics(self) -> Dict[str, float]:
"""
Read step-level constraint metrics.
Returns:
A dictionary of step-level metrics. If the underlying constraint raises
an exception, returns an empty dictionary.
"""
try:
return dict(self._constraint_env.constraint_step_metrics())
except Exception:
return {}
[docs]
def _episode_metrics(self) -> Dict[str, float]:
"""
Read episode-level constraint metrics.
Returns:
A dictionary of episode-level metrics. If the underlying constraint raises
an exception, returns an empty dictionary.
"""
try:
return dict(self._constraint_env.constraint_episode_metrics())
except Exception:
return {}
[docs]
class RewardMonitor(ConstraintPersistentWrapper):
"""
Monitor that injects reward/length metrics into ``info``.
This wrapper tracks:
* per-step immediate reward in ``info["metrics"]["step"]["reward"]``
* episode return/length at episode end in ``info["metrics"]["episode"]``
Args:
env: Base environment to wrap.
Attributes:
total_reward: Accumulated episode reward since last reset.
total_steps: Number of steps taken since last reset.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset reward counters and forward ``reset`` to the underlying env.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A tuple ``(obs, info)`` from the underlying environment.
"""
obs, info = self.env.reset(seed=seed, options=options)
self.total_reward = 0.0
self.total_steps = 0
return obs, info
[docs]
def step(self, action):
"""
Step the environment and update reward metrics.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(observation, reward, terminated, truncated, info)``.
On episode end, ``info["metrics"]["episode"]`` is populated with
episode return and length.
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self.total_reward += reward
self.total_steps += 1
info = dict(info or {})
info.setdefault("metrics", {})
info["metrics"]["step"] = {"reward": reward}
if terminated or truncated:
info["metrics"]["episode"] = self._episode_metrics()
return observation, reward, terminated, truncated, info
[docs]
def _episode_metrics(self):
"""
Compute episode-level reward metrics.
Returns:
A dictionary with keys:
* ``"ep_reward"``: total episode reward.
* ``"ep_length"``: episode length in steps.
"""
return {"ep_reward": self.total_reward, "ep_length": self.total_steps}
[docs]
class RewardShapingWrapper(ConstraintPersistentWrapper):
r"""
Potential-based reward shaping wrapper for DFA-based safety constraints.
If the wrapped environment's constraint exposes a :class:`~masa.common.ltl.DFACostFn`,
this wrapper constructs a shaped cost function :class:`~masa.common.ltl.ShapedCostFn`
and updates the step ``cost`` entry inside ``info["constraint"]["step"]`` using:
.. math::
c'_t \;=\; c_t \;+\; \gamma \Phi(q_{t+1}) \;-\; \Phi(q_t).
The potential :math:`\Phi` depends on ``impl``:
* ``"none"``: :math:`\Phi(q)=0` (no shaping)
* ``"vi"``: approximate value iteration over DFA graph to derive potentials
* ``"cycle"``: graph-distance based shaping using a reverse-reachability BFS
Notes:
This wrapper assumes the wrapped environment is already producing
``info["automaton_state"]`` and a constraint monitor-like structure
``info["constraint"]["step"]["cost"]``. If these keys are absent, the
wrapper will fall back to default values (state ``0`` and cost ``0.0``).
Args:
env: Base environment to wrap.
gamma: Discount used in the shaping term :math:`\gamma \Phi(q_{t+1})`.
impl: Shaping implementation. One of ``{"none", "vi", "cycle"}``.
Attributes:
shaped_cost_fn: The cost function exposed by :attr:`cost_fn` after shaping.
potential_fn: Callable :math:`\Phi(q)` mapping DFA states to potentials.
_last_potential: Potential at the previous step's DFA state.
_gamma: Shaping discount factor.
_impl: Shaping implementation identifier.
"""
def __init__(self, env: gym.Env, gamma: float = 0.99, impl: str = "none"):
super().__init__(env)
self._last_potential = 0.0
self._gamma = gamma
self._impl = impl
self._setup_potential_fn()
self._setup_cost_fn()
def _setup_cost_fn(self):
"""
Create :attr:`shaped_cost_fn` if DFA-based constraints are available.
If the underlying constraint exposes a :class:`~masa.common.ltl.DFACostFn`,
constructs a :class:`~masa.common.ltl.ShapedCostFn`. Otherwise, uses a
trivial zero-cost function.
Returns:
``None``. This method sets :attr:`shaped_cost_fn` as a side effect.
"""
if hasattr(self._constraint, "cost_fn") and isinstance(self._constraint.cost_fn, DFACostFn):
self.shaped_cost_fn = ShapedCostFn(self._constraint.cost_fn.dfa, self.potential_fn, gamma=self._gamma)
else:
self.shaped_cost_fn = lambda q: 0.0
def _setup_potential_fn(self):
"""
Construct the potential function :attr:`potential_fn` for shaping.
For ``impl="none"``, :attr:`potential_fn` is identically zero.
For ``impl="vi"``, a small fixed number of value-iteration steps are run
over DFA states, treating accepting states as having a constant terminal
value and propagating backward through reachable transitions.
For ``impl="cycle"``, a reverse-graph BFS is used to find a "furthest"
state from the accepting set and then compute distances to that target,
yielding a shaping potential based on distance.
Returns:
``None``. This method sets :attr:`potential_fn` and may store
intermediate tables such as :attr:`V` or :attr:`dist_to_furthest`.
Raises:
AssertionError: If ``impl != "none"`` but no DFA-based cost function
is available on the underlying constraint.
"""
if self._impl != "none":
assert hasattr(self._constraint, "cost_fn"), \
("RewardShapingWrapper requires env to implement a BaseConstraintEnv that exposes a cost_fn")
assert isinstance(getattr(self._constraint, "cost_fn", None), DFACostFn), \
("RewardShapingWrapper requires env to implement a LTLSafetyEnv with cost_fn class: DFACostFn")
dfa: DFA = self.env._constraint.cost_fn.dfa
else:
self.potential_fn = lambda q: 0.0
if self._impl == "vi":
VI_STEPS = 100
GAMMA = 0.9
self.V = {q: 0.0 for q in dfa.states}
assert GAMMA <= self._gamma
print("Reward shaping DFA ...")
for i in tqdm(range(VI_STEPS)):
diff = 0.0
for u in dfa.states:
V_u = self.V[u]
self.V[u] = 1.0/(1.0 - GAMMA) if u in dfa.accepting else \
np.max([GAMMA * self.V[v] for v in dfa.edges[u].keys()])
diff = max(diff, np.abs(V_u - self.V[u]))
if diff < 1e-6:
break
self.potential_fn = lambda q: self.V[q]
if self._impl == "cycle":
print("Reward shaping DFA ...")
edges_rev = {v: set() for v in dfa.states}
for u in dfa.states:
if u in dfa.edges:
reachable_states = set(dfa.edges[u].keys())
for v in reachable_states:
edges_rev[v].add(u)
dist_to_accepting = {u: np.inf for u in dfa.states}
queue = deque()
for a in dfa.accepting:
dist_to_accepting[a] = 0.0
queue.append(a)
max_dist = -1
furthest_state = None
while queue:
current = queue.popleft()
current_dist = dist_to_accepting[current]
if current_dist > max_dist:
max_dist = current_dist
furthest_state = current
for w in edges_rev.get(current, []):
if dist_to_accepting[w] == np.inf:
dist_to_accepting[w] = current_dist + 1
queue.append(w)
if furthest_state is None:
furthest_state = dfa.initial
self.dist_to_furthest = {u: np.inf for u in dfa.states}
max_finite_dist = 0.0
if furthest_state is not None:
u_target = furthest_state
self.dist_to_furthest[u_target] = 0.0
queue = deque([u_target])
while queue:
current = queue.popleft()
current_dist = self.dist_to_furthest[current]
if current_dist > max_finite_dist:
max_finite_dist = current_dist
for w in edges_rev.get(current, []):
if self.dist_to_furthest[w] == np.inf:
self.dist_to_furthest[w] = current_dist + 1
queue.append(w)
replacement_value = max_finite_dist + 1.0
for u in dfa.states:
if self.dist_to_furthest[u] == np.inf:
self.dist_to_furthest[u] = replacement_value
self.potential_fn = lambda q: self.dist_to_furthest[q]
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset the environment and initialize shaping state.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A tuple ``(obs, info)`` from the underlying environment.
Notes:
This wrapper reads ``info["automaton_state"]`` to initialize the
previous potential :attr:`_last_potential`. If the key is missing,
it assumes DFA state ``0``.
"""
obs, info = self.env.reset(seed=seed, options=options)
automaton_state = info.get("automaton_state", 0)
self._last_potential = self.potential_fn(automaton_state)
return obs, info
[docs]
def step(self, action: Any):
"""
Step the environment and apply potential-based shaping to the step cost.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(observation, reward, terminated, truncated, info)``.
Side effects:
Updates ``info["constraint"]["step"]["cost"]`` in-place with the shaped
cost and updates :attr:`_last_potential`.
Notes:
If the underlying ``info`` does not contain constraint metrics,
this method assumes an unshaped step cost of ``0.0`` and will
still attempt to write back into ``info["constraint"]["step"]``.
"""
observation, reward, terminated, truncated, info = self.env.step(action)
cost = info["constraint"]["step"].get("cost", 0.0)
automaton_state = info.get("automaton_state", 0)
potential = self.potential_fn(automaton_state)
info["constraint"]["step"]["cost"] = cost + self._gamma * potential - self._last_potential
self._last_potential = potential
return observation, reward, terminated, truncated, info
@property
def cost_fn(self):
"""
Expose the shaped cost function.
Returns:
The shaped cost function constructed in :meth:`_setup_cost_fn`.
"""
return self.shaped_cost_fn
[docs]
class NormWrapper(ConstraintPersistentWrapper):
"""
Normalize observations and/or rewards for a *single* (non-vectorized) environment.
This wrapper maintains running mean/variance estimates and applies:
* Observation normalization (elementwise): :math:`(x - \\mu) / \\sqrt{\\sigma^2 + \\varepsilon}`
* Reward normalization using a running variance estimate over discounted returns.
This wrapper is intended for non-vectorized environments. For vectorized
environments, use :class:`VecNormWrapper`.
Args:
env: Base (non-vectorized) environment.
norm_obs: Whether to normalize observations.
norm_rew: Whether to normalize rewards.
training: If ``True``, update running statistics; otherwise, statistics are frozen.
clip_obs: Clip normalized observations to ``[-clip_obs, clip_obs]``.
clip_rew: Clip normalized rewards to ``[-clip_rew, clip_rew]``.
gamma: Discount factor for the running return used in reward normalization.
eps: Small constant :math:`\\varepsilon` for numerical stability.
Attributes:
norm_obs: See Args.
norm_rew: See Args.
training: See Args.
clip_obs: See Args.
clip_rew: See Args.
gamma: See Args.
eps: See Args.
obs_rms: :class:`masa.common.running_mean_std.RunningMeanStd` for observations.
rew_rms: :class:`masa.common.running_mean_std.RunningMeanStd` for returns.
returns: Discounted return accumulator used for reward normalization.
"""
def __init__(
self,
env: gym.Env,
norm_obs: bool = True,
norm_rew: bool = True,
training: bool = True,
clip_obs: float = 10.0,
clip_rew: float = 10.0,
gamma: float = 0.99,
eps: float = 1e-8
):
assert not isinstance(
env, VecEnvWrapperBase
), "NormWrapper does not expect a vectorized environment (DummyVecWrapper / VecWrapper). Please use VecNormWrapper instead"
assert norm_obs and isinstance(
env.observation_space, spaces.Box
), "NormWrapper only supports Box observation spaces when norm_obs=True."
super().__init__(env)
self.norm_obs = norm_obs
self.norm_rew = norm_rew
self.training = training
self.clip_obs = clip_obs
self.clip_rew = clip_rew
self.gamma = gamma
self.eps = eps
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.rew_rms = RunningMeanStd(shape=())
self.returns = np.zeros(1, dtype=np.float32)
def _normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""
Normalize (and clip) a single observation.
Args:
obs: Raw observation.
Returns:
Normalized observation.
"""
return np.clip(
(obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps),
-self.clip_obs,
self.clip_obs
)
def _normalize_rew(self, rew: float) -> float:
"""
Normalize (and clip) a single reward.
Args:
rew: Raw reward.
Returns:
Normalized reward.
"""
return np.clip(
rew / np.sqrt(self.rew_rms.var + self.eps),
-self.clip_rew,
self.clip_rew,
)
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset the environment and (optionally) update normalization statistics.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A tuple ``(obs, info)`` where ``obs`` may be normalized.
"""
obs, info = self.env.reset(seed=seed, options=options)
if self.norm_obs and self.training:
self.obs_rms.update(obs)
if self.norm_rew and self.training:
self.returns[:] = 0.0
if self.norm_obs:
obs = self._normalize_obs(obs)
return obs, info
[docs]
def step(self, action):
"""
Step the environment and apply observation/reward normalization.
Args:
action: Action forwarded to the underlying environment.
Returns:
A 5-tuple ``(obs, rew, terminated, truncated, info)`` where ``obs`` and/or
``rew`` may be normalized.
"""
obs, rew, term, trunc, info = self.env.step(action)
if self.norm_obs and self.training:
self.obs_rms.update(obs)
if self.norm_rew:
self.returns = self.returns * self.gamma + rew
if self.training:
self.rew_rms.update(self.returns)
rew = self._normalize_rew(rew)
if self.norm_obs:
obs = self._normalize_obs(obs)
return obs, rew, term, trunc, info
[docs]
class OneHotObsWrapper(ConstraintPersistentObsWrapper):
"""
One-hot encode :class:`gymnasium.spaces.Discrete` observations.
Supported input observation spaces:
* :class:`gymnasium.spaces.Discrete`: returns a 1D one-hot vector of length ``n``.
* :class:`gymnasium.spaces.Dict`: one-hot encodes any Discrete subspaces and
passes through non-Discrete subspaces.
* Otherwise: passes observations through unchanged.
The wrapper updates :attr:`gymnasium.Env.observation_space` accordingly.
Args:
env: Base environment to wrap.
Attributes:
_orig_obs_space: The original observation space of the wrapped env.
_mode: One of ``{"discrete", "dict", "pass"}`` describing the encoding mode.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
self._orig_obs_space = self.env.observation_space
if isinstance(self._orig_obs_space, spaces.Discrete):
self._mode = "discrete"
n = self._orig_obs_space.n
self.observation_space = spaces.Box(
low=0.0,
high=1.0,
shape=(n,),
dtype=np.float32,
)
elif isinstance(self._orig_obs_space, spaces.Dict):
self._mode = "dict"
new_spaces: Dict[str, spaces.Space] = {}
for key, subspace in self._orig_obs_space.spaces.items():
if isinstance(subspace, spaces.Discrete):
n = subspace.n
new_spaces[key] = spaces.Box(
low=0.0,
high=1.0,
shape=(n,),
dtype=np.float32,
)
else:
# Preserve non-Discrete subspace as-is
new_spaces[key] = subspace
self.observation_space = spaces.Dict(new_spaces)
else:
self._mode = "pass"
self.observation_space = self._orig_obs_space
@staticmethod
def _one_hot_scalar(idx: int, n: int) -> np.ndarray:
"""
One-hot encode an integer index.
Args:
idx: Index in ``{0, 1, ..., n-1}``.
n: Vector length.
Returns:
A float32 vector ``v`` with ``v[idx] = 1`` and zeros elsewhere.
"""
one_hot = np.zeros(n, dtype=np.float32)
one_hot[idx] = 1.0
return one_hot
def _get_obs(self, obs: Union[int, Dict[str, Any], np.ndarray]) -> np.ndarray:
"""
Transform an observation according to the wrapper's configured mode.
Args:
obs: Raw observation.
Returns:
One-hot encoded observation (or dict containing one-hot fields) when applicable,
otherwise the original observation.
"""
if self._mode == "discrete":
# Original obs_space is Discrete; obs is an int-like
idx = int(obs)
n = self._orig_obs_space.n
return self._one_hot_scalar(idx, n)
elif self._mode == "dict":
assert isinstance(obs, dict), (
f"Expected dict observation for Dict space, got {type(obs)}"
)
new_obs: Dict[str, Any] = {}
for key, subspace in self._orig_obs_space.spaces.items():
value = obs[key]
if isinstance(subspace, spaces.Discrete):
idx = int(value)
new_obs[key] = self._one_hot_scalar(idx, subspace.n)
else:
# Leave non-Discrete parts unchanged
new_obs[key] = value
return new_obs
else:
# pass
return obs
[docs]
class FlattenDictObsWrapper(ConstraintPersistentObsWrapper):
"""
Flatten a :class:`gymnasium.spaces.Dict` observation into a 1D :class:`~gymnasium.spaces.Box`.
The wrapper creates a deterministic key ordering (alphabetical) and concatenates
each sub-observation in that order.
Supported Dict subspaces:
* :class:`gymnasium.spaces.Box`: flattened via ``reshape(-1)``.
* :class:`gymnasium.spaces.Discrete`: represented as a length-``n`` one-hot
segment for the purposes of bounds (note: the current implementation of
:meth:`_get_obs` expects Box values; see Notes).
Args:
env: Base environment with Dict observation space.
Attributes:
_orig_obs_space: Original Dict observation space.
_key_slices: Mapping from key to slice in the flattened vector.
Raises:
TypeError: If the underlying observation space is not a Dict, or contains
unsupported subspaces.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
self._orig_obs_space = self.env.observation_space
if not isinstance(self._orig_obs_space, spaces.Dict):
raise TypeError(
f"FlattenDictObsWrapper requires Dict observation space, got {type(self._orig_obs_space)}"
)
# To be able to reconstruct if needed, keep slices for each key
self._key_slices: dict[str, slice] = {}
low_parts = []
high_parts = []
offset = 0
# Sort keys alphabetically for deterministic ordering
for key in sorted(self._orig_obs_space.spaces.keys()):
subspace = self._orig_obs_space.spaces[key]
if isinstance(subspace, spaces.Box):
# Flatten Box
low = np.asarray(subspace.low, dtype=np.float32).reshape(-1)
high = np.asarray(subspace.high, dtype=np.float32).reshape(-1)
length = low.shape[0]
low_parts.append(low)
high_parts.append(high)
elif isinstance(subspace, spaces.Discrete):
# One-hot will be in [0, 1]
length = subspace.n
low_parts.append(np.zeros(length, dtype=np.float32))
high_parts.append(np.ones(length, dtype=np.float32))
else:
raise TypeError(
f"Unsupported subspace type for key '{key}': {type(subspace)}"
)
self._key_slices[key] = slice(offset, offset + length)
offset += length
low = np.concatenate(low_parts).astype(np.float32)
high = np.concatenate(high_parts).astype(np.float32)
self.observation_space = spaces.Box(
low=low,
high=high,
dtype=np.float32,
)
def _get_obs(self, obs: Dict[str, Any]) -> np.ndarray:
"""
Flatten a Dict observation into a 1D vector.
Args:
obs: Dict observation keyed the same way as the original Dict space.
Returns:
A 1D float32 array created by concatenating flattened sub-observations.
Raises:
TypeError: If any subspace is not a Box.
Notes:
Although the constructor supports Discrete subspaces when building bounds,
this implementation currently enforces Box-only subspaces at runtime.
"""
assert isinstance(obs, dict), (
f"Expected dict observation for Dict space, got {type(obs)}"
)
parts = []
for key in sorted(self._orig_obs_space.spaces.keys()):
subspace = self._orig_obs_space.spaces[key]
value = obs[key]
if not isinstance(subspace, spaces.Box):
raise TypeError(
f"FlattenDictObsWrapper only supports Box subspaces, "
f"got {type(subspace)} for key '{key}'"
)
arr = np.asarray(value, dtype=np.float32).reshape(-1)
parts.append(arr)
return np.concatenate(parts, axis=0).astype(np.float32)
[docs]
class VecEnvWrapperBase(ConstraintPersistentWrapper):
"""
Base class for simple Python-list vector environment wrappers.
Vector environments in this file expose:
* :attr:`n_envs`: number of parallel environments
* :meth:`reset`: returns ``(obs_list, info_list)``
* :meth:`step`: returns ``(obs_list, rew_list, term_list, trunc_list, info_list)``
* :meth:`reset_done`: reset only environments indicated by a ``dones`` mask
Args:
env: For :class:`DummyVecWrapper`, this is the single underlying env.
For :class:`VecWrapper`, this is set to ``envs[0]`` to preserve a
Gymnasium-like API surface.
Attributes:
n_envs: Number of environments.
"""
n_envs: int
def __init__(self, env: gym.Env):
# For DummyVecWrapper: env is the single env
# For VecWrapper: env is envs[0]
# For VecNormWrapper: env is a VecEnvWrapperBase
super().__init__(env)
[docs]
def reset_done(
self,
dones: Union[List[bool], np.ndarray],
*,
seed: int | None = None,
options: Dict[str, Any] | None = None
):
"""
Reset only the environments indicated by ``dones``.
Args:
dones: Boolean mask/list of length :attr:`n_envs`. Entries set to
``True`` are reset.
seed: Optional base seed. Implementations may offset by environment index.
options: Reset options forwarded to underlying environments.
Returns:
A tuple ``(reset_obs, reset_infos)`` where:
* ``reset_obs`` is a list of length :attr:`n_envs` containing reset
observations at indices that were reset, and ``None`` elsewhere.
* ``reset_infos`` is a list of length :attr:`n_envs` containing reset
info dicts at indices that were reset, and empty dicts elsewhere.
Raises:
NotImplementedError: If not implemented by a subclass.
"""
raise NotImplementedError
[docs]
class DummyVecWrapper(VecEnvWrapperBase):
"""
Wrap a single environment with a vector-environment API (``n_envs=1``).
This wrapper is useful for code paths that expect list-based vector outputs,
while still running a single environment instance.
Args:
env: Base environment.
Attributes:
n_envs: Always ``1``.
envs: List containing the single wrapped environment.
"""
def __init__(self, env: gym.Env):
super().__init__(env)
self.n_envs = 1
self.envs: List[gym.Env] = [env]
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset and return vectorized lists of length 1.
Args:
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
``([obs], [info])``.
"""
obs, info = self.env.reset(seed=seed, options=options)
return [obs], [info]
[docs]
def reset_done(
self,
dones: Union[List[bool], np.ndarray],
*,
seed: int | None = None,
options: Dict[str, Any] | None = None
):
"""
Conditionally reset the single environment.
Args:
dones: A length-1 mask. If ``dones[0]`` is ``True``, reset.
seed: Random seed forwarded to the underlying environment.
options: Reset options forwarded to the underlying environment.
Returns:
A pair ``(reset_obs, reset_infos)`` as described by
:meth:`VecEnvWrapperBase.reset_done`.
"""
dones = list(dones)
assert len(dones) == 1
if dones[0]:
return self.reset(seed=seed, options=options)
else:
[None], [{}]
[docs]
def step(self, action):
"""
Step and return vectorized lists of length 1.
Args:
action: Action for the single environment.
Returns:
``([obs], [rew], [terminated], [truncated], [info])``.
"""
obs, rew, term, trunc, info = self.env.step(action)
return [obs], [rew], [term], [trunc], [info]
[docs]
class VecWrapper(VecEnvWrapperBase):
"""
Wrap a list of environments with a simple vector-environment API.
Each underlying environment is reset/stepped sequentially in Python, and
results are returned as Python lists.
Args:
envs: Non-empty list of environments.
Attributes:
envs: The list of wrapped environments.
n_envs: Number of wrapped environments.
"""
def __init__(self, envs: List[gym.Env]):
assert len(envs) > 0, "VecWrapper requires at least one environment"
super().__init__(envs[0]) # maintain API compatibility
self.envs: List[gym.Env] = envs
self.n_envs = len(envs)
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset all environments and return lists.
Args:
seed: Optional base seed. If provided, environment ``i`` receives ``seed + i``.
options: Reset options forwarded to each environment.
Returns:
A pair ``(obs_list, info_list)`` of length :attr:`n_envs`.
"""
obs_list, info_list = [], []
for i, env in enumerate(self.envs):
s = None if seed is None else seed + i
obs, info = env.reset(seed=s, options=options)
obs_list.append(obs)
info_list.append(info)
return obs_list, info_list
[docs]
def reset_done(
self,
dones: Union[List[bool], np.ndarray],
*,
seed: int | None = None,
options: Dict[str, Any] | None = None
):
"""
Reset only environments whose done flag is True.
Args:
dones: Boolean mask/list of length :attr:`n_envs`.
seed: Optional base seed. If provided, environment ``i`` receives ``seed + i``.
options: Reset options forwarded to environments being reset.
Returns:
A tuple ``(reset_obs, reset_infos)`` where non-reset indices contain
``None`` and ``{}`` respectively.
"""
dones = list(dones)
assert len(dones) == self.n_envs
reset_obs = [None] * self.n_envs
reset_infos = [{} for _ in range(self.n_envs)]
for i, done in enumerate(dones):
if done:
s = None if seed is None else seed + i
obs, info = self.envs[i].reset(seed=s, options=options)
reset_obs[i] = obs
reset_infos[i] = info
return reset_obs, reset_infos
[docs]
def step(self, action):
"""
Step all environments.
Args:
actions: Iterable of actions of length :attr:`n_envs`.
Returns:
A 5-tuple of lists ``(obs_list, rew_list, term_list, trunc_list, info_list)``.
Notes:
The loop expects one action per environment. If the provided
``actions`` length mismatches :attr:`n_envs`, Python will raise.
"""
obs_list, rew_list, term_list, trunc_list, info_list = [], [], [], [], []
for env, action in zip(self.envs, actions):
obs, rew, term, trunc, info = env.step(action)
obs_list.append(obs)
rew_list.append(rew)
term_list.append(term)
trunc_list.append(trunc)
info_list.append(info)
return obs_list, rew_list, term_list, trunc_list, info_list
[docs]
class VecNormWrapper(VecEnvWrapperBase):
"""
Normalize observations and/or rewards for a vectorized environment.
This wrapper expects an environment implementing :class:`VecEnvWrapperBase`
(e.g., :class:`DummyVecWrapper` or :class:`VecWrapper`) and applies the same
normalization logic as :class:`NormWrapper`, but over batches.
Observation normalization uses running statistics of the stacked observation
array (shape ``(n_envs, *obs_shape)``). Reward normalization uses running
statistics of discounted returns per environment.
Args:
env: A vectorized environment implementing :class:`VecEnvWrapperBase`.
norm_obs: Whether to normalize observations.
norm_rew: Whether to normalize rewards.
training: If ``True``, update running statistics; otherwise, statistics are frozen.
clip_obs: Clip normalized observations to ``[-clip_obs, clip_obs]``.
clip_rew: Clip normalized rewards to ``[-clip_rew, clip_rew]``.
gamma: Discount factor for the running return used in reward normalization.
eps: Small constant :math:`\\varepsilon` for numerical stability.
Attributes:
n_envs: Copied from the wrapped vector environment.
obs_rms: :class:`masa.common.running_mean_std.RunningMeanStd` for observations.
rew_rms: :class:`masa.common.running_mean_std.RunningMeanStd` for returns.
returns: Vector of length :attr:`n_envs` storing per-env discounted returns.
"""
def __init__(
self,
env: Union[gym.Env, List[gym.Env]],
norm_obs: bool = True,
norm_rew: bool = True,
training: bool = True,
clip_obs: float = 10.0,
clip_rew: float = 10.0,
gamma: float = 0.99,
eps: float = 1e-8
):
assert isinstance(
env, VecEnvWrapperBase
), "VecNormWrapper expects a vectorized environment (DummyVecWrapper / VecWrapper)."
assert norm_obs and isinstance(
env.observation_space, spaces.Box
), "VecNormWrapper only supports Box observation spaces when norm_obs=True."
super().__init__(env)
self.n_envs = env.n_envs
self.norm_obs = norm_obs
self.norm_rew = norm_rew
self.training = training
self.clip_obs = clip_obs
self.clip_rew = clip_rew
self.gamma = gamma
self.eps = eps
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.rew_rms = RunningMeanStd(shape=())
self.returns = np.zeros(self.n_envs, dtype=np.float32)
def _normalize_obs(self, obs_list: List[np.ndarray]) -> List[np.ndarray]:
"""
Normalize and clip a list of observations.
Args:
obs_list: List of raw observations of length :attr:`n_envs`.
Returns:
List of normalized observations.
"""
obs_arr = np.asarray(obs_list, dtype=np.float32)
norm = (obs_arr - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
norm = np.clip(norm, -self.clip_obs, self.clip_obs)
return norm.tolist()
def _normalize_rew(self, rew_list: List[float]) -> List[float]:
"""
Normalize and clip a list of rewards.
Args:
rew_list: List of raw rewards of length :attr:`n_envs`.
Returns:
List of normalized rewards.
"""
rew_arr = np.asarray(rew_list, dtype=np.float32)
norm = rew_arr / np.sqrt(self.rew_rms.var + self.eps)
norm = np.clip(norm, -self.clip_rew, self.clip_rew)
return norm.tolist()
[docs]
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
"""
Reset all environments and normalize observations.
Args:
seed: Optional base seed forwarded to the underlying vector env.
options: Reset options forwarded to the underlying vector env.
Returns:
A pair ``(obs_list, info_list)``. Observations may be normalized.
"""
obs_list, info_list = self.env.reset(seed=seed, options=options)
if self.norm_obs and self.training:
self.obs_rms.update(np.asarray(obs_list, dtype=np.float32))
self.returns[:] = 0.0
if self.norm_obs:
obs_list = self._normalize_obs(obs_list)
return obs_list, info_list
[docs]
def reset_done(
self,
dones: Union[List[bool], np.ndarray],
*,
seed: int | None = None,
options: Dict[str, Any] | None = None
):
"""
Reset only environments indicated by ``dones`` and normalize those observations.
Args:
dones: Boolean mask/list of length :attr:`n_envs`.
seed: Optional base seed forwarded to the underlying vector env.
options: Reset options forwarded to the underlying vector env.
Returns:
A tuple ``(reset_obs, reset_infos)`` as described by
:meth:`VecEnvWrapperBase.reset_done`, with reset observations optionally
normalized.
"""
reset_obs, reset_infos = self.env.reset_done(
dones, seed=seed, options=options
)
obs_arr = np.asarray(
[o for o in reset_obs if o is not None],
dtype=np.float32,
) if any(o is not None for o in reset_obs) else None
if self.norm_obs and self.training and obs_arr is not None:
self.obs_rms.update(obs_arr)
for i, done in enumerate(dones):
if done:
self.returns[i] = 0.0
if self.norm_obs:
norm_reset_obs: List[Any] = list(reset_obs)
# Only normalize indices that were reset
for i, done in enumerate(dones):
if done and reset_obs[i] is not None:
o = np.asarray(reset_obs[i], dtype=np.float32)
norm = (o - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
norm = np.clip(norm, -self.clip_obs, self.clip_obs)
norm_reset_obs[i] = norm
reset_obs = norm_reset_obs
return reset_obs, reset_infos
[docs]
def step(self, actions):
"""
Step all environments and apply observation/reward normalization.
Args:
actions: Iterable of actions of length :attr:`n_envs`.
Returns:
A 5-tuple ``(obs_list, rew_list, term_list, trunc_list, infos)``, where
observations and/or rewards may be normalized.
"""
obs_list, rew_list, term_list, trunc_list, infos = self.env.step(actions)
obs_arr = np.asarray(obs_list, dtype=np.float32)
rew_arr = np.asarray(rew_list, dtype=np.float32)
if self.norm_obs and self.training:
self.obs_rms.update(obs_arr)
if self.norm_rew:
self.returns = self.returns * self.gamma + rew_arr
if self.training:
self.rew_rms.update(self.returns)
rew_list = self._normalize_rew(rew_list)
if self.norm_obs:
obs_list = self._normalize_obs(obs_list)
return obs_list, rew_list, term_list, trunc_list, infos