Skip to content

Commit

Permalink
Support gymnasium (#226)
Browse files Browse the repository at this point in the history
* Upgrade from gym to gymnasium

* Support vector env

* Update dependencies

* Split state into multiple components

* Remove comment

* Add reset_to_demo option for gym env

* Revert task env step change and clean up dependencies

* Install gym deps for ci

* Fix typo

* Fix tests

* Remove imports

* Add import

* Add test for example scripts

* Comment out test

* Update pyrep dependency with cffi fix branch

* Typo

* Point to commit hash

* Comment out flaky test

* Move gym envs to rlbench/__init__.py

* Include package data

* Move dataset_generator.py into package; add rlbench-generate-dataset entry point

* Clean up old code and sort imports

* Update setup.py

---------

Co-authored-by: Stephen James <stepjamuk@gmail.com>
  • Loading branch information
eugeneteoh and stepjam committed Jul 2, 2024
1 parent f2c625f commit 299dddc
Show file tree
Hide file tree
Showing 17 changed files with 348 additions and 302 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/task_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
pip install ".[dev]"
pip install ".[gym,dev]"
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
pytest -v -n auto tests/demos
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ jobs:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
pip install ".[dev]"
pip install ".[gym,dev]"
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ the observation mode: 'state' or 'vision'.

```python
import gym
import rlbench.gym
import rlbench

env = gym.make('reach_target-state-v0')
# Alternatively, for vision:
Expand Down
13 changes: 9 additions & 4 deletions examples/rlbench_gym.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import gym
import rlbench.gym
import gymnasium as gym
from gymnasium.utils.performance import benchmark_step
import rlbench

env = gym.make('rlbench/reach_target-vision-v0', render_mode="rgb_array")

env = gym.make('reach_target-state-v0', render_mode='human')

training_steps = 120
episode_length = 40
for i in range(training_steps):
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _ = env.step(env.action_space.sample())
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')

fps = benchmark_step(env, target_duration=10)
print(f"FPS: {fps:.2f}")
env.close()
48 changes: 48 additions & 0 deletions examples/rlbench_gym_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import time
import gymnasium as gym
import rlbench


def benchmark_vector_step(env, target_duration: int = 5, seed=None) -> float:
steps = 0
end = 0.0
env.reset(seed=seed)
env.action_space.sample()
start = time.monotonic()

while True:
action = env.action_space.sample()
_, _, terminal, truncated, _ = env.step(action)
steps += terminal.shape[0]

# if terminal or truncated:
# env.reset()

if time.monotonic() - start > target_duration:
end = time.monotonic()
break

length = end - start

steps_per_time = steps / length
return steps_per_time

if __name__ == "__main__":
# Only works with spawn (multiprocessing) context
env = gym.make_vec('rlbench/reach_target-vision-v0', num_envs=2, vectorization_mode="async", vector_kwargs={"context": "spawn"})

training_steps = 120
episode_length = 40
for i in range(training_steps):
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')

fps = benchmark_vector_step(env, target_duration=10)
print(f"FPS: {fps:.2f}")

env.close()
57 changes: 44 additions & 13 deletions rlbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
__version__ = '1.2.0'
__version__ = "1.2.0"

import numpy as np
import pyrep

pr_v = np.array(pyrep.__version__.split('.'), dtype=int)
if pr_v.size < 4 or np.any(pr_v < np.array([4, 1, 0, 2])):
raise ImportError(
'PyRep version must be greater than 4.1.0.2. Please update PyRep.')
import os

from gymnasium import register

import rlbench.backend.task as task
from rlbench.action_modes.action_mode import (
ActionMode,
ArmActionMode,
GripperActionMode,
)
from rlbench.environment import Environment
from rlbench.action_modes.action_mode import ActionMode, ArmActionMode, GripperActionMode
from rlbench.observation_config import ObservationConfig
from rlbench.observation_config import CameraConfig
from rlbench.sim2real.domain_randomization import RandomizeEvery
from rlbench.sim2real.domain_randomization import VisualRandomizationConfig
from rlbench.observation_config import CameraConfig, ObservationConfig
from rlbench.sim2real.domain_randomization import (
RandomizeEvery,
VisualRandomizationConfig,
)
from rlbench.utils import name_to_task_class

__all__ = [
"ActionMode",
"ArmActionMode",
"GripperActionMode",
"CameraConfig",
"Environment",
"ObservationConfig",
"RandomizeEvery",
"VisualRandomizationConfig",
]

TASKS = [
t for t in os.listdir(task.TASKS_PATH) if t != "__init__.py" and t.endswith(".py")
]

for task_file in TASKS:
task_name = task_file.split(".py")[0]
task_class = name_to_task_class(task_name)
for obs_mode in ["state", "vision"]:
register(
id=f"rlbench/{task_name}-{obs_mode}-v0",
entry_point="rlbench.gym:RLBenchEnv",
kwargs={
"task_class": task_class,
"observation_mode": obs_mode,
},
nondeterministic=True,
)
Empty file added rlbench/assets/__init__.py
Empty file.
87 changes: 39 additions & 48 deletions tools/dataset_generator.py → rlbench/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,21 @@
from multiprocessing import Process, Manager
import argparse
import os
import pickle
from multiprocessing import Manager, Process

import numpy as np
from PIL import Image
from pyrep.const import RenderMode

import rlbench.backend.task as task
from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

import os
import pickle
from PIL import Image
from rlbench.backend import utils
from rlbench.backend.const import *
import numpy as np

from absl import app
from absl import flags

FLAGS = flags.FLAGS

flags.DEFINE_string('save_path',
'/tmp/rlbench_data/',
'Where to save the demos.')
flags.DEFINE_list('tasks', [],
'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_list('image_size', [128, 128],
'The size of the images tp save.')
flags.DEFINE_enum('renderer', 'opengl3', ['opengl', 'opengl3'],
'The renderer to use. opengl does not include shadows, '
'but is faster.')
flags.DEFINE_integer('processes', 1,
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
flags.DEFINE_integer('variations', -1,
'Number of variations to collect per task. -1 for all.')
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment


def check_and_make(dir):
Expand Down Expand Up @@ -166,15 +144,15 @@ def save_demo(demo, example_path):
pickle.dump(demo, f)


def run(i, lock, task_index, variation_count, results, file_lock, tasks):
def run(i, lock, task_index, variation_count, results, file_lock, tasks, args):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""

# Initialise each thread with random seed
np.random.seed(None)
num_tasks = len(tasks)

img_size = list(map(int, FLAGS.image_size))
img_size = list(map(int, args.image_size))

obs_config = ObservationConfig()
obs_config.set_all(True)
Expand All @@ -198,13 +176,13 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
obs_config.wrist_camera.masks_as_one_channel = False
obs_config.front_camera.masks_as_one_channel = False

if FLAGS.renderer == 'opengl':
if args.renderer == 'opengl':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.overhead_camera.render_mode = RenderMode.OPENGL
obs_config.wrist_camera.render_mode = RenderMode.OPENGL
obs_config.front_camera.render_mode = RenderMode.OPENGL
elif FLAGS.renderer == 'opengl3':
elif args.renderer == 'opengl3':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL3
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL3
obs_config.overhead_camera.render_mode = RenderMode.OPENGL3
Expand Down Expand Up @@ -233,8 +211,8 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
var_target = task_env.variation_count()
if FLAGS.variations >= 0:
var_target = np.minimum(FLAGS.variations, var_target)
if args.variations >= 0:
var_target = np.minimum(args.variations, var_target)
if my_variation_count >= var_target:
# If we have reached the required number of variations for this
# task, then move on to the next task.
Expand All @@ -252,7 +230,7 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
descriptions, _ = task_env.reset()

variation_path = os.path.join(
FLAGS.save_path, task_env.get_name(),
args.save_path, task_env.get_name(),
VARIATIONS_FOLDER % my_variation_count)

check_and_make(variation_path)
Expand All @@ -265,7 +243,7 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
check_and_make(episodes_path)

abort_variation = False
for ex_idx in range(FLAGS.episodes_per_task):
for ex_idx in range(args.episodes_per_task):
print('Process', i, '// Task:', task_env.get_name(),
'// Variation:', my_variation_count, '// Demo:', ex_idx)
attempts = 10
Expand Down Expand Up @@ -300,16 +278,29 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
rlbench_env.shutdown()


def main(argv):
def parse_args():
parser = argparse.ArgumentParser(description="RLBench Dataset Generator")
parser.add_argument('--save_path', type=str, default='/tmp/rlbench_data/', help='Where to save the demos.')
parser.add_argument('--tasks', nargs='*', default=[], help='The tasks to collect. If empty, all tasks are collected.')
parser.add_argument('--image_size', nargs=2, type=int, default=[128, 128], help='The size of the images to save.')
parser.add_argument('--renderer', type=str, choices=['opengl', 'opengl3'], default='opengl3', help='The renderer to use. opengl does not include shadows, but is faster.')
parser.add_argument('--processes', type=int, default=1, help='The number of parallel processes during collection.')
parser.add_argument('--episodes_per_task', type=int, default=10, help='The number of episodes to collect per task.')
parser.add_argument('--variations', type=int, default=-1, help='Number of variations to collect per task. -1 for all.')
return parser.parse_args()


def main():
args = parse_args()

task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]

if len(FLAGS.tasks) > 0:
for t in FLAGS.tasks:
if len(args.tasks) > 0:
for t in args.tasks:
if t not in task_files:
raise ValueError('Task %s not recognised!.' % t)
task_files = FLAGS.tasks
task_files = args.tasks

tasks = [task_file_to_task_class(t) for t in task_files]

Expand All @@ -322,20 +313,20 @@ def main(argv):
variation_count = manager.Value('i', 0)
lock = manager.Lock()

check_and_make(FLAGS.save_path)
check_and_make(args.save_path)

processes = [Process(
target=run, args=(
i, lock, task_index, variation_count, result_dict, file_lock,
tasks))
for i in range(FLAGS.processes)]
tasks, args))
for i in range(args.processes)]
[t.start() for t in processes]
[t.join() for t in processes]

print('Data collection done!')
for i in range(FLAGS.processes):
for i in range(args.processes):
print(result_dict[i])


if __name__ == '__main__':
app.run(main)
main()
Loading

0 comments on commit 299dddc

Please sign in to comment.