-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Diffusion policy baseline (#493)
* init with tongzhou's code * work * work * work * bug fix with IK on CPU * Update README.md * work * Update baselines.md * Update README.md
- Loading branch information
1 parent
899b174
commit f565bc1
Showing
15 changed files
with
991 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
__pycache__/ | ||
runs/ | ||
wandb/ | ||
*.egg-info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Diffusion Policy | ||
|
||
Code for running the Diffusion Policy algorithm is adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl/). It is written to be single-file and easy to follow/read, and supports state-based RL and visual-based RL code. | ||
|
||
## Installation | ||
|
||
To get started, we recommend using conda/mamba to create a new environment and install the dependencies | ||
|
||
```bash | ||
conda create -n diffusion-policy-ms python=3.9 | ||
conda activate diffusion-policy-ms | ||
pip install -e . | ||
``` | ||
|
||
## Demonstration Download and Preprocessing | ||
|
||
By default for fast downloads and smaller file sizes, ManiSkill demonstrations are stored in a highly reduced/compressed format which includes not keeping any observation data. Run the command to download the demonstration and convert it to a format that includes observation data and the desired action space. | ||
|
||
```bash | ||
python -m mani_skill.utils.download_demo "PickCube-v1" | ||
``` | ||
|
||
```bash | ||
env_id="PickCube-v1" | ||
python -m mani_skill.trajectory.replay_trajectory \ | ||
--traj-path ~/.maniskill/demos/${env_id}/motionplanning/trajectory.h5 \ | ||
--use-first-env-state \ | ||
-c pd_ee_delta_pose -o state \ | ||
--save-traj --num-procs 10 | ||
``` | ||
|
||
## Training | ||
|
||
We further add a `--max_episode_steps` argument to the training script to allow for longer demonstrations to be learned from (such as motionplanning / teleoperated demonstrations). By default the max episode steps of most environments are tuned lower so reinforcement learning agents can learn faster. | ||
|
||
```bash | ||
seed=42 | ||
demos=100 | ||
env_id="PickCube-v1" | ||
python train.py --env-id ${env_id} --max_episode_steps 100 \ | ||
--control-mode "pd_joint_delta_pos" --num-demos ${demos} --seed ${seed} \ | ||
--demo-path ~/.maniskill/demos/${env_id}/motionplanning/trajectory.state.pd_joint_delta_pos.h5 \ | ||
--exp-name diffusion_policy-${env_id}-state-${demos}_motionplanning_demos-${seed} \ | ||
--demo_type="motionplanning" --track # additional tag for logging purposes on wandb | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
for env_id in PickCube-v1 | ||
do | ||
python train.py --env-id ${env_id} --max_episode_steps 100 \ | ||
--control-mode "pd_joint_delta_pos" \ | ||
--demo-path ~/.maniskill/demos/${env_id}/motionplanning/trajectory.state.pd_joint_delta_pos.h5 \ | ||
--track | ||
done |
264 changes: 264 additions & 0 deletions
264
examples/baselines/diffusion_policy/diffusion_policy/conditional_unet1d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
#@markdown ### **Network** | ||
#@markdown | ||
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D` | ||
#@markdown as the noies prediction network | ||
#@markdown | ||
#@markdown Components | ||
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k | ||
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution | ||
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution | ||
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish | ||
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \ | ||
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. | ||
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning. | ||
|
||
""" | ||
Note: This is copied from the colab notebook. | ||
The main difference with the github repo code is in `class ConditionalUnet1D` -- this version makes some simplifications. | ||
""" | ||
|
||
|
||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
class SinusoidalPosEmb(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.dim = dim | ||
|
||
def forward(self, x): | ||
device = x.device | ||
half_dim = self.dim // 2 | ||
emb = math.log(10000) / (half_dim - 1) | ||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | ||
emb = x[:, None] * emb[None, :] | ||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | ||
return emb | ||
|
||
|
||
class Downsample1d(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
class Upsample1d(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
class Conv1dBlock(nn.Module): | ||
''' | ||
Conv1d --> GroupNorm --> Mish | ||
''' | ||
|
||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): | ||
super().__init__() | ||
|
||
self.block = nn.Sequential( | ||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), | ||
nn.GroupNorm(n_groups, out_channels), | ||
nn.Mish(), | ||
) | ||
|
||
def forward(self, x): | ||
return self.block(x) | ||
|
||
|
||
class ConditionalResidualBlock1D(nn.Module): | ||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
cond_dim, | ||
kernel_size=3, | ||
n_groups=8): | ||
super().__init__() | ||
|
||
self.blocks = nn.ModuleList([ | ||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), | ||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), | ||
]) | ||
|
||
# FiLM modulation https://arxiv.org/abs/1709.07871 | ||
# predicts per-channel scale and bias | ||
cond_channels = out_channels * 2 | ||
self.out_channels = out_channels | ||
self.cond_encoder = nn.Sequential( | ||
nn.Mish(), | ||
nn.Linear(cond_dim, cond_channels), | ||
nn.Unflatten(-1, (-1, 1)) | ||
) | ||
|
||
# make sure dimensions compatible | ||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ | ||
if in_channels != out_channels else nn.Identity() | ||
|
||
def forward(self, x, cond): | ||
''' | ||
x : [ batch_size x in_channels x horizon ] | ||
cond : [ batch_size x cond_dim] | ||
returns: | ||
out : [ batch_size x out_channels x horizon ] | ||
''' | ||
out = self.blocks[0](x) | ||
embed = self.cond_encoder(cond) | ||
|
||
embed = embed.reshape( | ||
embed.shape[0], 2, self.out_channels, 1) | ||
scale = embed[:,0,...] | ||
bias = embed[:,1,...] | ||
out = scale * out + bias | ||
|
||
out = self.blocks[1](out) | ||
out = out + self.residual_conv(x) | ||
return out | ||
|
||
|
||
class ConditionalUnet1D(nn.Module): | ||
def __init__(self, | ||
input_dim, | ||
global_cond_dim, | ||
diffusion_step_embed_dim=256, | ||
down_dims=[256,512,1024], | ||
kernel_size=5, | ||
n_groups=8 | ||
): | ||
""" | ||
input_dim: Dim of actions. | ||
global_cond_dim: Dim of global conditioning applied with FiLM | ||
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim | ||
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k | ||
down_dims: Channel size for each UNet level. | ||
The length of this array determines numebr of levels. | ||
kernel_size: Conv kernel size | ||
n_groups: Number of groups for GroupNorm | ||
""" | ||
|
||
super().__init__() | ||
all_dims = [input_dim] + list(down_dims) | ||
start_dim = down_dims[0] | ||
|
||
dsed = diffusion_step_embed_dim | ||
diffusion_step_encoder = nn.Sequential( | ||
SinusoidalPosEmb(dsed), | ||
nn.Linear(dsed, dsed * 4), | ||
nn.Mish(), | ||
nn.Linear(dsed * 4, dsed), | ||
) | ||
cond_dim = dsed + global_cond_dim | ||
|
||
in_out = list(zip(all_dims[:-1], all_dims[1:])) | ||
mid_dim = all_dims[-1] | ||
self.mid_modules = nn.ModuleList([ | ||
ConditionalResidualBlock1D( | ||
mid_dim, mid_dim, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups | ||
), | ||
ConditionalResidualBlock1D( | ||
mid_dim, mid_dim, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups | ||
), | ||
]) | ||
|
||
down_modules = nn.ModuleList([]) | ||
for ind, (dim_in, dim_out) in enumerate(in_out): | ||
is_last = ind >= (len(in_out) - 1) | ||
down_modules.append(nn.ModuleList([ | ||
ConditionalResidualBlock1D( | ||
dim_in, dim_out, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups), | ||
ConditionalResidualBlock1D( | ||
dim_out, dim_out, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups), | ||
Downsample1d(dim_out) if not is_last else nn.Identity() | ||
])) | ||
|
||
up_modules = nn.ModuleList([]) | ||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | ||
is_last = ind >= (len(in_out) - 1) | ||
up_modules.append(nn.ModuleList([ | ||
ConditionalResidualBlock1D( | ||
dim_out*2, dim_in, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups), | ||
ConditionalResidualBlock1D( | ||
dim_in, dim_in, cond_dim=cond_dim, | ||
kernel_size=kernel_size, n_groups=n_groups), | ||
Upsample1d(dim_in) if not is_last else nn.Identity() | ||
])) | ||
|
||
final_conv = nn.Sequential( | ||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), | ||
nn.Conv1d(start_dim, input_dim, 1), | ||
) | ||
|
||
self.diffusion_step_encoder = diffusion_step_encoder | ||
self.up_modules = up_modules | ||
self.down_modules = down_modules | ||
self.final_conv = final_conv | ||
|
||
n_params = sum(p.numel() for p in self.parameters()) | ||
print(f"number of parameters: {n_params / 1e6:.2f}M") | ||
|
||
def forward(self, | ||
sample: torch.Tensor, | ||
timestep: Union[torch.Tensor, float, int], | ||
global_cond=None): | ||
""" | ||
x: (B,T,input_dim) | ||
timestep: (B,) or int, diffusion step | ||
global_cond: (B,global_cond_dim) | ||
output: (B,T,input_dim) | ||
""" | ||
# (B,T,C) | ||
sample = sample.moveaxis(-1,-2) | ||
# (B,C,T) | ||
|
||
# 1. time | ||
timesteps = timestep | ||
if not torch.is_tensor(timesteps): | ||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | ||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) | ||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | ||
timesteps = timesteps[None].to(sample.device) | ||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
timesteps = timesteps.expand(sample.shape[0]) | ||
|
||
global_feature = self.diffusion_step_encoder(timesteps) | ||
|
||
if global_cond is not None: | ||
global_feature = torch.cat([ | ||
global_feature, global_cond | ||
], axis=-1) | ||
|
||
x = sample | ||
h = [] | ||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): | ||
x = resnet(x, global_feature) | ||
x = resnet2(x, global_feature) | ||
h.append(x) | ||
x = downsample(x) | ||
|
||
for mid_module in self.mid_modules: | ||
x = mid_module(x, global_feature) | ||
|
||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): | ||
x = torch.cat((x, h.pop()), dim=1) | ||
x = resnet(x, global_feature) | ||
x = resnet2(x, global_feature) | ||
x = upsample(x) | ||
|
||
x = self.final_conv(x) | ||
|
||
# (B,C,T) | ||
x = x.moveaxis(-1,-2) | ||
# (B,T,C) | ||
return x |
32 changes: 32 additions & 0 deletions
32
examples/baselines/diffusion_policy/diffusion_policy/make_env.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Optional | ||
import gymnasium as gym | ||
import mani_skill.envs | ||
from mani_skill.utils.wrappers import RecordEpisode | ||
from mani_skill.utils.wrappers.gymnasium import CPUGymWrapper | ||
from diffusion_policy.wrappers import SeqActionWrapper | ||
|
||
|
||
def make_env(env_id, num_envs: int, sim_backend: str, seed: int, env_kwargs: dict, other_kwargs: dict,video_dir: Optional[str] = None): | ||
if sim_backend == "cpu": | ||
def cpu_make_env(env_id, seed, video_dir=None, env_kwargs = dict(), other_kwargs = dict()): | ||
def thunk(): | ||
env = gym.make(env_id, **env_kwargs) | ||
env = CPUGymWrapper(env) | ||
if video_dir: | ||
env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, info_on_video=True) | ||
|
||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
env = gym.wrappers.ClipAction(env) | ||
env = gym.wrappers.FrameStack(env, other_kwargs['obs_horizon']) | ||
env = SeqActionWrapper(env) | ||
|
||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
return env | ||
|
||
return thunk | ||
vector_cls = gym.vector.SyncVectorEnv if num_envs == 1 else lambda x : gym.vector.AsyncVectorEnv(x, context="forkserver") | ||
env = vector_cls([cpu_make_env(env_id, seed, video_dir if seed == 0 else None, env_kwargs, other_kwargs) for seed in range(num_envs)]) | ||
else: | ||
env = gym.make(env_id, num_envs=num_envs, sim_backend=sim_backend, **env_kwargs) | ||
return env |
Oops, something went wrong.