Skip to content

Commit

Permalink
feat: implement octahedral lifting and separable kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
dgcnz committed May 21, 2024
1 parent f93d740 commit 220c2a4
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 5 deletions.
1 change: 1 addition & 0 deletions gconv/gnn/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
GLiftingKernel,
GSubgroupKernel,
RGLiftingKernel,
RGSeparableKernel
)
from .kernel_sen import *
from .kernel_en import *
135 changes: 134 additions & 1 deletion gconv/gnn/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.nn as nn
import torch.nn.init as init
from gconv.utils import unsqueeze_like

from torch import Tensor

Expand Down Expand Up @@ -276,7 +277,10 @@ def forward(self, H) -> Tensor:
)
weight = torch.sum(
self.relaxed_weights.view(
self.num_filter_banks, 1, self.relaxed_weights.shape[1], *product_dims,
self.num_filter_banks,
1,
self.relaxed_weights.shape[1],
*product_dims,
)
* weight,
dim=0,
Expand Down Expand Up @@ -401,6 +405,135 @@ def forward(self, in_H: Tensor, out_H: Tensor) -> tuple[Tensor, Tensor]:
return weight_H, weight


class RGSeparableKernel(GroupKernel): # TODO: Implement this
def __init__(
self,
in_channels: int,
out_channels: int,
num_filter_banks: int,
kernel_size: tuple,
group_kernel_size: tuple,
grid_H: Tensor,
grid_Rn: Tensor,
groups: int = 1,
mask: Tensor | None = None,
det_H: Callable | None = None,
inverse_H: Callable | None = None,
left_apply_to_H: Callable | None = None,
left_apply_to_Rn: Callable | None = None,
sample_H: Callable | None = None,
sample_Rn: Callable | None = None,
sample_H_kwargs: dict = {},
sample_Rn_kwargs: dict = {},
) -> None:
"""
The separable kernel manages the group and weights for
separable group convolutions, returning weights for
subgroup H and Rn separately.
"""
super().__init__(
in_channels,
out_channels,
kernel_size,
group_kernel_size=group_kernel_size,
grid_H=grid_H,
grid_Rn=grid_Rn,
groups=groups,
mask=mask,
det_H=det_H,
inverse_H=inverse_H,
left_apply_to_H=left_apply_to_H,
left_apply_to_Rn=left_apply_to_Rn,
sample_H=sample_H,
sample_Rn=sample_Rn,
sample_H_kwargs=sample_H_kwargs,
sample_Rn_kwargs=sample_Rn_kwargs,
)

if len(group_kernel_size) != 1:
raise NotImplementedError(
"Relaxed Group Convolutions only support group kernels of size 1"
)
self.weight_H = nn.Parameter(
torch.empty(
num_filter_banks,
self._group_kernel_dim,
out_channels,
in_channels // groups,
)
)
self.num_filter_banks = num_filter_banks
self.rweight_H = nn.Parameter(
torch.ones(num_filter_banks, self._group_kernel_dim)
)

self.weight = nn.Parameter(torch.empty(out_channels, 1, *kernel_size))
self.weight_dims = (1,) * len(kernel_size)

self.reset_parameters()

def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight_H, a=math.sqrt(5))
init.kaiming_uniform_(self.weight, a=math.sqrt(5))

def forward(self, in_H: Tensor, out_H: Tensor) -> tuple[Tensor, Tensor]:
num_in_H, num_out_H = in_H.shape[0], out_H.shape[0]
out_H_inverse = self.inverse_H(out_H)

H_product_H = self.left_apply_to_H(out_H_inverse, in_H)
H_product_Rn = self.left_apply_to_Rn(out_H_inverse, self.grid_Rn)

product_dims = (1,) * (H_product_Rn.ndim - 1)
# TODO: vectorize this for loop
weight_H = torch.stack(
[
self.sample_H(
H_product_H.flatten(0, 1),
self.weight_H[i].flatten(1, -1),
self.grid_H,
**self.sample_H_kwargs,
)
.view(
num_in_H,
num_out_H,
self.in_channels // self.groups,
self.out_channels,
*self.weight_dims,
)
.transpose(0, 3)
.transpose(1, 3)
for i in range(self.num_filter_banks)
],
dim=0,
)

# linear combination of filters
weight_H = torch.sum(
unsqueeze_like(self.rweight_H[:, None, :], weight_H) * weight_H,
dim=0,
)

