Source code for masa.common.pettingzoo_record_video

from __future__ import annotations

import gc
import os
from typing import Any, Callable, TypeAlias

import gymnasium
import numpy as np
from gymnasium.error import DependencyNotInstalled
from pettingzoo.utils.env import AECEnv, ActionType, AgentID, ObsType, ParallelEnv
from pettingzoo.utils.wrappers.base import BaseWrapper
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper

RenderFrame: TypeAlias = np.typing.NDArray[Any]


[docs] class RecordVideoParallel(BaseParallelWrapper): """Record videos from a PettingZoo parallel environment.""" def __init__( self, env: ParallelEnv, video_folder: str, episode_trigger: Callable[[int], bool] | None = None, step_trigger: Callable[[int], bool] | None = None, video_length: int = 0, name_prefix: str = "rl-video", fps: int | None = None, disable_logger: bool = True, gc_trigger: Callable[[int], bool] | None = lambda episode: True, ): super().__init__(env) assert isinstance(env, ParallelEnv), "RecordVideoParallel is only compatible with ParallelEnv environments." if env.render_mode in {None, "human", "ansi"}: # type: ignore[attr-defined] raise ValueError( f"Render mode is {env.render_mode}, which is incompatible with RecordVideoParallel. " # type: ignore[attr-defined] "Initialize your environment with a render_mode that returns an image, such as rgb_array." ) if episode_trigger is None and step_trigger is None: episode_trigger = ( lambda episode_id: int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id if episode_id < 1000 else episode_id % 1000 == 0 ) self.episode_trigger = episode_trigger self.step_trigger = step_trigger self.disable_logger = disable_logger self.gc_trigger = gc_trigger self.video_folder = os.path.abspath(video_folder) if os.path.isdir(self.video_folder): gymnasium.logger.warn( f"Overwriting existing videos at {self.video_folder} folder " "(try specifying a different `video_folder` for the `RecordVideoParallel` wrapper if this is not desired)" ) os.makedirs(self.video_folder, exist_ok=True) if fps is None: fps = int(getattr(env, "metadata", {}).get("render_fps", 30)) self.frames_per_sec: int = fps self.name_prefix: str = name_prefix self._video_name: str | None = None self.video_length: int | float = video_length if video_length != 0 else float("inf") self.recording: bool = False self.recorded_frames: list[RenderFrame] = [] self.render_history: list[RenderFrame] = [] self.step_id: int = -1 self.episode_id: int = -1 try: import moviepy # noqa: F401 except ImportError as e: raise DependencyNotInstalled( 'MoviePy is not installed, run `pip install "moviepy>=2.2.1,<3.0.0"`' ) from e def _capture_frame(self): assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() if isinstance(frame, list): if len(frame) == 0: return self.render_history += frame frame = frame[-1] if isinstance(frame, np.ndarray): self.recorded_frames.append(frame) else: self.stop_recording() gymnasium.logger.warn( f"Recording stopped: expected type of frame returned by render to be a numpy array, got {type(frame)}." )
[docs] def reset( self, seed: int | None = None, options: dict | None = None ) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]: obs, info = self.env.reset(seed=seed, options=options) self.episode_id += 1 if self.recording and self.video_length == float("inf"): self.stop_recording() if self.episode_trigger and self.episode_trigger(self.episode_id): self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording() return obs, info
[docs] def step( self, actions: dict[AgentID, ActionType] ) -> tuple[ dict[AgentID, ObsType], dict[AgentID, float], dict[AgentID, bool], dict[AgentID, bool], dict[AgentID, dict], ]: obs, rew, terminated, truncated, info = self.env.step(actions) self.step_id += 1 if self.step_trigger and self.step_trigger(self.step_id): self.start_recording(f"{self.name_prefix}-step-{self.step_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording() return obs, rew, terminated, truncated, info
[docs] def render(self): render_out = self.env.render() if self.recording and isinstance(render_out, list): self.recorded_frames += render_out if len(self.render_history) > 0: tmp_history = self.render_history self.render_history = [] frames = render_out if isinstance(render_out, list) else [render_out] return tmp_history + frames return render_out
[docs] def close(self): super().close() if self.recording: self.stop_recording()
def start_recording(self, video_name: str): if self.recording: self.stop_recording() self.recording = True self._video_name = video_name def stop_recording(self): assert self.recording, "stop_recording was called, but no recording was started" if len(self.recorded_frames) == 0: gymnasium.logger.warn("Ignored saving a video as there were zero frames to save.") else: try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip except ImportError as e: raise DependencyNotInstalled( 'MoviePy is not installed, run `pip install "moviepy>=2.2.1,<3.0.0"`' ) from e clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) moviepy_logger = None if self.disable_logger else "bar" path = os.path.join(self.video_folder, f"{self._video_name}.mp4") clip.write_videofile(path, logger=moviepy_logger) self.recorded_frames = [] self.recording = False self._video_name = None if self.gc_trigger and self.gc_trigger(self.episode_id): gc.collect() def __del__(self): if len(getattr(self, "recorded_frames", [])) > 0: gymnasium.logger.warn("Unable to save last video! Did you call close()?")
[docs] class RecordVideoAEC(BaseWrapper): """Record videos from a PettingZoo AEC environment.""" def __init__( self, env: AECEnv, video_folder: str, episode_trigger: Callable[[int], bool] | None = None, step_trigger: Callable[[int], bool] | None = None, video_length: int = 0, name_prefix: str = "rl-video", fps: int | None = None, disable_logger: bool = True, gc_trigger: Callable[[int], bool] | None = lambda episode: True, ): super().__init__(env) assert isinstance(env, AECEnv), "RecordVideoAEC is only compatible with AECEnv environments." if env.render_mode in {None, "human", "ansi"}: # type: ignore[attr-defined] raise ValueError( f"Render mode is {env.render_mode}, which is incompatible with RecordVideoAEC. " # type: ignore[attr-defined] "Initialize your environment with a render_mode that returns an image, such as rgb_array." ) if episode_trigger is None and step_trigger is None: episode_trigger = ( lambda episode_id: int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id if episode_id < 1000 else episode_id % 1000 == 0 ) self.episode_trigger = episode_trigger self.step_trigger = step_trigger self.disable_logger = disable_logger self.gc_trigger = gc_trigger self.video_folder = os.path.abspath(video_folder) if os.path.isdir(self.video_folder): gymnasium.logger.warn( f"Overwriting existing videos at {self.video_folder} folder " "(try specifying a different `video_folder` for the `RecordVideoAEC` wrapper if this is not desired)" ) os.makedirs(self.video_folder, exist_ok=True) if fps is None: fps = int(getattr(env, "metadata", {}).get("render_fps", 30)) self.frames_per_sec: int = fps self.name_prefix: str = name_prefix self._video_name: str | None = None self.video_length: int | float = video_length if video_length != 0 else float("inf") self.recording: bool = False self.recorded_frames: list[RenderFrame] = [] self.render_history: list[RenderFrame] = [] self.step_id: int = -1 self.episode_id: int = -1 try: import moviepy # noqa: F401 except ImportError as e: raise DependencyNotInstalled( 'MoviePy is not installed, run `pip install "moviepy>=2.2.1,<3.0.0"`' ) from e def _capture_frame(self): assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() if isinstance(frame, list): if len(frame) == 0: return self.render_history += frame frame = frame[-1] if isinstance(frame, np.ndarray): self.recorded_frames.append(frame) else: self.stop_recording() gymnasium.logger.warn( f"Recording stopped: expected type of frame returned by render to be a numpy array, got {type(frame)}." )
[docs] def reset(self, seed: int | None = None, options: dict | None = None): self.env.reset(seed=seed, options=options) self.episode_id += 1 if self.recording and self.video_length == float("inf"): self.stop_recording() if self.episode_trigger and self.episode_trigger(self.episode_id): self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording()
[docs] def step(self, action: ActionType): self.env.step(action) self.step_id += 1 if self.step_trigger and self.step_trigger(self.step_id): self.start_recording(f"{self.name_prefix}-step-{self.step_id}") if self.recording: self._capture_frame() if len(self.recorded_frames) > self.video_length: self.stop_recording()
[docs] def render(self): render_out = self.env.render() if self.recording and isinstance(render_out, list): self.recorded_frames += render_out if len(self.render_history) > 0: tmp_history = self.render_history self.render_history = [] frames = render_out if isinstance(render_out, list) else [render_out] return tmp_history + frames return render_out
[docs] def close(self): super().close() if self.recording: self.stop_recording()
def start_recording(self, video_name: str): if self.recording: self.stop_recording() self.recording = True self._video_name = video_name def stop_recording(self): assert self.recording, "stop_recording was called, but no recording was started" if len(self.recorded_frames) == 0: gymnasium.logger.warn("Ignored saving a video as there were zero frames to save.") else: try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip except ImportError as e: raise DependencyNotInstalled( 'MoviePy is not installed, run `pip install "moviepy>=2.2.1,<3.0.0"`' ) from e clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) moviepy_logger = None if self.disable_logger else "bar" path = os.path.join(self.video_folder, f"{self._video_name}.mp4") clip.write_videofile(path, logger=moviepy_logger) self.recorded_frames = [] self.recording = False self._video_name = None if self.gc_trigger and self.gc_trigger(self.episode_id): gc.collect() def __del__(self): if len(getattr(self, "recorded_frames", [])) > 0: gymnasium.logger.warn("Unable to save last video! Did you call close()?")
RecordVideo = RecordVideoParallel