Skip to content

Commit

Permalink
Merge pull request #53 from kaseris/fix/transforms
Browse files Browse the repository at this point in the history
Handle the case of tensor inputs and squeeze the singleton dimensions
  • Loading branch information
kaseris committed Dec 14, 2023
2 parents cdd1abe + b3caa21 commit 34cd2c1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/skelcast/data/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import torch
from typing import Any, Tuple

Expand Down Expand Up @@ -50,8 +52,11 @@ def __init__(self, parents: list = None) -> None:
else:
self.parents = parents

def __call__(self, x) -> torch.Tensor:
return xyz_to_expmap(x, self.parents)
def __call__(self, x) -> torch.Tensor: # Really? Torch tensor? Check it again
logging.info(f'The type of x is: {type(x)}')
if isinstance(x, torch.Tensor):
return xyz_to_expmap(x.squeeze(1).numpy(), self.parents)
return xyz_to_expmap(x.squeeze(1), self.parents)


@TRANSFORMS.register_module()
Expand All @@ -74,8 +79,11 @@ def __init__(self, parents: list = None) -> None:
self.parents = parents

def __call__(self, x) -> Any:
if isinstance(x, torch.Tensor):
x = x.squeeze(1).numpy()
_exps = xyz_to_expmap(x, self.parents)
return exps_to_quats(_exps)
# Because it returns a tensor we need to convert it to numpy
return exps_to_quats(_exps.squeeze(1).numpy())

class Compose:
def __init__(self, transforms: list) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/skelcast/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Copied from https://github.com/qxcv/pose-prediction/blob/master/expmap.py"""
import torch
import numpy as np


Expand Down Expand Up @@ -85,7 +86,7 @@ def xyz_to_expmap(xyz_seq, parents):
root = toposorted[0]
exp_seq[1:, root] = xyz_seq[1:, root] - xyz_seq[:-1, root]

return exp_seq
return torch.from_numpy(exp_seq).unsqueeze(1)

def exp_to_rotmat(exp):
"""Convert rotation paramterised as exponential map into ordinary 3x3
Expand Down Expand Up @@ -133,4 +134,4 @@ def exps_to_quats(exps):
rv_flat[nonzero_mask, 1:] = nonzero_normed * sines[..., None]

rv_shape = exps.shape[:-1] + (4, )
return rv_flat.reshape(rv_shape)
return torch.from_numpy(rv_flat.reshape(rv_shape)).unsqueeze(1)

0 comments on commit 34cd2c1

Please sign in to comment.