Skip to content

Commit

Permalink
Add save_video util and deprecate RecordVideo in favor of it (#3016)
Browse files Browse the repository at this point in the history
* init

* add save_video util

* simplify API @pseudo-rnd-thoughts

* fix video_length and remove folder warning

* remove RecordVideo deprecation warnings

* add test record video

* avoid test failing cascade
  • Loading branch information
younik committed Aug 29, 2022
1 parent 44e9475 commit 2a9853f
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
109 changes: 109 additions & 0 deletions gym/utils/save_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Utility functions to save rendering videos."""
import os
from typing import Callable, Optional

import gym
from gym import logger

try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError:
raise gym.error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
)


def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
Args:
episode_id: The episode number
Returns:
If to apply a video schedule number
"""
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0


def save_video(
frames: list,
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: Optional[int] = None,
name_prefix: str = "rl-video",
episode_index: int = 0,
step_starting_index: int = 0,
**kwargs,
):
"""Save videos from rendering frames.
This function extract video from a list of render frame episodes.
Args:
frames (List[RenderFrame]): A list of frames to compose the video.
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If it isn't specified, the entire episode is recorded.
Otherwise, snippets of the specified length are captured.
name_prefix (str): Will be prepended to the filename of the recordings.
episode_index (int): The index of the current episode.
step_starting_index (int): The step index of the first frame.
**kwargs: The kwargs that will be passed to moviepy's ImageSequenceClip.
You need to specify either fps or duration.
Example:
>>> import gym
>>> from gym.utils.save_video import save_video
>>> env = gym.make("FrozenLake-v1", render_mode="rgb_array")
>>> env.reset()
>>> step_starting_index = 0
>>> episode_index = 0
>>> for step_index in range(199):
... action = env.action_space.sample()
... _, _, done, _ = env.step(action)
... if done:
... save_video(
... env.render(),
... "videos",
... fps=env.metadata["render_fps"],
... step_starting_index=step_starting_index,
... episode_index=episode_index
... )
... step_starting_index = step_index + 1
... episode_index += 1
... env.reset()
>>> env.close()
"""
if not isinstance(frames, list):
logger.error(
f"Expected a list of frames, got a {frames.__class__.__name__} instead."
)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule

video_folder = os.path.abspath(video_folder)
os.makedirs(video_folder, exist_ok=True)
path_prefix = f"{video_folder}/{name_prefix}"

if episode_trigger is not None and episode_trigger(episode_index):
clip = ImageSequenceClip(frames[:video_length], **kwargs)
clip.write_videofile(f"{path_prefix}-episode-{episode_index}.mp4")

if step_trigger is not None:
# skip the first frame since it comes from reset
for step_index, frame_index in enumerate(
range(1, len(frames)), start=step_starting_index
):
if step_trigger(step_index):
end_index = (
frame_index + video_length if video_length is not None else None
)
clip = ImageSequenceClip(frames[frame_index:end_index], **kwargs)
clip.write_videofile(f"{path_prefix}-step-{step_index}.mp4")
104 changes: 104 additions & 0 deletions tests/utils/test_save_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import shutil

import gym
from gym.utils.save_video import capped_cubic_video_schedule, save_video


def test_record_video_using_default_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)

env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()

env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == sum(
capped_cubic_video_schedule(i) for i in range(episode_index)
)


def modulo_step_trigger(mod: int):
def step_trigger(step_index):
return step_index % mod == 0

return step_trigger


def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env._max_episode_steps = 20

env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_trigger=modulo_step_trigger(100),
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()
env.close()

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 2


def test_record_video_within_vector():
envs = gym.vector.make(
"CartPole-v1", num_envs=2, asynchronous=True, render_mode="rgb_array"
)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs.reset()
episode_frames = []
step_starting_index = 0
episode_index = 0
for step_index in range(199):
_, _, _, infos = envs.step(envs.action_space.sample())
episode_frames.extend(envs.call("render")[0])

if "episode" in infos and infos["_episode"][0]:
save_video(
episode_frames,
"videos",
fps=envs.metadata["render_fps"],
step_trigger=modulo_step_trigger(100),
step_starting_index=step_starting_index,
episode_index=episode_index,
)
episode_frames = []
step_starting_index = step_index + 1
episode_index += 1

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 2

0 comments on commit 2a9853f

Please sign in to comment.