Skip to content

Commit

Permalink
Merge pull request #16 from will-rice/wr-unet1d
Browse files Browse the repository at this point in the history
Add UNet1D model
  • Loading branch information
will-rice committed Nov 2, 2023
2 parents 2d996cf + 4ab67b9 commit 84eeaf7
Show file tree
Hide file tree
Showing 16 changed files with 761 additions and 12 deletions.
7 changes: 4 additions & 3 deletions bin/run
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

set -e

command=$1
shift
exec python -m "denoisers.$command" "$@"
subsys=$1
command=$2
shift; shift
exec python -m "denoisers.scripts.$subsys.$command" "$@"
4 changes: 3 additions & 1 deletion denoisers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Denoisers for the 1D and 2D cases."""
from denoisers.modeling.unet1d.config import UNet1DConfig
from denoisers.modeling.unet1d.model import UNet1DModel
from denoisers.modeling.waveunet.config import WaveUNetConfig
from denoisers.modeling.waveunet.model import WaveUNetModel

__all__ = ["WaveUNetConfig", "WaveUNetModel"]
__all__ = ["WaveUNetConfig", "WaveUNetModel", "UNet1DConfig", "UNet1DModel"]
114 changes: 114 additions & 0 deletions denoisers/datamodules/unet1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Unet1D Data modules."""
import os
from typing import List, NamedTuple, Optional

import numpy as np
import pytorch_lightning as pl
import torch
import torchaudio
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

from denoisers import transforms


class Batch(NamedTuple):
"""Batch of inputs."""

audio: Tensor
noisy: Tensor
lengths: Tensor


class AudioFromFileDataModule(pl.LightningDataModule):
"""LibriTTS DataModule."""

def __init__(
self,
dataset: torch.utils.data.Dataset,
batch_size: int = 24,
num_workers: int = os.cpu_count() // 2, # type: ignore
max_length: int = 16384 * 10,
sample_rate: int = 24000,
) -> None:
super().__init__()
self.save_hyperparameters()

self._dataset = dataset
self._batch_size = batch_size
self._num_workers = num_workers
# we don't use sample_rate here for divisibility
self._max_length = max_length
self._sample_rate = sample_rate
self._transforms = nn.Sequential(
transforms.ReverbFromSoundboard(p=1.0, sample_rate=sample_rate),
transforms.GaussianNoise(p=1.0),
)

def setup(self, stage: Optional[str] = "fit") -> None:
"""Setup datasets."""
train_split = int(np.floor(len(self._dataset) * 0.95)) # type: ignore
val_split = int(np.ceil(len(self._dataset) * 0.05)) # type: ignore

self.train_dataset, self.val_dataset = torch.utils.data.random_split(
self._dataset, lengths=(train_split, val_split)
)

def pad_collate_fn(self, paths: List[str]) -> Batch:
"""Pad collate function."""
audios = []
noisy_audio = []
lengths = []
for path in paths:
try:
audio, sr = torchaudio.load(path)
except Exception:
continue

if sr != self._sample_rate:
audio = torchaudio.functional.resample(audio, sr, self._sample_rate)

if audio.size(0) > 1:
audio = audio[0].unsqueeze(0)

audio_length = min(audio.size(-1), self._max_length)

if audio_length < self._max_length:
pad_length = self._max_length - audio_length
audio = F.pad(audio, (0, pad_length))
else:
audio = audio[:, : self._max_length]

noisy = self._transforms(audio.clone())

audios.append(audio)
noisy_audio.append(noisy)
lengths.append(torch.tensor(audio_length))

return Batch(
audio=torch.stack(audios),
noisy=torch.stack(noisy_audio),
lengths=torch.stack(lengths),
)

def train_dataloader(self) -> DataLoader:
"""Train dataloader."""
return DataLoader(
self.train_dataset,
batch_size=self._batch_size,
num_workers=self._num_workers,
collate_fn=self.pad_collate_fn,
shuffle=True,
)

def val_dataloader(self) -> DataLoader:
"""Validation dataloader."""
return DataLoader(
self.val_dataset,
batch_size=self._batch_size,
num_workers=self._num_workers,
collate_fn=self.pad_collate_fn,
shuffle=False,
drop_last=True,
)
4 changes: 3 additions & 1 deletion denoisers/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Models."""
from denoisers.modeling.unet1d.config import UNet1DConfig
from denoisers.modeling.unet1d.model import UNet1DModel
from denoisers.modeling.waveunet.config import WaveUNetConfig
from denoisers.modeling.waveunet.model import WaveUNetModel

__all__ = ["WaveUNetConfig", "WaveUNetModel"]
__all__ = ["WaveUNetConfig", "WaveUNetModel", "UNet1DConfig", "UNet1DModel"]
5 changes: 5 additions & 0 deletions denoisers/modeling/unet1d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""UNet1D model."""
from denoisers.modeling.unet1d.config import UNet1DConfig
from denoisers.modeling.unet1d.model import UNet1DModel

__all__ = ["UNet1DConfig", "UNet1DModel"]
43 changes: 43 additions & 0 deletions denoisers/modeling/unet1d/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""WaveUNet configuration file."""
from typing import Any, Tuple

from transformers import PretrainedConfig


class UNet1DConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `UNet1DModel`."""

model_type = "unet1d"

def __init__(
self,
channels: Tuple[int, ...] = (
32,
64,
96,
128,
160,
192,
224,
256,
288,
320,
352,
384,
),
kernel_size: int = 3,
num_groups: int = 32,
dropout: float = 0.1,
activation: str = "silu",
autoencoder: bool = False,
max_length: int = 16384 * 10,
sample_rate: int = 48000,
**kwargs: Any,
) -> None:
self.channels = channels
self.kernel_size = kernel_size
self.num_groups = num_groups
self.dropout = dropout
self.activation = activation
self.autoencoder = autoencoder
super().__init__(**kwargs, max_length=max_length, sample_rate=sample_rate)
Loading

0 comments on commit 84eeaf7

Please sign in to comment.