# sample R3
weight = self.sample_Rn(
self.weight.repeat_interleave(num_out_H, dim=0),
H_product_Rn.repeat(self.out_channels, *product_dims),
**self.sample_Rn_kwargs,
).view(
self.out_channels,
num_out_H,
1,
*self.kernel_size,
)

if self.mask is not None:
weight = self.mask * weight

if self.det_H is not None:
weight = weight / self.det_H(out_H).view(-1, 1, *self.weight_dims)

return weight_H, weight


class GSubgroupKernel(GroupKernel):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion gconv/gnn/kernels/kernel_sen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .kernel_se2 import GLiftingKernelSE2, GSeparableKernelSE2, GKernelSE2, RGLiftingKernelSE2
from .kernel_se3 import GLiftingKernelSE3, GSeparableKernelSE3, GKernelSE3, RGLiftingKernelSE3
from .kernel_se3 import GLiftingKernelSE3, GSeparableKernelSE3, GKernelSE3, RGLiftingKernelSE3, RGSeparableKernelSE3
from .kernel_so3 import GSubgroupKernelSO3
77 changes: 76 additions & 1 deletion gconv/gnn/kernels/kernel_sen/kernel_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Optional

from gconv.gnn.kernels import GKernel, GSeparableKernel, GLiftingKernel, RGLiftingKernel
from gconv.gnn.kernels import GKernel, GSeparableKernel, GLiftingKernel, RGLiftingKernel, RGSeparableKernel

from torch import Tensor

Expand Down Expand Up @@ -208,6 +208,81 @@ def __init__(
sample_Rn_kwargs=sample_Rn_kwargs,
)

class RGSeparableKernelSE3(RGSeparableKernel):
def __init__(
self,
in_channels: int,
out_channels: int,
num_filter_banks: int,
kernel_size: int,
group_kernel_size: int = 4,
groups: int = 1,
group_sampling_mode: str = "rbf",
group_sampling_width: float = 0.0,
spatial_sampling_mode: str = "bilinear",
spatial_sampling_padding_mode: str = "border",
mask: bool = True,
grid_H: Optional[Tensor] = None,
) -> None:
"""
:param in_channels: int denoting the number of input channels.
:param out_channels: int denoting the number of output channels.
:param num_filter_banks: int denoting the number of filter banks for relaxed equivariance.
:param kernel_size: int denoting the spatial kernel size.
:param group_kernel_size: int denoting the group kernel size.
:param groups: number of groups for depth-wise separability.
:param group_sampling_mode: str indicating the sampling mode. Supports rbf (default)
or nearest.
:param group_sampling_width: float denoting the width of the Gaussian rbf kernels.
If 0.0 (default, recommended), width will be initialized
based on grid_H density.
:param spatial_sampling_mode: str indicating the sampling mode. Supports bilinear (default)
or nearest.
:param spatial_sampling_padding_mode: str indicating padding mode for sampling. Default
border.
:param mask: bool if true, will initialize spherical mask.
:param grid_H: tensor of reference grid used for interpolation. If not
provided, a uniform grid of group_kernel_size will be
generated. If provided, will overwrite given group_kernel_size.
"""
if grid_H is None:
grid_H = so3.uniform_grid(group_kernel_size, "matrix")

grid_Rn = gF.create_grid_R3(kernel_size)

if not group_sampling_width:
group_sampling_width = 0.8 * so3.nearest_neighbour_distance(grid_H).mean()

sample_H_kwargs = {
"mode": group_sampling_mode,
"width": group_sampling_width,
}
sample_Rn_kwargs = {
"mode": spatial_sampling_mode,
"padding_mode": spatial_sampling_padding_mode,
}

mask = gF.create_spherical_mask_R3(kernel_size) if mask else None

super().__init__(
in_channels,
out_channels,
num_filter_banks,
(kernel_size, kernel_size, kernel_size),
(grid_H.shape[0],),
grid_H,
grid_Rn,
groups,
mask=mask,
inverse_H=so3.matrix_inverse,
left_apply_to_H=so3.left_apply_to_matrix,
left_apply_to_Rn=so3.left_apply_to_R3,
sample_H=so3.grid_sample,
sample_Rn=gF.grid_sample,
sample_H_kwargs=sample_H_kwargs,
sample_Rn_kwargs=sample_Rn_kwargs,
)


class GKernelSE3(GKernel):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion gconv/gnn/modules/gconv_sen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .gconv_se2 import GLiftingConvSE2, GSeparableConvSE2, GConvSE2, RGLiftingConvSE2
from .gconv_se3 import GLiftingConvSE3, GSeparableConvSE3, GConvSE3, RGLiftingConvSE3
from .gconv_se3 import GLiftingConvSE3, GSeparableConvSE3, GConvSE3, RGLiftingConvSE3, RGSeparableConvSE3
98 changes: 97 additions & 1 deletion gconv/gnn/modules/gconv_sen/gconv_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor

