Skip to content

Commit

Permalink
Merge pull request #61 from kaseris/fix/sttf
Browse files Browse the repository at this point in the history
Fix/sttf
  • Loading branch information
kaseris committed Jan 4, 2024
2 parents aacebd7 + 4b21b0a commit 8142108
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
9 changes: 3 additions & 6 deletions configs/sttf_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: CartToQuaternionTransform
args:
parents: null

loss:
name: SmoothL1Loss
Expand Down Expand Up @@ -45,7 +42,7 @@ model:
args:
n_joints: 25
d_model: 256
n_blocks: 8
n_blocks: 3
n_heads: 8
d_head: 16
mlp_dim: 512
Expand All @@ -54,8 +51,8 @@ model:
runner:
name: Runner
args:
train_batch_size: 1024
val_batch_size: 1024
train_batch_size: 32
val_batch_size: 32
block_size: 8
log_gradient_info: true
device: cuda
Expand Down
4 changes: 2 additions & 2 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def training_step(self, train_batch: NTURGBDSample):
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.train()
out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well
out = self.model.training_step(x=x, y=y, mask=mask) # TODO: Make the other models accept a mask as well
loss = out['loss']
outputs = out['out']
# Calculate the saturation of the tanh output
Expand Down Expand Up @@ -229,7 +229,7 @@ def validation_step(self, val_batch: NTURGBDSample):
x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32)
x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device)
self.model.eval()
out = self.model.validation_step(x, y, mask)
out = self.model.validation_step(x=x, y=y, mask=mask)
loss = out['loss']
self.validation_loss_per_step.append(loss.item())
# Log it to the logger
Expand Down
9 changes: 4 additions & 5 deletions src/skelcast/models/transformers/sttf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor):
q_proj = self.q(x)
k_proj = self.k(x)
v_proj = self.v(x)
mask = self.get_mask(seq_len, batch_size)
mask = self.get_mask(seq_len, batch_size).to(x.device)
attn_prod_ = torch.bmm(q_proj, k_proj.permute(0, 2, 1)) * (self.d_model) ** -0.5

attn_temporal = F.softmax(attn_prod_ + mask, dim=-1)
Expand Down Expand Up @@ -198,16 +198,15 @@ def __init__(self, n_joints,
self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False)

def forward(self, x: torch.Tensor):
if x.ndim > 4:
x = x.squeeze(2)
batch_size, seq_len, n_joints, dims = x.shape
input_ = x.view(batch_size, seq_len, n_joints * dims)
o = self.embedding(input_)
print(f'o shape after embedding: {o.shape}')
o = self.pe.pe.repeat(batch_size, 1, 1)[:, :seq_len, :] + o
print(f'o shape after positional encoding: {o.shape}')
o = self.pre_dropout(o)
o = o.view(batch_size, seq_len, n_joints, self.d_model)
o = self.transformer(o)
print(f'o shape after transformer: {o.shape}')
out = self.linear_out(o) + x
return out

Expand All @@ -217,7 +216,7 @@ def training_step(self, **kwargs) -> dict:
# Forward pass
out = self(x)
# Compute the loss
loss = self.loss_fn(out, y)
loss = self.loss_fn(out, y.squeeze(2))
return {'loss': loss, 'out': out}

def validation_step(self, *args, **kwargs):
Expand Down

0 comments on commit 8142108

Please sign in to comment.