Skip to content

Commit

Permalink
Merge pull request #6 from will-rice/wr-metrics
Browse files Browse the repository at this point in the history
Add pesq and SignalDistortionRatio
  • Loading branch information
will-rice committed Oct 10, 2023
2 parents 5de49cb + f3294b0 commit 9266bb3
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 42 deletions.
10 changes: 2 additions & 8 deletions denoisers/data/waveunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ def __init__(
dataset: torch.utils.data.Dataset,
batch_size: int = 24,
num_workers: int = os.cpu_count() // 2, # type: ignore
max_length: int = 10,
max_length: int = 16384 * 10,
sample_rate: int = 24000,
n_fft: int = 2048,
win_length: int = 1024,
hop_length: int = 256,
) -> None:
super().__init__()
self.save_hyperparameters()
Expand All @@ -42,11 +39,8 @@ def __init__(
self._batch_size = batch_size
self._num_workers = num_workers
# we don't use sample_rate here for divisibility
self._max_length = 16384 * max_length
self._max_length = max_length
self._sample_rate = sample_rate
self._n_fft = n_fft
self._win_length = win_length
self._hop_length = hop_length
self._transforms = nn.Sequential(
transforms.ReverbFromSoundboard(p=1.0),
transforms.GaussianNoise(p=1.0),
Expand Down
21 changes: 21 additions & 0 deletions denoisers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Metrics for denoising."""
import torch
import torchaudio
import torchmetrics
from pesq import cypesq
from torch import Tensor


def calculate_pesq(pred: Tensor, true: Tensor, sample_rate: int = 24000) -> Tensor:
"""Calculate PESQ."""
pred_resample = torchaudio.functional.resample(pred, sample_rate, 16000)
true_resample = torchaudio.functional.resample(true, sample_rate, 16000)

try:
pesq = torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality(
pred_resample, true_resample, 16000, "wb"
)
except cypesq.NoUtterancesError:
pesq = torch.tensor(0.0)

return pesq.mean()
12 changes: 8 additions & 4 deletions denoisers/modeling/waveunet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@ def __init__(
264,
288,
),
kernel_size: int = 15,
dropout: float = 0.0,
downsample_kernel_size: int = 15,
upsample_kernel_size: int = 5,
dropout: float = 0.1,
activation: str = "leaky_relu",
autoencoder: bool = False,
max_length: int = 16384 * 10,
sample_rate: int = 48000,
**kwargs: Any,
) -> None:
self.in_channels = in_channels
self.kernel_size = kernel_size
self.downsample_kernel_size = downsample_kernel_size
self.upsample_kernel_size = upsample_kernel_size
self.dropout = dropout
self.activation = activation
self.autoencoder = autoencoder
super().__init__(**kwargs)
super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate)
67 changes: 46 additions & 21 deletions denoisers/modeling/waveunet/model.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
"""Wave UNet Model."""
from typing import Any, Dict, Optional, Tuple, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import grad_norm
from pytorch_lightning.utilities.memory import garbage_collection_cuda
from torch import Tensor, nn
from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.audio import (
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
)
from transformers import PreTrainedModel

from denoisers.data.waveunet import Batch
from denoisers.metrics import calculate_pesq
from denoisers.modeling.modules import Activation, DownsampleBlock1D, UpsampleBlock1D
from denoisers.modeling.waveunet.config import WaveUNetConfig
from denoisers.utils import log_audio_batch, plot_image_from_audio


class WaveUNetLightningModule(pl.LightningModule):
class WaveUNetLightningModule(LightningModule):
"""WaveUNet Model."""

def __init__(self) -> None:
def __init__(self, config: WaveUNetConfig) -> None:
super().__init__()
self.save_hyperparameters()
self.config = WaveUNetConfig()
self.config = config
self.model = WaveUNetModel(self.config)
self.loss_fn = nn.L1Loss()
self.snr = SignalNoiseRatio()
self.snr = ScaleInvariantSignalNoiseRatio()
self.sdr = ScaleInvariantSignalDistortionRatio()
self.autoencoder = self.config.autoencoder
self.last_val_batch: Any = {}

Expand All @@ -40,13 +45,15 @@ def training_step(

if self.autoencoder:
loss = self.loss_fn(outputs.logits, batch.audio)
snr = self.snr(outputs.logits, batch.audio)
else:
loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio)
snr = self.snr(outputs.logits, batch.audio)

snr = self.snr(outputs.logits, batch.audio)
sdr = self.sdr(outputs.logits, batch.audio)

self.log("train_loss", loss, prog_bar=True)
self.log("train_snr", snr)
self.log("train_sdr", sdr)

return loss

Expand All @@ -58,21 +65,23 @@ def validation_step(

if self.autoencoder:
loss = self.loss_fn(outputs.logits, batch.audio)
snr = self.snr(outputs.logits, batch.audio)
pred = outputs.logits
else:
loss = self.loss_fn(outputs.noise, batch.noisy - batch.audio)
snr = self.snr(outputs.logits, batch.audio)
pred = outputs.logits

snr = self.snr(outputs.logits, batch.audio)
sdr = self.sdr(outputs.logits, batch.audio)
pesq = calculate_pesq(outputs.logits, batch.audio, self.config.sample_rate)

self.log("val_loss", loss, prog_bar=True)
self.log("val_snr", snr)
self.log("val_sdr", sdr)
self.log("pesq", pesq)

self.last_val_batch = {
"outputs": (
batch.audio.detach(),
batch.noisy.detach(),
pred.detach(),
outputs.logits.detach(),
batch.lengths.detach(),
)
}
Expand All @@ -83,9 +92,17 @@ def on_validation_epoch_end(self) -> None:
"""Val epoch end."""
outputs = self.last_val_batch["outputs"]
audio, noisy, preds, lengths = outputs
log_audio_batch(audio, noisy, preds, lengths, name="val")
log_audio_batch(
audio,
noisy,
preds,
lengths,
name="val",
sample_rate=self.config.sample_rate,
)
plot_image_from_audio(audio, noisy, preds, lengths, "val")
self.snr.reset()
self.sdr.reset()

model_name = self.trainer.default_root_dir.split("/")[-1]
self.model.save_pretrained(self.trainer.default_root_dir + "/" + model_name)
Expand Down Expand Up @@ -121,7 +138,11 @@ def __init__(self, config: WaveUNetConfig):
super().__init__(config)
self.config = config
self.model = WaveUNet(
config.in_channels, config.kernel_size, config.dropout, config.activation
in_channels=config.in_channels,
downsample_kernel_size=config.downsample_kernel_size,
upsample_kernel_size=config.upsample_kernel_size,
dropout=config.dropout,
activation=config.activation,
)

def forward(self, inputs: Tensor) -> WaveUNetModelOutputs:
Expand Down Expand Up @@ -154,20 +175,24 @@ def __init__(
264,
288,
),
kernel_size: int = 15,
downsample_kernel_size: int = 15,
upsample_kernel_size: int = 5,
dropout: float = 0.0,
activation: str = "leaky_relu",
) -> None:
super().__init__()
self.in_conv = nn.Conv1d(
1, in_channels[0], kernel_size=kernel_size, padding=kernel_size // 2
1,
in_channels[0],
kernel_size=downsample_kernel_size,
padding=downsample_kernel_size // 2,
)
self.encoder_layers = nn.ModuleList(
[
DownsampleBlock1D(
in_channels[i],
out_channels=in_channels[i + 1],
kernel_size=kernel_size,
kernel_size=downsample_kernel_size,
dropout=dropout,
activation=activation,
)
Expand All @@ -178,8 +203,8 @@ def __init__(
nn.Conv1d(
in_channels[-1],
in_channels[-1],
kernel_size=kernel_size,
padding=kernel_size // 2,
kernel_size=downsample_kernel_size,
padding=downsample_kernel_size // 2,
),
nn.BatchNorm1d(in_channels[-1]),
Activation(activation),
Expand All @@ -190,7 +215,7 @@ def __init__(
UpsampleBlock1D(
2 * in_channels[i + 1],
out_channels=in_channels[i],
kernel_size=kernel_size,
kernel_size=upsample_kernel_size,
dropout=dropout,
activation=activation,
)
Expand Down
11 changes: 9 additions & 2 deletions denoisers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from denoisers.data.waveunet import AudioFromFileDataModule
from denoisers.datasets.audio import AudioDataset
from denoisers.modeling.waveunet.config import WaveUNetConfig
from denoisers.modeling.waveunet.model import WaveUNetLightningModule

if torch.cuda.is_available():
Expand Down Expand Up @@ -38,10 +39,16 @@ def main() -> None:
log_path = args.log_path / args.name
log_path.mkdir(exist_ok=True, parents=True)

model = WaveUNetLightningModule()
config = WaveUNetConfig()
model = WaveUNetLightningModule(config)

dataset = AudioDataset(args.data_path)
datamodule = AudioFromFileDataModule(dataset, batch_size=args.batch_size)
datamodule = AudioFromFileDataModule(
dataset,
batch_size=args.batch_size,
max_length=config.max_length,
sample_rate=config.sample_rate,
)

logger = loggers.WandbLogger(
project=args.project,
Expand Down
9 changes: 5 additions & 4 deletions denoisers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@ def forward(self, x: Tensor) -> Tensor:
class GaussianNoise(nn.Module):
"""Gaussian Noise Transform."""

def __init__(self, p: float = 0.5) -> None:
def __init__(self, p: float = 0.5, db_min=1, db_max=30) -> None:
super().__init__()
self.p = p
self.db_min = db_min
self.db_max = db_max

def forward(self, x: Union[Tensor, np.ndarray]) -> Tensor:
"""Forward Pass."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)

if random.random() < self.p:
intensity = random.random()
noise = torch.randn_like(x) * intensity
x += noise
db = torch.randint(self.db_min, self.db_max, (1,))
x = torchaudio.functional.add_noise(x, torch.randn_like(x), snr=db)

return x

Expand Down
7 changes: 4 additions & 3 deletions denoisers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def log_audio_batch(
preds: torch.Tensor,
lengths: torch.Tensor,
name: str,
sample_rate: int = 24000,
) -> None:
"""Log a batch of audio to wandb."""
np_clean = clean.squeeze(1).cpu().detach().numpy()[0][: int(lengths[0])]
Expand All @@ -106,9 +107,9 @@ def log_audio_batch(
wandb.log(
{
f"{name}_audio": {
f"{name}_clean": wandb.Audio(np_clean, sample_rate=24000),
f"{name}_noisy": wandb.Audio(np_noisy, sample_rate=24000),
f"{name}_pred": wandb.Audio(np_preds, sample_rate=24000),
f"{name}_clean": wandb.Audio(np_clean, sample_rate=sample_rate),
f"{name}_noisy": wandb.Audio(np_noisy, sample_rate=sample_rate),
f"{name}_pred": wandb.Audio(np_preds, sample_rate=sample_rate),
}
}
)

0 comments on commit 9266bb3

Please sign in to comment.