Skip to content

Commit

Permalink
Don't use hparams internal value
Browse files Browse the repository at this point in the history
  • Loading branch information
alealv committed Jun 16, 2023
1 parent 03c4fcb commit 130eb63
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions vocos/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def __init__(
self.train_discriminator = False
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff

self.initial_learning_rate = initial_learning_rate
self.num_warmup_steps = num_warmup_steps
self.mrd_loss_coeff = mrd_loss_coeff
self.pretrain_mel_steps = pretrain_mel_steps
self.decay_mel_coeff = decay_mel_coeff
self.evaluate_utmos = evaluate_utmos
self.evaluate_pesq = evaluate_pesq
self.evaluate_periodicty = evaluate_periodicty

def configure_optimizers(self):
disc_params = [
{"params": self.multiperioddisc.parameters()},
Expand All @@ -78,15 +87,15 @@ def configure_optimizers(self):
{"params": self.head.parameters()},
]

opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
opt_disc = torch.optim.AdamW(disc_params, lr=self.initial_learning_rate)
opt_gen = torch.optim.AdamW(gen_params, lr=self.initial_learning_rate)

max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
opt_disc, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps
)
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
opt_gen, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps
)

return (
Expand Down Expand Up @@ -118,7 +127,7 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
)
loss_mp /= len(loss_mp_real)
loss_mrd /= len(loss_mrd_real)
loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd
loss = loss_mp + self.mrd_loss_coeff * loss_mrd

self.log("discriminator/total", loss, prog_bar=True)
self.log("discriminator/multi_period_loss", loss_mp)
Expand Down Expand Up @@ -152,9 +161,9 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
mel_loss = self.melspec_loss(audio_hat, audio_input)
loss = (
loss_gen_mp
+ self.hparams.mrd_loss_coeff * loss_gen_mrd
+ self.mrd_loss_coeff * loss_gen_mrd
+ loss_fm_mp
+ self.hparams.mrd_loss_coeff * loss_fm_mrd
+ self.mrd_loss_coeff * loss_fm_mrd
+ self.mel_loss_coeff * mel_loss
)

Expand All @@ -164,10 +173,10 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):

if self.global_step % 1000 == 0 and self.global_rank == 0:
self.logger.experiment.add_audio(
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.sample_rate
)
self.logger.experiment.add_audio(
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.sample_rate
)
with torch.no_grad():
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
Expand All @@ -188,7 +197,7 @@ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
return loss

def on_validation_epoch_start(self):
if self.hparams.evaluate_utmos:
if self.evaluate_utmos:
from metrics.UTMOS import UTMOSScore

if not hasattr(self, "utmos_model"):
Expand All @@ -198,22 +207,22 @@ def validation_step(self, batch, batch_idx, **kwargs):
audio_input = batch
audio_hat = self(audio_input, **kwargs)

audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.sample_rate, new_freq=16000)
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.sample_rate, new_freq=16000)

if self.hparams.evaluate_periodicty:
if self.evaluate_periodicty:
from metrics.periodicity import calculate_periodicity_metrics

periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
else:
periodicity_loss = pitch_loss = f1_score = 0

if self.hparams.evaluate_utmos:
if self.evaluate_utmos:
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
else:
utmos_score = torch.zeros(1, device=self.device)

if self.hparams.evaluate_pesq:
if self.evaluate_pesq:
from pesq import pesq

pesq_score = 0
Expand Down Expand Up @@ -243,10 +252,10 @@ def validation_epoch_end(self, outputs):
if self.global_rank == 0:
*_, audio_in, audio_pred = outputs[0].values()
self.logger.experiment.add_audio(
"val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
"val_in", audio_in.data.cpu().numpy(), self.global_step, self.sample_rate
)
self.logger.experiment.add_audio(
"val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
"val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.sample_rate
)
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
Expand Down Expand Up @@ -286,22 +295,22 @@ def global_step(self):
return self.trainer.fit_loop.epoch_loop.total_batch_idx

def on_train_batch_start(self, *args):
if self.global_step >= self.hparams.pretrain_mel_steps:
if self.global_step >= self.pretrain_mel_steps:
self.train_discriminator = True
else:
self.train_discriminator = False

def on_train_batch_end(self, *args):
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
max_steps = self.trainer.max_steps // 2
if current_step < self.hparams.num_warmup_steps:
if current_step < self.num_warmup_steps:
return 1.0
progress = float(current_step - self.hparams.num_warmup_steps) / float(
max(1, max_steps - self.hparams.num_warmup_steps)
progress = float(current_step - self.num_warmup_steps) / float(
max(1, max_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

if self.hparams.decay_mel_coeff:
if self.decay_mel_coeff:
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)


Expand Down Expand Up @@ -365,7 +374,7 @@ def validation_epoch_end(self, outputs):
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
self.logger.experiment.add_audio(
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.sample_rate,
)

super().validation_epoch_end(outputs)

0 comments on commit 130eb63

Please sign in to comment.