Labels, Costs, and Infos¶
This tutorial slows down the MASA environment loop. Instead of training an agent, you will step through colour_grid_world manually and inspect:
obs,reward,info["labels"],cost values,
info["constraint"],info["metrics"],terminated,truncated.
Runnable notebook: notebooks/tutorials/02_labels_costs_and_infos.ipynb
CPU-First Setup¶
Set these before importing MASA/JAX modules:
import os
os.environ.setdefault("JAX_PLATFORMS", "cpu")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
This tutorial does not train an agent, but the same CPU-first convention keeps all tutorials portable.
Labels and Costs¶
MASA separates environment observations from semantic safety signals:
observation -> label_fn -> labels -> cost_fn -> scalar cost
For colour_grid_world, the blue state is the unsafe labelled state.
from pprint import pprint
from masa.envs.tabular.colour_grid_world import (
BLUE_STATE,
GOAL_STATE,
START_STATE,
cost_fn,
label_fn,
)
representative_states = {
"start": START_STATE,
"blue": BLUE_STATE,
"goal": GOAL_STATE,
}
label_cost_table = []
for name, obs in representative_states.items():
labels = label_fn(obs)
label_cost_table.append(
{
"name": name,
"obs": int(obs),
"labels": sorted(labels),
"cost": float(cost_fn(labels)),
}
)
pprint(label_cost_table)
The important convention is that labels describe what is true about the current observation, while the cost function decides which labels matter for a particular safety constraint.
Build the Environment¶
Use a CMDP-style cumulative cost constraint with a zero budget. That makes a single blue-state visit immediately visible in the metrics.
from masa.plugins.helpers import load_plugins
from masa.common.utils import make_env
load_plugins()
def build_colour_env(max_episode_steps=20, budget=0.0):
return make_env(
"colour_grid_world",
"cmdp",
max_episode_steps,
label_fn=label_fn,
cost_fn=cost_fn,
budget=budget,
)
env = build_colour_env()
obs, info = env.reset(seed=0)
print("reset obs:", obs)
print('info["labels"]:', info["labels"])
print('info["constraint"]:')
pprint(info["constraint"])
env.close()
At reset, the LabelledEnv wrapper has already populated info["labels"], and ConstraintMonitor has populated the initial constraint step metrics.
A Rollout Helper¶
This helper records the fields you should inspect when debugging a MASA environment. It stops when either terminated or truncated becomes true.
ACTION_NAMES = {0: "left", 1: "right", 2: "down", 3: "up", 4: "stay"}
def run_rollout(actions, *, seed, max_episode_steps=20, budget=0.0):
env = build_colour_env(max_episode_steps=max_episode_steps, budget=budget)
obs, info = env.reset(seed=seed)
rows = [
{
"event": "reset",
"obs": int(obs),
"labels": sorted(info["labels"]),
"constraint_step": info["constraint"]["step"],
}
]
for step, action in enumerate(actions, start=1):
obs, reward, terminated, truncated, info = env.step(action)
row = {
"event": f"step_{step}",
"action": ACTION_NAMES[action],
"obs": int(obs),
"reward": float(reward),
"terminated": bool(terminated),
"truncated": bool(truncated),
"labels": sorted(info["labels"]),
"constraint_step": info["constraint"]["step"],
"constraint_episode": info["constraint"].get("episode"),
"metric_step": info["metrics"]["step"],
"metric_episode": info.get("metrics", {}).get("episode"),
}
rows.append(row)
if terminated or truncated:
break
env.close()
return rows
Cost Rollout¶
With seed 1, four down actions reach the blue state. The blue label has cost 1.0.
cost_rows = run_rollout([2, 2, 2, 2], seed=1, max_episode_steps=20, budget=0.0)
pprint(cost_rows)
On the final row, expect:
labels == ["blue"],constraint_step["cost"] == 1.0,constraint_step["violation"] == 1.0,constraint_step["cum_cost"] == 1.0.
Termination Rollout¶
terminated means the environment task ended. With seed 4, this scripted path reaches the goal state and receives reward 1.0.
termination_actions = [2] * 8 + [1] * 8
termination_rows = run_rollout(termination_actions, seed=4, max_episode_steps=40, budget=0.0)
pprint(termination_rows)
On the final row, expect:
labels == ["goal"],reward == 1.0,terminated is True,truncated is False,metric_episode["ep_reward"] == 1.0.
Truncation Rollout¶
truncated means an external limit stopped the episode. Here the environment has max_episode_steps=3, so it truncates before reaching a terminal state.
truncation_rows = run_rollout([1, 1, 1, 1], seed=0, max_episode_steps=3, budget=0.0)
pprint(truncation_rows)
On the final row, expect:
terminated is False,truncated is True,metric_episode["ep_length"] == 3.
What to Remember¶
obsandrewardremain the environment’s task interface.info["labels"]is the semantic bridge from observations to safety logic.cost_fn(info["labels"])is what cost-based constraints consume.info["constraint"]["step"]is the per-step safety view.info["constraint"]["episode"]andinfo["metrics"]["episode"]summarize the episode.terminatedis task completion/failure from the environment;truncatedis an external cutoff such as a time limit.