diff --git a/denoisers/data/waveunet.py b/denoisers/data/waveunet.py index ccc3cf7..f977edd 100644 --- a/denoisers/data/waveunet.py +++ b/denoisers/data/waveunet.py @@ -29,11 +29,8 @@ def __init__( dataset: torch.utils.data.Dataset, batch_size: int = 24, num_workers: int = os.cpu_count() // 2, # type: ignore - max_length: int = 10, + max_length: int = 16384 * 10, sample_rate: int = 24000, - n_fft: int = 2048, - win_length: int = 1024, - hop_length: int = 256, ) -> None: super().__init__() self.save_hyperparameters() @@ -42,11 +39,8 @@ def __init__( self._batch_size = batch_size self._num_workers = num_workers # we don't use sample_rate here for divisibility - self._max_length = 16384 * max_length + self._max_length = max_length self._sample_rate = sample_rate - self._n_fft = n_fft - self._win_length = win_length - self._hop_length = hop_length self._transforms = nn.Sequential( transforms.ReverbFromSoundboard(p=1.0), transforms.GaussianNoise(p=1.0), diff --git a/denoisers/metrics.py b/denoisers/metrics.py new file mode 100644 index 0000000..7256da3 --- /dev/null +++ b/denoisers/metrics.py @@ -0,0 +1,21 @@ +"""Metrics for denoising.""" +import torch +import torchaudio +import torchmetrics +from pesq import cypesq +from torch import Tensor + + +def calculate_pesq(pred: Tensor, true: Tensor, sample_rate: int = 24000) -> Tensor: + """Calculate PESQ.""" + pred_resample = torchaudio.functional.resample(pred, sample_rate, 16000) + true_resample = torchaudio.functional.resample(true, sample_rate, 16000) + + try: + pesq = torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality( + pred_resample, true_resample, 16000, "wb" + ) + except cypesq.NoUtterancesError: + pesq = torch.tensor(0.0) + + return pesq.mean() diff --git a/denoisers/modeling/waveunet/config.py b/denoisers/modeling/waveunet/config.py index d1bf31c..ba3fdbc 100644 --- a/denoisers/modeling/waveunet/config.py +++ b/denoisers/modeling/waveunet/config.py @@ -25,15 +25,19 @@ def __init__( 264, 288, ), - kernel_size: int = 15, - dropout: float = 0.0, + downsample_kernel_size: int = 15, + upsample_kernel_size: int = 5, + dropout: float = 0.1, activation: str = "leaky_relu", autoencoder: bool = False, + max_length: int = 16384 * 10, + sample_rate: int = 48000, **kwargs: Any, ) -> None: self.in_channels = in_channels - self.kernel_size = kernel_size + self.downsample_kernel_size = downsample_kernel_size + self.upsample_kernel_size = upsample_kernel_size self.dropout = dropout self.activation = activation self.autoencoder = autoencoder - super().__init__(**kwargs) + super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate) diff --git a/denoisers/modeling/waveunet/model.py b/denoisers/modeling/waveunet/model.py index 8d1fc14..8634bd4 100644 --- a/denoisers/modeling/waveunet/model.py +++ b/denoisers/modeling/waveunet/model.py @@ -1,30 +1,35 @@ """Wave UNet Model.""" from typing import Any, Dict, Optional, Tuple, Union -import pytorch_lightning as pl import torch +from pytorch_lightning import LightningModule from pytorch_lightning.utilities import grad_norm from pytorch_lightning.utilities.memory import garbage_collection_cuda from torch import Tensor, nn -from torchmetrics.audio import SignalNoiseRatio +from torchmetrics.audio import ( + ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, +) from transformers import PreTrainedModel from denoisers.data.waveunet import Batch +from denoisers.metrics import calculate_pesq from denoisers.modeling.modules import Activation, DownsampleBlock1D, UpsampleBlock1D from denoisers.modeling.waveunet.config import WaveUNetConfig from denoisers.utils import log_audio_batch, plot_image_from_audio -class WaveUNetLightningModule(pl.LightningModule): +class WaveUNetLightningModule(LightningModule): """WaveUNet Model.""" - def __init__(self) -> None: + def __init__(self, config: WaveUNetConfig) -> None: super().__init__() self.save_hyperparameters() - self.config = WaveUNetConfig() + self.config = config self.model = WaveUNetModel(self.config) self.loss_fn = nn.L1Loss() - self.snr = SignalNoiseRatio() + self.snr = ScaleInvariantSignalNoiseRatio() + self.sdr = ScaleInvariantSignalDistortionRatio() self.autoencoder = self.config.autoencoder self.last_val_batch: Any = {} @@ -40,13 +45,15 @@ def training_step( if self.autoencoder: loss = self.loss_fn(outputs.logits, batch.audio) - snr = self.snr(outputs.logits, batch.audio) else: loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) - snr = self.snr(outputs.logits, batch.audio) + + snr = self.snr(outputs.logits, batch.audio) + sdr = self.sdr(outputs.logits, batch.audio) self.log("train_loss", loss, prog_bar=True) self.log("train_snr", snr) + self.log("train_sdr", sdr) return loss @@ -58,21 +65,23 @@ def validation_step( if self.autoencoder: loss = self.loss_fn(outputs.logits, batch.audio) - snr = self.snr(outputs.logits, batch.audio) - pred = outputs.logits else: loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio) - snr = self.snr(outputs.logits, batch.audio) - pred = outputs.logits + + snr = self.snr(outputs.logits, batch.audio) + sdr = self.sdr(outputs.logits, batch.audio) + pesq = calculate_pesq(outputs.logits, batch.audio, self.config.sample_rate) self.log("val_loss", loss, prog_bar=True) self.log("val_snr", snr) + self.log("val_sdr", sdr) + self.log("pesq", pesq) self.last_val_batch = { "outputs": ( batch.audio.detach(), batch.noisy.detach(), - pred.detach(), + outputs.logits.detach(), batch.lengths.detach(), ) } @@ -83,9 +92,17 @@ def on_validation_epoch_end(self) -> None: """Val epoch end.""" outputs = self.last_val_batch["outputs"] audio, noisy, preds, lengths = outputs - log_audio_batch(audio, noisy, preds, lengths, name="val") + log_audio_batch( + audio, + noisy, + preds, + lengths, + name="val", + sample_rate=self.config.sample_rate, + ) plot_image_from_audio(audio, noisy, preds, lengths, "val") self.snr.reset() + self.sdr.reset() model_name = self.trainer.default_root_dir.split("/")[-1] self.model.save_pretrained(self.trainer.default_root_dir + "/" + model_name) @@ -121,7 +138,11 @@ def __init__(self, config: WaveUNetConfig): super().__init__(config) self.config = config self.model = WaveUNet( - config.in_channels, config.kernel_size, config.dropout, config.activation + in_channels=config.in_channels, + downsample_kernel_size=config.downsample_kernel_size, + upsample_kernel_size=config.upsample_kernel_size, + dropout=config.dropout, + activation=config.activation, ) def forward(self, inputs: Tensor) -> WaveUNetModelOutputs: @@ -154,20 +175,24 @@ def __init__( 264, 288, ), - kernel_size: int = 15, + downsample_kernel_size: int = 15, + upsample_kernel_size: int = 5, dropout: float = 0.0, activation: str = "leaky_relu", ) -> None: super().__init__() self.in_conv = nn.Conv1d( - 1, in_channels[0], kernel_size=kernel_size, padding=kernel_size // 2 + 1, + in_channels[0], + kernel_size=downsample_kernel_size, + padding=downsample_kernel_size // 2, ) self.encoder_layers = nn.ModuleList( [ DownsampleBlock1D( in_channels[i], out_channels=in_channels[i + 1], - kernel_size=kernel_size, + kernel_size=downsample_kernel_size, dropout=dropout, activation=activation, ) @@ -178,8 +203,8 @@ def __init__( nn.Conv1d( in_channels[-1], in_channels[-1], - kernel_size=kernel_size, - padding=kernel_size // 2, + kernel_size=downsample_kernel_size, + padding=downsample_kernel_size // 2, ), nn.BatchNorm1d(in_channels[-1]), Activation(activation), @@ -190,7 +215,7 @@ def __init__( UpsampleBlock1D( 2 * in_channels[i + 1], out_channels=in_channels[i], - kernel_size=kernel_size, + kernel_size=upsample_kernel_size, dropout=dropout, activation=activation, ) diff --git a/denoisers/train.py b/denoisers/train.py index 7f993c6..5f57cc6 100644 --- a/denoisers/train.py +++ b/denoisers/train.py @@ -8,6 +8,7 @@ from denoisers.data.waveunet import AudioFromFileDataModule from denoisers.datasets.audio import AudioDataset +from denoisers.modeling.waveunet.config import WaveUNetConfig from denoisers.modeling.waveunet.model import WaveUNetLightningModule if torch.cuda.is_available(): @@ -38,10 +39,16 @@ def main() -> None: log_path = args.log_path / args.name log_path.mkdir(exist_ok=True, parents=True) - model = WaveUNetLightningModule() + config = WaveUNetConfig() + model = WaveUNetLightningModule(config) dataset = AudioDataset(args.data_path) - datamodule = AudioFromFileDataModule(dataset, batch_size=args.batch_size) + datamodule = AudioFromFileDataModule( + dataset, + batch_size=args.batch_size, + max_length=config.max_length, + sample_rate=config.sample_rate, + ) logger = loggers.WandbLogger( project=args.project, diff --git a/denoisers/transforms.py b/denoisers/transforms.py index bda47b2..f902388 100644 --- a/denoisers/transforms.py +++ b/denoisers/transforms.py @@ -30,9 +30,11 @@ def forward(self, x: Tensor) -> Tensor: class GaussianNoise(nn.Module): """Gaussian Noise Transform.""" - def __init__(self, p: float = 0.5) -> None: + def __init__(self, p: float = 0.5, db_min=1, db_max=30) -> None: super().__init__() self.p = p + self.db_min = db_min + self.db_max = db_max def forward(self, x: Union[Tensor, np.ndarray]) -> Tensor: """Forward Pass.""" @@ -40,9 +42,8 @@ def forward(self, x: Union[Tensor, np.ndarray]) -> Tensor: x = torch.from_numpy(x) if random.random() < self.p: - intensity = random.random() - noise = torch.randn_like(x) * intensity - x += noise + db = torch.randint(self.db_min, self.db_max, (1,)) + x = torchaudio.functional.add_noise(x, torch.randn_like(x), snr=db) return x diff --git a/denoisers/utils.py b/denoisers/utils.py index baaac0e..14bc7a8 100644 --- a/denoisers/utils.py +++ b/denoisers/utils.py @@ -97,6 +97,7 @@ def log_audio_batch( preds: torch.Tensor, lengths: torch.Tensor, name: str, + sample_rate: int = 24000, ) -> None: """Log a batch of audio to wandb.""" np_clean = clean.squeeze(1).cpu().detach().numpy()[0][: int(lengths[0])] @@ -106,9 +107,9 @@ def log_audio_batch( wandb.log( { f"{name}_audio": { - f"{name}_clean": wandb.Audio(np_clean, sample_rate=24000), - f"{name}_noisy": wandb.Audio(np_noisy, sample_rate=24000), - f"{name}_pred": wandb.Audio(np_preds, sample_rate=24000), + f"{name}_clean": wandb.Audio(np_clean, sample_rate=sample_rate), + f"{name}_noisy": wandb.Audio(np_noisy, sample_rate=sample_rate), + f"{name}_pred": wandb.Audio(np_preds, sample_rate=sample_rate), } } )