From bf96730883f423438017d1d12584ec499a662138 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 2 Nov 2023 10:50:46 -0400 Subject: [PATCH 1/2] Add UNet1D model --- bin/run | 7 +- denoisers/__init__.py | 4 +- denoisers/datamodules/unet1d.py | 114 ++++++++ denoisers/modeling/__init__.py | 4 +- denoisers/modeling/unet1d/__init__.py | 5 + denoisers/modeling/unet1d/config.py | 43 +++ denoisers/modeling/unet1d/model.py | 247 ++++++++++++++++++ denoisers/modeling/unet1d/modules.py | 163 ++++++++++++ denoisers/modeling/waveunet/__init__.py | 2 +- denoisers/modeling/waveunet/model.py | 2 +- denoisers/scripts/__init__.py | 1 + denoisers/scripts/train/__init__.py | 1 + denoisers/scripts/train/unet1d.py | 91 +++++++ .../{train.py => scripts/train/waveunet.py} | 0 tests/modeling/test_unet1d_model.py | 80 ++++++ 15 files changed, 757 insertions(+), 7 deletions(-) create mode 100644 denoisers/datamodules/unet1d.py create mode 100644 denoisers/modeling/unet1d/__init__.py create mode 100644 denoisers/modeling/unet1d/config.py create mode 100644 denoisers/modeling/unet1d/model.py create mode 100644 denoisers/modeling/unet1d/modules.py create mode 100644 denoisers/scripts/__init__.py create mode 100644 denoisers/scripts/train/__init__.py create mode 100644 denoisers/scripts/train/unet1d.py rename denoisers/{train.py => scripts/train/waveunet.py} (100%) create mode 100644 tests/modeling/test_unet1d_model.py diff --git a/bin/run b/bin/run index 7454690..4f2efb3 100755 --- a/bin/run +++ b/bin/run @@ -2,6 +2,7 @@ set -e -command=$1 -shift -exec python -m "denoisers.$command" "$@" +subsys=$1 +command=$2 +shift; shift +exec python -m "denoisers.scripts.$subsys.$command" "$@" diff --git a/denoisers/__init__.py b/denoisers/__init__.py index b7347d9..82227d2 100644 --- a/denoisers/__init__.py +++ b/denoisers/__init__.py @@ -1,5 +1,7 @@ """Denoisers for the 1D and 2D cases.""" +from denoisers.modeling.unet1d.config import UNet1DConfig +from denoisers.modeling.unet1d.model import UNet1DModel from denoisers.modeling.waveunet.config import WaveUNetConfig from denoisers.modeling.waveunet.model import WaveUNetModel -__all__ = ["WaveUNetConfig", "WaveUNetModel"] +__all__ = ["WaveUNetConfig", "WaveUNetModel", "UNet1DConfig", "UNet1DModel"] diff --git a/denoisers/datamodules/unet1d.py b/denoisers/datamodules/unet1d.py new file mode 100644 index 0000000..9c57eae --- /dev/null +++ b/denoisers/datamodules/unet1d.py @@ -0,0 +1,114 @@ +"""Unet1D Data modules.""" +import os +from typing import List, NamedTuple, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +import torchaudio +from torch import Tensor, nn +from torch.nn import functional as F +from torch.utils.data import DataLoader + +from denoisers import transforms + + +class Batch(NamedTuple): + """Batch of inputs.""" + + audio: Tensor + noisy: Tensor + lengths: Tensor + + +class AudioFromFileDataModule(pl.LightningDataModule): + """LibriTTS DataModule.""" + + def __init__( + self, + dataset: torch.utils.data.Dataset, + batch_size: int = 24, + num_workers: int = os.cpu_count() // 2, # type: ignore + max_length: int = 16384 * 10, + sample_rate: int = 24000, + ) -> None: + super().__init__() + self.save_hyperparameters() + + self._dataset = dataset + self._batch_size = batch_size + self._num_workers = num_workers + # we don't use sample_rate here for divisibility + self._max_length = max_length + self._sample_rate = sample_rate + self._transforms = nn.Sequential( + transforms.ReverbFromSoundboard(p=1.0, sample_rate=sample_rate), + transforms.GaussianNoise(p=1.0), + ) + + def setup(self, stage: Optional[str] = "fit") -> None: + """Setup datasets.""" + train_split = int(np.floor(len(self._dataset) * 0.95)) # type: ignore + val_split = int(np.ceil(len(self._dataset) * 0.05)) # type: ignore + + self.train_dataset, self.val_dataset = torch.utils.data.random_split( + self._dataset, lengths=(train_split, val_split) + ) + + def pad_collate_fn(self, paths: List[str]) -> Batch: + """Pad collate function.""" + audios = [] + noisy_audio = [] + lengths = [] + for path in paths: + try: + audio, sr = torchaudio.load(path) + except Exception: + continue + + if sr != self._sample_rate: + audio = torchaudio.functional.resample(audio, sr, self._sample_rate) + + if audio.size(0) > 1: + audio = audio[0].unsqueeze(0) + + audio_length = min(audio.size(-1), self._max_length) + + if audio_length < self._max_length: + pad_length = self._max_length - audio_length + audio = F.pad(audio, (0, pad_length)) + else: + audio = audio[:, : self._max_length] + + noisy = self._transforms(audio.clone()) + + audios.append(audio) + noisy_audio.append(noisy) + lengths.append(torch.tensor(audio_length)) + + return Batch( + audio=torch.stack(audios), + noisy=torch.stack(noisy_audio), + lengths=torch.stack(lengths), + ) + + def train_dataloader(self) -> DataLoader: + """Train dataloader.""" + return DataLoader( + self.train_dataset, + batch_size=self._batch_size, + num_workers=self._num_workers, + collate_fn=self.pad_collate_fn, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + """Validation dataloader.""" + return DataLoader( + self.val_dataset, + batch_size=self._batch_size, + num_workers=self._num_workers, + collate_fn=self.pad_collate_fn, + shuffle=False, + drop_last=True, + ) diff --git a/denoisers/modeling/__init__.py b/denoisers/modeling/__init__.py index d942cd3..cde9b9f 100644 --- a/denoisers/modeling/__init__.py +++ b/denoisers/modeling/__init__.py @@ -1,5 +1,7 @@ """Models.""" +from denoisers.modeling.unet1d.config import UNet1DConfig +from denoisers.modeling.unet1d.model import UNet1DModel from denoisers.modeling.waveunet.config import WaveUNetConfig from denoisers.modeling.waveunet.model import WaveUNetModel -__all__ = ["WaveUNetConfig", "WaveUNetModel"] +__all__ = ["WaveUNetConfig", "WaveUNetModel", "UNet1DConfig", "UNet1DModel"] diff --git a/denoisers/modeling/unet1d/__init__.py b/denoisers/modeling/unet1d/__init__.py new file mode 100644 index 0000000..44064c4 --- /dev/null +++ b/denoisers/modeling/unet1d/__init__.py @@ -0,0 +1,5 @@ +"""UNet1D model.""" +from denoisers.modeling.unet1d.config import UNet1DConfig +from denoisers.modeling.unet1d.model import UNet1DModel + +__all__ = ["UNet1DConfig", "UNet1DModel"] diff --git a/denoisers/modeling/unet1d/config.py b/denoisers/modeling/unet1d/config.py new file mode 100644 index 0000000..84913c3 --- /dev/null +++ b/denoisers/modeling/unet1d/config.py @@ -0,0 +1,43 @@ +"""WaveUNet configuration file.""" +from typing import Any, Tuple + +from transformers import PretrainedConfig + + +class UNet1DConfig(PretrainedConfig): + """Configuration class to store the configuration of a `UNet1DModel`.""" + + model_type = "unet1d" + + def __init__( + self, + channels: Tuple[int, ...] = ( + 32, + 64, + 96, + 128, + 160, + 192, + 224, + 256, + 288, + 320, + 352, + 384, + ), + kernel_size: int = 3, + num_groups: int = 32, + dropout: float = 0.1, + activation: str = "silu", + autoencoder: bool = False, + max_length: int = 16384 * 10, + sample_rate: int = 48000, + **kwargs: Any, + ) -> None: + self.channels = channels + self.kernel_size = kernel_size + self.num_groups = num_groups + self.dropout = dropout + self.activation = activation + self.autoencoder = autoencoder + super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate) diff --git a/denoisers/modeling/unet1d/model.py b/denoisers/modeling/unet1d/model.py new file mode 100644 index 0000000..3d2e86e --- /dev/null +++ b/denoisers/modeling/unet1d/model.py @@ -0,0 +1,247 @@ +"""UNet1D model.""" +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import grad_norm +from pytorch_lightning.utilities.memory import garbage_collection_cuda +from torch import Tensor, nn +from torchmetrics.audio import ( + ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, +) +from transformers import PreTrainedModel + +from denoisers.datamodules.unet1d import Batch +from denoisers.metrics import calculate_pesq +from denoisers.modeling.unet1d.config import UNet1DConfig +from denoisers.modeling.unet1d.modules import DownBlock1D, MidBlock1D, UpBlock1D +from denoisers.utils import log_audio_batch, plot_image_from_audio + + +class UNet1DLightningModule(LightningModule): + """UNet1D Lightning Module.""" + + def __init__(self, config: UNet1DConfig) -> None: + super().__init__() + self.config = config + self.model = UNet1DModel(config) + self.loss_fn = nn.L1Loss() + self.snr = ScaleInvariantSignalNoiseRatio() + self.sdr = ScaleInvariantSignalDistortionRatio() + self.autoencoder = self.config.autoencoder + self.last_val_batch: Any = {} + + def forward(self, inputs: Tensor) -> Tensor: + """Forward Pass.""" + return self.model(inputs) + + def training_step( + self, batch: Batch, batch_idx: Any + ) -> Union[Tensor, Dict[str, Any]]: + """Train step.""" + outputs = self(batch.noisy) + + if self.autoencoder: + loss = self.loss_fn(outputs.logits, batch.audio) + else: + loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) + + snr = self.snr(outputs.logits, batch.audio) + sdr = self.sdr(outputs.logits, batch.audio) + + self.log("train_loss", loss, prog_bar=True) + self.log("train_snr", snr) + self.log("train_sdr", sdr) + + return loss + + def validation_step( + self, batch: Any, batch_idx: Any + ) -> Union[Tensor, Dict[str, Any]]: + """Val step.""" + outputs = self(batch.noisy) + + if self.autoencoder: + loss = self.loss_fn(outputs.logits, batch.audio) + else: + loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) + + snr = self.snr(outputs.logits, batch.audio) + sdr = self.sdr(outputs.logits, batch.audio) + pesq = calculate_pesq(outputs.logits, batch.audio, self.config.sample_rate) + + self.log("val_loss", loss, prog_bar=True) + self.log("val_snr", snr) + self.log("val_sdr", sdr) + self.log("pesq", pesq) + + self.last_val_batch = { + "outputs": ( + batch.audio.detach(), + batch.noisy.detach(), + outputs.logits.detach(), + batch.lengths.detach(), + ) + } + + return loss + + def on_validation_epoch_end(self) -> None: + """Val epoch end.""" + outputs = self.last_val_batch["outputs"] + audio, noisy, preds, lengths = outputs + log_audio_batch( + audio, + noisy, + preds, + lengths, + name="val", + sample_rate=self.config.sample_rate, + ) + plot_image_from_audio(audio, noisy, preds, lengths, "val") + self.snr.reset() + self.sdr.reset() + + model_name = self.trainer.default_root_dir.split("/")[-1] + self.model.save_pretrained(self.trainer.default_root_dir + "/" + model_name) + self.model.push_to_hub(model_name) + + garbage_collection_cuda() + + def on_before_optimizer_step(self, optimizer: Any) -> None: + """Before optimizer step.""" + self.log_dict(grad_norm(self, norm_type=1)) + + def configure_optimizers(self) -> Any: + """Set optimizer.""" + optimizer = torch.optim.AdamW( + self.model.parameters(), lr=1e-4, weight_decay=1e-2 + ) + + return optimizer + + +class UNet1DModelOutputs: + """Class for holding model outputs.""" + + def __init__(self, logits: Tensor, noise: Optional[Tensor] = None) -> None: + self.logits = logits + self.noise = noise + + +class UNet1DModel(PreTrainedModel): + """Pretrained UNet1D Model.""" + + config_class = UNet1DConfig + + def __init__(self, config: UNet1DConfig) -> None: + super().__init__(config) + self.config = config + self.model = UNet1D( + channels=config.channels, + kernel_size=config.kernel_size, + num_groups=config.num_groups, + activation=config.activation, + dropout=config.dropout, + ) + + def forward(self, inputs: Tensor) -> UNet1DModelOutputs: + """Forward Pass.""" + if self.config.autoencoder: + logits = self.model(inputs) + return UNet1DModelOutputs(logits=logits) + else: + noise = self.model(inputs) + denoised = inputs - noise + return UNet1DModelOutputs(logits=denoised, noise=noise) + + +class UNet1D(nn.Module): + """UNet1D model.""" + + def __init__( + self, + channels: Tuple[int, ...] = ( + 32, + 64, + 96, + 128, + 160, + 192, + 224, + 256, + 288, + 320, + 352, + 384, + ), + kernel_size: int = 3, + num_groups: int = 32, + activation: str = "silu", + dropout: float = 0.1, + ) -> None: + super().__init__() + self.in_conv = nn.Conv1d( + 1, + channels[0], + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + self.encoder_layers = nn.ModuleList( + [ + DownBlock1D( + channels[i], + out_channels=channels[i + 1], + kernel_size=kernel_size, + num_groups=num_groups, + dropout=dropout, + activation=activation, + ) + for i in range(len(channels) - 1) + ] + ) + self.middle = MidBlock1D( + in_channels=channels[-1], + out_channels=channels[-1], + kernel_size=kernel_size, + num_groups=num_groups, + dropout=dropout, + activation=activation, + ) + self.decoder_layers = nn.ModuleList( + [ + UpBlock1D( + channels[i + 1], + out_channels=channels[i], + kernel_size=kernel_size, + num_groups=num_groups, + dropout=dropout, + activation=activation, + ) + for i in reversed(range(len(channels) - 1)) + ] + ) + self.out_conv = nn.Sequential( + nn.Conv1d(channels[0] + 1, 1, kernel_size=1, padding=0), + nn.Tanh(), + ) + + def forward(self, inputs: Tensor) -> Tensor: + """Forward Pass.""" + out = self.in_conv(inputs) + + skips = [] + for layer in self.encoder_layers: + out = layer(out) + skips.append(out) + + out = self.middle(out) + + for skip, layer in zip(reversed(skips), self.decoder_layers): + out = layer(out + skip) + + out = torch.concat([out, inputs], dim=1) + out = self.out_conv(out) + + return out.float() diff --git a/denoisers/modeling/unet1d/modules.py b/denoisers/modeling/unet1d/modules.py new file mode 100644 index 0000000..1bd2c8a --- /dev/null +++ b/denoisers/modeling/unet1d/modules.py @@ -0,0 +1,163 @@ +"""Modules for 1D U-Net.""" +from torch import Tensor, nn + +from denoisers.modeling.modules import Activation, Downsample1D, Upsample1D + + +class DownBlock1D(nn.Module): + """Downsampling Block for 1D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_groups: int = 32, + activation: str = "silu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.res_block = ResBlock1D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_groups=num_groups, + activation=activation, + dropout=dropout, + ) + self.downsample = Downsample1D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + use_conv=True, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward Pass.""" + x = self.res_block(x) + x = self.downsample(x) + return x + + +class UpBlock1D(nn.Module): + """Upsampling Block for 1D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_groups: int = 32, + activation: str = "silu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.res_block = ResBlock1D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_groups=num_groups, + activation=activation, + dropout=dropout, + ) + self.upsample = Upsample1D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + use_conv=True, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward Pass.""" + x = self.res_block(x) + x = self.upsample(x) + return x + + +class ResBlock1D(nn.Module): + """Residual Block for 1D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_groups: int = 32, + activation: str = "silu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.norm_1 = nn.GroupNorm(num_groups, in_channels) + self.activation_1 = Activation(activation) + self.conv_1 = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ) + self.norm_2 = nn.GroupNorm(num_groups, out_channels) + self.activation_2 = Activation(activation) + self.dropout = nn.Dropout(dropout) + self.conv_2 = nn.Conv1d( + out_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ) + self.residual = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x: Tensor) -> Tensor: + """Forward Pass.""" + residual = self.residual(x) + x = self.norm_1(x) + x = self.activation_1(x) + x = self.conv_1(x) + x = self.norm_2(x) + x = self.activation_2(x) + x = self.dropout(x) + x = self.conv_2(x) + return x + residual + + +class MidBlock1D(nn.Module): + """Middle Block for 1D data.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_groups: int = 32, + num_heads: int = 8, + activation: str = "silu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.res_block_1 = ResBlock1D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_groups=num_groups, + activation=activation, + dropout=dropout, + ) + self.attention = nn.MultiheadAttention(out_channels, num_heads=num_heads) + self.res_block_2 = ResBlock1D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_groups=num_groups, + activation=activation, + dropout=dropout, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward Pass.""" + x = self.res_block_1(x) + x = self.attention(x.transpose(2, 1), x.transpose(2, 1), x.transpose(2, 1))[ + 0 + ].transpose(2, 1) + x = self.res_block_2(x) + return x diff --git a/denoisers/modeling/waveunet/__init__.py b/denoisers/modeling/waveunet/__init__.py index 23b1ac8..7b8981a 100644 --- a/denoisers/modeling/waveunet/__init__.py +++ b/denoisers/modeling/waveunet/__init__.py @@ -1,4 +1,4 @@ -"""1d UNet model.""" +"""WaveUnet model.""" from denoisers.modeling.waveunet.config import WaveUNetConfig from denoisers.modeling.waveunet.model import WaveUNetModel diff --git a/denoisers/modeling/waveunet/model.py b/denoisers/modeling/waveunet/model.py index b091555..0e1689f 100644 --- a/denoisers/modeling/waveunet/model.py +++ b/denoisers/modeling/waveunet/model.py @@ -136,7 +136,7 @@ class WaveUNetModel(PreTrainedModel): config_class = WaveUNetConfig - def __init__(self, config: WaveUNetConfig): + def __init__(self, config: WaveUNetConfig) -> None: super().__init__(config) self.config = config self.model = WaveUNet( diff --git a/denoisers/scripts/__init__.py b/denoisers/scripts/__init__.py new file mode 100644 index 0000000..a9a18d7 --- /dev/null +++ b/denoisers/scripts/__init__.py @@ -0,0 +1 @@ +"""Scripts.""" diff --git a/denoisers/scripts/train/__init__.py b/denoisers/scripts/train/__init__.py new file mode 100644 index 0000000..b6d9cb5 --- /dev/null +++ b/denoisers/scripts/train/__init__.py @@ -0,0 +1 @@ +"""Train scripts.""" diff --git a/denoisers/scripts/train/unet1d.py b/denoisers/scripts/train/unet1d.py new file mode 100644 index 0000000..c71d1a7 --- /dev/null +++ b/denoisers/scripts/train/unet1d.py @@ -0,0 +1,91 @@ +"""Train script.""" +import argparse +from pathlib import Path + +import pytorch_lightning as pl +import torch +from pytorch_lightning import loggers + +from denoisers.datamodules.unet1d import AudioFromFileDataModule +from denoisers.datasets.audio import AudioDataset +from denoisers.modeling.unet1d.config import UNet1DConfig +from denoisers.modeling.unet1d.model import UNet1DLightningModule + +if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + +def main() -> None: + """Main.""" + parser = argparse.ArgumentParser("train parser") + parser.add_argument("name", type=str) + parser.add_argument("data_path", type=Path) + parser.add_argument("--project", default="denoisers", type=str) + parser.add_argument( + "--num_devices", default=1 if torch.cuda.is_available() else None + ) + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--seed", default=1234, type=int) + parser.add_argument("--log_path", default="logs", type=Path) + parser.add_argument("--checkpoint_path", default=None, type=Path) + parser.add_argument("--debug", action="store_true") + + args = parser.parse_args() + + pl.seed_everything(args.seed) + + log_path = args.log_path / args.name + log_path.mkdir(exist_ok=True, parents=True) + + config = UNet1DConfig() + model = UNet1DLightningModule(config) + + dataset = AudioDataset(args.data_path) + datamodule = AudioFromFileDataModule( + dataset, + batch_size=args.batch_size, + max_length=config.max_length, + sample_rate=config.sample_rate, + ) + + logger = loggers.WandbLogger( + project=args.project, + save_dir=log_path, + name=args.name, + offline=args.debug, + ) + + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=log_path, + filename="{step}", + save_last=True, + ) + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") + + pretrained = args.checkpoint_path + last_checkpoint = pretrained if pretrained else log_path / "last.ckpt" + + trainer = pl.Trainer( + default_root_dir=log_path, + max_epochs=1000, + accelerator="auto", + val_check_interval=0.1, + devices=args.num_devices, + logger=logger, + precision="16-mixed", + accumulate_grad_batches=2, + limit_val_batches=10, + callbacks=[checkpoint_callback, lr_monitor], + ) + + trainer.fit( + model, + datamodule=datamodule, + ckpt_path=last_checkpoint if last_checkpoint.exists() else None, + ) + + +if __name__ == "__main__": + main() diff --git a/denoisers/train.py b/denoisers/scripts/train/waveunet.py similarity index 100% rename from denoisers/train.py rename to denoisers/scripts/train/waveunet.py diff --git a/tests/modeling/test_unet1d_model.py b/tests/modeling/test_unet1d_model.py new file mode 100644 index 0000000..5e72dee --- /dev/null +++ b/tests/modeling/test_unet1d_model.py @@ -0,0 +1,80 @@ +"""Tests for WaveUNet model.""" +import torch + +from denoisers.datamodules.unet1d import Batch +from denoisers.modeling.unet1d.model import ( + UNet1DConfig, + UNet1DLightningModule, + UNet1DModel, +) +from denoisers.testing import sine_wave + + +def test_config(): + config = UNet1DConfig( + max_length=8192, + sample_rate=16000, + channels=(1, 2, 3, 4, 5, 6), + kernel_size=3, + dropout=0.1, + activation="silu", + autoencoder=False, + ) + assert config.max_length == 8192 + assert config.sample_rate == 16000 + assert config.channels == (1, 2, 3, 4, 5, 6) + assert config.kernel_size == 3 + assert config.dropout == 0.1 + assert config.activation == "silu" + assert config.autoencoder is False + + +def test_model(): + """Test model.""" + config = UNet1DConfig( + max_length=1024, + sample_rate=16000, + channels=(1, 2, 3), + kernel_size=3, + ) + model = UNet1DModel(config) + model.eval() + + audio = sine_wave(800, config.max_length, config.sample_rate)[None] + with torch.no_grad(): + recon = model(audio).logits + + assert isinstance(recon, torch.Tensor) + assert audio.shape == recon.shape + + +def test_lightning_module(): + """Test lightning module.""" + config = UNet1DConfig( + max_length=1024, + sample_rate=16000, + in_channels=(1, 2, 3), + downsample_kernel_size=3, + upsample_kernel_size=3, + ) + model = UNet1DLightningModule(config) + + audio = sine_wave(800, config.max_length, config.sample_rate)[None] + batch = Batch(audio=audio, noisy=audio, lengths=torch.tensor([audio.shape[-1]])) + + # test forward + with torch.no_grad(): + recon = model(audio).logits + + assert isinstance(recon, torch.Tensor) + assert audio.shape == recon.shape + + # test training step + loss = model.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.shape == torch.Size([]) + + # test validation step + loss = model.validation_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.shape == torch.Size([]) From 4ab67b98d534678c132b2edea850c84a39c6bff1 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 2 Nov 2023 11:20:14 -0400 Subject: [PATCH 2/2] Add UNet1D model --- tests/modeling/test_unet1d_model.py | 18 +++++++++--------- tests/modeling/test_waveunet_model.py | 9 ++++----- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/modeling/test_unet1d_model.py b/tests/modeling/test_unet1d_model.py index 5e72dee..dd59f6e 100644 --- a/tests/modeling/test_unet1d_model.py +++ b/tests/modeling/test_unet1d_model.py @@ -7,7 +7,6 @@ UNet1DLightningModule, UNet1DModel, ) -from denoisers.testing import sine_wave def test_config(): @@ -32,15 +31,16 @@ def test_config(): def test_model(): """Test model.""" config = UNet1DConfig( - max_length=1024, + max_length=16384, sample_rate=16000, - channels=(1, 2, 3), + channels=(2, 4, 6, 8), kernel_size=3, + num_groups=2, ) model = UNet1DModel(config) model.eval() - audio = sine_wave(800, config.max_length, config.sample_rate)[None] + audio = torch.randn(1, 1, config.max_length) with torch.no_grad(): recon = model(audio).logits @@ -51,15 +51,15 @@ def test_model(): def test_lightning_module(): """Test lightning module.""" config = UNet1DConfig( - max_length=1024, + max_length=16384, sample_rate=16000, - in_channels=(1, 2, 3), - downsample_kernel_size=3, - upsample_kernel_size=3, + in_channels=(2, 4, 6, 8), + kernel_size=3, + num_groups=2, ) model = UNet1DLightningModule(config) - audio = sine_wave(800, config.max_length, config.sample_rate)[None] + audio = torch.randn(1, 1, config.max_length) batch = Batch(audio=audio, noisy=audio, lengths=torch.tensor([audio.shape[-1]])) # test forward diff --git a/tests/modeling/test_waveunet_model.py b/tests/modeling/test_waveunet_model.py index a90b224..2001f0e 100644 --- a/tests/modeling/test_waveunet_model.py +++ b/tests/modeling/test_waveunet_model.py @@ -7,7 +7,6 @@ WaveUNetLightningModule, WaveUNetModel, ) -from denoisers.testing import sine_wave def test_config(): @@ -34,7 +33,7 @@ def test_config(): def test_model(): """Test model.""" config = WaveUNetConfig( - max_length=1024, + max_length=16384, sample_rate=16000, in_channels=(1, 2, 3), downsample_kernel_size=3, @@ -43,7 +42,7 @@ def test_model(): model = WaveUNetModel(config) model.eval() - audio = sine_wave(800, config.max_length, config.sample_rate)[None] + audio = torch.randn(1, 1, config.max_length) with torch.no_grad(): recon = model(audio).logits @@ -54,7 +53,7 @@ def test_model(): def test_lightning_module(): """Test lightning module.""" config = WaveUNetConfig( - max_length=1024, + max_length=16384, sample_rate=16000, in_channels=(1, 2, 3), downsample_kernel_size=3, @@ -62,7 +61,7 @@ def test_lightning_module(): ) model = WaveUNetLightningModule(config) - audio = sine_wave(800, config.max_length, config.sample_rate)[None] + audio = torch.randn(1, 1, config.max_length) batch = Batch(audio=audio, noisy=audio, lengths=torch.tensor([audio.shape[-1]])) # test forward