diff --git a/denoisers/modeling/unet1d/model.py b/denoisers/modeling/unet1d/model.py index 74d5fce..2cfbb3e 100644 --- a/denoisers/modeling/unet1d/model.py +++ b/denoisers/modeling/unet1d/model.py @@ -245,7 +245,7 @@ def forward(self, inputs: Tensor) -> Tensor: out = self.middle(out) for skip, layer in zip(reversed(skips), self.decoder_layers): - out = layer(out + skip) + out = layer(out[..., : skip.size(-1)] + skip) out = torch.concat([out, inputs], dim=1) out = self.out_conv(out) diff --git a/denoisers/modeling/waveunet/model.py b/denoisers/modeling/waveunet/model.py index cfab524..c4e7c4e 100644 --- a/denoisers/modeling/waveunet/model.py +++ b/denoisers/modeling/waveunet/model.py @@ -247,7 +247,7 @@ def forward(self, inputs: Tensor) -> Tensor: out = self.middle(out) for skip, layer in zip(reversed(skips), self.decoder_layers): - out = torch.concat([out, skip], dim=1) + out = torch.concat([out[..., : skip.size(-1)], skip], dim=1) out = layer(out) out = torch.concat([out, inputs], dim=1)