from gconv.gnn.modules.gconv import GLiftingConv3d, GSeparableConv3d, GConv3d
from gconv.gnn.kernels import GLiftingKernelSE3, GSeparableKernelSE3, GKernelSE3, RGLiftingKernelSE3
from gconv.gnn.kernels import GLiftingKernelSE3, GSeparableKernelSE3, GKernelSE3, RGLiftingKernelSE3, RGSeparableKernelSE3

from gconv.geometry import so3

Expand Down Expand Up @@ -266,6 +266,102 @@ def forward(
)

return super().forward(input, in_H, out_H)


class RGSeparableConvSE3(GSeparableConv3d):
""" Implements SE3 separable group convolution."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_filter_banks: int,
kernel_size: int,
group_kernel_size: int = 4,
groups: int = 1,
stride: int = 1,
padding: int | str = 0,
dilation: int = 1,
padding_mode: str = "zeros",
permute_output_grid: bool = True,
group_sampling_mode: str = "rbf",
group_sampling_width: float = 0.0,
spatial_sampling_mode: str = "bilinear",
spatial_sampling_padding_mode: str = "border",
mask: bool = True,
bias: bool = False,
grid_H: Optional[Tensor] = None,
) -> None:
"""
:param int_channels: int denoting the number of input channels.
:param out_channels: int denoting the number of output channels.
:param num_filter_banks: int denoting the number of filter banks.
:param kernel_size: tuple denoting the spatial kernel size.
:param groups: int denoting the number of groups for depth-wise separability.
:param stride: int denoting the stride.
:param padding: int or denoting padding.
:param dilation: int denoting dilation.
:param group_kernel_size: int denoting the group kernel size (default 4).
:param grid_H: tensor of shape (N, 3, 3) of SO3 elements (rotation matrices). If
not provided, a uniform grid will be initalizd of size group_kernel_size.
If provided, group_kernel_size will be set to N.
:param padding_mode: str denoting the padding mode.
:param permute_output_grid: bool that if true will randomly permute output group grid
for estimating continuous groups.
:param group_sampling_mode: str denoting mode used for sampling group weights. Supports
rbf (default) or nearest.
:param group_sampling_width: float denoting width of Gaussian rbf kernel when using rbf sampling.
If set to 0.0 (default, recommended), width will be initialized on
the density of grid_H.
:param spatial_sampling_mode: str denoting mode used for sampling spatial weights. Supports
bilinear (default) or nearest.
:param spatial_sampling_padding_mode: str denoting padding mode for spatial weight sampling,
border (default) is recommended.
:param bias: bool that if true, will initialzie bias parameters.
:param mask: bool that if true, will initialize spherical mask applied to spatial weights.
"""
kernel = RGSeparableKernelSE3(
in_channels,
out_channels,
num_filter_banks,
kernel_size,
group_kernel_size,
groups=groups,
group_sampling_mode=group_sampling_mode,
group_sampling_width=group_sampling_width,
spatial_sampling_mode=spatial_sampling_mode,
spatial_sampling_padding_mode=spatial_sampling_padding_mode,
mask=mask,
grid_H=grid_H,
)

self.permute_output_grid = permute_output_grid

super().__init__(
in_channels,
out_channels,
kernel_size,
group_kernel_size,
kernel,
groups,
stride,
padding,
dilation,
padding_mode,
bias,
)

def forward(
self, input: Tensor, in_H: Tensor, out_H: Optional[Tensor] = None
) -> tuple[Tensor, Tensor]:
if out_H is None:
out_H = in_H

if self.permute_output_grid:
out_H = so3.left_apply_matrix(
so3.random_matrix(1, device=input.device), out_H
)

return super().forward(input, in_H, out_H)


class GConvSE3(GConv3d):
Expand Down
15 changes: 15 additions & 0 deletions gconv/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor) -> torch.Tensor:
""" Unsqueeze last dimensions of tensor to match another tensor's number of dimensions.
:param tensor: tensor to unsqueeze
:param like: tensor whose dimensions to match
"""
n_unsqueezes = like.ndim - tensor.ndim
if n_unsqueezes < 0:
raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}")
elif n_unsqueezes == 0:
return tensor
else:
return tensor[(...,) + (None,) * n_unsqueezes]

0 comments on commit 220c2a4

Please sign in to comment.