Wrapper Stack¶
This tutorial shows that make_env(...) is a convenience around a concrete wrapper stack. You will build the same colour_grid_world CMDP environment two ways:
with
make_env(...),manually with each wrapper in order.
Runnable notebook: notebooks/tutorials/03_wrapper_stack.ipynb
CPU-First Setup¶
Use the same CPU-first setup as the earlier tutorials:
import os
os.environ.setdefault("JAX_PLATFORMS", "cpu")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
Imports¶
The manual path uses the same pieces that make_env applies internally.
from pathlib import Path
from pprint import pprint
from shutil import rmtree
from gymnasium.wrappers import RecordVideo
from masa.plugins.helpers import load_plugins
from masa.common.constraints.cmdp import CumulativeCostEnv
from masa.common.labelled_env import LabelledEnv
from masa.common.utils import make_env
from masa.common.wrappers import (
ConstraintMonitor,
RewardMonitor,
TimeLimit,
get_wrapped,
is_wrapped,
)
from masa.envs.tabular.colour_grid_world import ColourGridWorld, cost_fn, label_fn
load_plugins()
Build with make_env¶
make_env looks up the environment and constraint by registry name, constructs the base environment, and applies MASA’s standard wrapper order.
def build_factory_env():
return make_env(
"colour_grid_world",
"cmdp",
5,
label_fn=label_fn,
cost_fn=cost_fn,
budget=0.0,
)
Build the Same Stack Manually¶
The equivalent manual stack is:
ColourGridWorld
-> TimeLimit
-> LabelledEnv
-> CumulativeCostEnv
-> ConstraintMonitor
-> RewardMonitor
In code:
def build_manual_env():
env = ColourGridWorld()
env = TimeLimit(env, 5)
env = LabelledEnv(env, label_fn)
env = CumulativeCostEnv(env, cost_fn=cost_fn, budget=0.0)
env = ConstraintMonitor(env)
env = RewardMonitor(env)
return env
Inspect the Wrapper Chain¶
is_wrapped answers whether a wrapper appears anywhere in the chain. get_wrapped returns the first matching wrapper object.
WRAPPERS = (TimeLimit, LabelledEnv, CumulativeCostEnv, ConstraintMonitor, RewardMonitor)
def summarize_wrappers(env):
return {
wrapper.__name__: {
"present": is_wrapped(env, wrapper),
"found_type": type(get_wrapped(env, wrapper)).__name__,
}
for wrapper in WRAPPERS
}
factory_env = build_factory_env()
manual_env = build_manual_env()
factory_summary = summarize_wrappers(factory_env)
manual_summary = summarize_wrappers(manual_env)
print("factory stack")
pprint(factory_summary)
print("manual stack")
pprint(manual_summary)
assert factory_summary == manual_summary
factory_env.close()
manual_env.close()
Compare Behaviour¶
Both environments should emit the same observations, rewards, labels, constraint metrics, and done flags for the same seed and actions.
ACTION_NAMES = {0: "left", 1: "right", 2: "down", 3: "up", 4: "stay"}
def rollout(build_env, actions, *, seed):
env = build_env()
obs, info = env.reset(seed=seed)
rows = [
{
"event": "reset",
"obs": int(obs),
"labels": sorted(info["labels"]),
"constraint": info["constraint"],
}
]
for step, action in enumerate(actions, start=1):
obs, reward, terminated, truncated, info = env.step(action)
rows.append(
{
"event": f"step_{step}",
"action": ACTION_NAMES[action],
"obs": int(obs),
"reward": float(reward),
"terminated": bool(terminated),
"truncated": bool(truncated),
"labels": sorted(info["labels"]),
"constraint": info["constraint"],
"metrics": info.get("metrics"),
}
)
if terminated or truncated:
break
env.close()
return rows
actions = [2, 2, 2, 2]
factory_rows = rollout(build_factory_env, actions, seed=1)
manual_rows = rollout(build_manual_env, actions, seed=1)
pprint(factory_rows)
assert factory_rows == manual_rows
The final row reaches the blue state, so both environments should report:
labels == ["blue"],constraint["step"]["cost"] == 1.0,constraint["step"]["violation"] == 1.0.
Record the Finished Stack¶
RecordVideo is not part of the semantic MASA stack. When record_video=True, make_env wraps the completed stack with Gymnasium’s video recorder, so labels, constraints, and monitors behave the same while frames are saved from render().
Path("videos/tutorial_wrapper_stack") stores recordings under the repo-local videos/ directory when the tutorial is run from the project root. The example prints the exact directory and MP4 paths, and clears this tutorial subdirectory before recording so reruns do not mix old and new videos.
record_video_episode_trigger is Gymnasium’s episode_trigger. It receives the zero-based episode id and records that episode when it returns True. Common schedules are:
record_every_episode = lambda episode_id: True
record_every_5_episodes_from_zero = lambda episode_id: episode_id % 5 == 0
record_human_episodes_5_10_15 = lambda episode_id: (episode_id + 1) % 5 == 0
Gymnasium’s step_trigger is useful for fixed-length clips that start immediately on a global environment step. This starts a 500-frame clip every 500 environment steps:
video_kwargs={
"step_trigger": lambda step_id: step_id > 0 and step_id % 500 == 0,
"video_length": 500,
}
If you specifically want the next complete episode after each 500-step boundary, use a small stateful episode trigger and update it from your rollout loop. The example script exposes this as --trigger-mode step --trigger-value 500.
class RecordNextEpisodeEveryNSteps:
def __init__(self, interval):
self.interval = interval
self.total_steps = 0
self.next_threshold = interval
self.pending_recordings = 0
def observe_step(self):
self.total_steps += 1
while self.total_steps >= self.next_threshold:
self.pending_recordings += 1
self.next_threshold += self.interval
def __call__(self, episode_id):
if self.pending_recordings < 1:
return False
self.pending_recordings -= 1
return True
trigger = RecordNextEpisodeEveryNSteps(500)
env = make_env(
...,
record_video=True,
record_video_episode_trigger=trigger,
)
# Inside the rollout loop, after each env.step(...):
trigger.observe_step()
video_dir = Path("videos/tutorial_wrapper_stack")
rmtree(video_dir, ignore_errors=True)
print("video directory", video_dir)
video_env = make_env(
"colour_grid_world",
"cmdp",
len(actions),
label_fn=label_fn,
cost_fn=cost_fn,
budget=0.0,
env_kwargs={
"render_mode": "rgb_array",
"render_window_size": 96,
},
record_video=True,
record_video_episode_trigger=lambda episode_id: True,
video_folder=str(video_dir),
)
assert isinstance(video_env, RecordVideo)
try:
video_env.reset(seed=2)
for action in actions:
_, _, terminated, truncated, _ = video_env.step(action)
if terminated or truncated:
break
finally:
video_env.close()
recorded_videos = sorted(video_dir.glob("*.mp4"))
for path in recorded_videos:
print(path)
assert recorded_videos
Why Order Matters¶
TimeLimitcomes first so truncation is part of the base interaction before safety monitoring.LabelledEnvmust run before the constraint wrapper because constraints consumeinfo["labels"].CumulativeCostEnvupdates the stateful safety monitor.ConstraintMonitorreads the constraint and writesinfo["constraint"].RewardMonitoris last here so it can add reward and episode-length metrics without changing safety logic.When enabled,
RecordVideosits outside the completed stack and observes rendered frames without changing MASA metadata.
Most users should call make_env. Manual construction is useful when you need to understand, debug, or extend the stack.