From 89a77090b6633883af447a654cf60c769b291e6a Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 4 Jan 2024 14:08:15 +0200 Subject: [PATCH 1/3] Needed to lower the batch size, due to the model memory requirements --- configs/sttf_base.yaml | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/configs/sttf_base.yaml b/configs/sttf_base.yaml index 0086f7e..d099aba 100644 --- a/configs/sttf_base.yaml +++ b/configs/sttf_base.yaml @@ -14,9 +14,6 @@ transforms: - name: MinMaxScaleTransform args: feature_scale: [0.0, 1.0] - - name: CartToQuaternionTransform - args: - parents: null loss: name: SmoothL1Loss @@ -45,7 +42,7 @@ model: args: n_joints: 25 d_model: 256 - n_blocks: 8 + n_blocks: 3 n_heads: 8 d_head: 16 mlp_dim: 512 @@ -54,8 +51,8 @@ model: runner: name: Runner args: - train_batch_size: 1024 - val_batch_size: 1024 + train_batch_size: 32 + val_batch_size: 32 block_size: 8 log_gradient_info: true device: cuda From be58c10a6cc0782ab1de9174c188800a8af60095 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 4 Jan 2024 14:08:44 +0200 Subject: [PATCH 2/3] Pass the inputs as kwargs for the model training steps --- src/skelcast/experiments/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 6ae305e..7d0a13b 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -186,7 +186,7 @@ def training_step(self, train_batch: NTURGBDSample): x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32) x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device) self.model.train() - out = self.model.training_step(x, y, mask) # TODO: Make the other models accept a mask as well + out = self.model.training_step(x=x, y=y, mask=mask) # TODO: Make the other models accept a mask as well loss = out['loss'] outputs = out['out'] # Calculate the saturation of the tanh output @@ -229,7 +229,7 @@ def validation_step(self, val_batch: NTURGBDSample): x, y, mask = x.to(torch.float32), y.to(torch.float32), mask.to(torch.float32) x, y, mask = x.to(self.device), y.to(self.device), mask.to(self.device) self.model.eval() - out = self.model.validation_step(x, y, mask) + out = self.model.validation_step(x=x, y=y, mask=mask) loss = out['loss'] self.validation_loss_per_step.append(loss.item()) # Log it to the logger From 4b21b0a7b35b1cb037306d8434fe85b2b7b42a77 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 4 Jan 2024 14:09:41 +0200 Subject: [PATCH 3/3] Dimensionality fixes --- src/skelcast/models/transformers/sttf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/skelcast/models/transformers/sttf.py b/src/skelcast/models/transformers/sttf.py index 9651896..f08839e 100644 --- a/src/skelcast/models/transformers/sttf.py +++ b/src/skelcast/models/transformers/sttf.py @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor): q_proj = self.q(x) k_proj = self.k(x) v_proj = self.v(x) - mask = self.get_mask(seq_len, batch_size) + mask = self.get_mask(seq_len, batch_size).to(x.device) attn_prod_ = torch.bmm(q_proj, k_proj.permute(0, 2, 1)) * (self.d_model) ** -0.5 attn_temporal = F.softmax(attn_prod_ + mask, dim=-1) @@ -198,16 +198,15 @@ def __init__(self, n_joints, self.linear_out = nn.Linear(in_features=d_model, out_features=3, bias=False) def forward(self, x: torch.Tensor): + if x.ndim > 4: + x = x.squeeze(2) batch_size, seq_len, n_joints, dims = x.shape input_ = x.view(batch_size, seq_len, n_joints * dims) o = self.embedding(input_) - print(f'o shape after embedding: {o.shape}') o = self.pe.pe.repeat(batch_size, 1, 1)[:, :seq_len, :] + o - print(f'o shape after positional encoding: {o.shape}') o = self.pre_dropout(o) o = o.view(batch_size, seq_len, n_joints, self.d_model) o = self.transformer(o) - print(f'o shape after transformer: {o.shape}') out = self.linear_out(o) + x return out @@ -217,7 +216,7 @@ def training_step(self, **kwargs) -> dict: # Forward pass out = self(x) # Compute the loss - loss = self.loss_fn(out, y) + loss = self.loss_fn(out, y.squeeze(2)) return {'loss': loss, 'out': out} def validation_step(self, *args, **kwargs):