| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import torchaudio |
| from transformers import AutoModel |
|
|
|
|
| class SpectralConvergengeLoss(torch.nn.Module): |
| """Spectral convergence loss module.""" |
|
|
| def __init__(self): |
| """Initilize spectral convergence loss module.""" |
| super(SpectralConvergengeLoss, self).__init__() |
|
|
| def forward(self, x_mag, y_mag): |
| """Calculate forward propagation. |
| Args: |
| x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). |
| y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). |
| Returns: |
| Tensor: Spectral convergence loss value. |
| """ |
| return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1) |
|
|
|
|
| class STFTLoss(torch.nn.Module): |
| """STFT loss module.""" |
|
|
| def __init__( |
| self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window |
| ): |
| """Initialize STFT loss module.""" |
| super(STFTLoss, self).__init__() |
| self.fft_size = fft_size |
| self.shift_size = shift_size |
| self.win_length = win_length |
| self.to_mel = torchaudio.transforms.MelSpectrogram( |
| sample_rate=24000, |
| n_fft=fft_size, |
| win_length=win_length, |
| hop_length=shift_size, |
| window_fn=window, |
| ) |
|
|
| self.spectral_convergenge_loss = SpectralConvergengeLoss() |
|
|
| def forward(self, x, y): |
| """Calculate forward propagation. |
| Args: |
| x (Tensor): Predicted signal (B, T). |
| y (Tensor): Groundtruth signal (B, T). |
| Returns: |
| Tensor: Spectral convergence loss value. |
| Tensor: Log STFT magnitude loss value. |
| """ |
| x_mag = self.to_mel(x) |
| mean, std = -4, 4 |
| x_mag = (torch.log(1e-5 + x_mag) - mean) / std |
|
|
| y_mag = self.to_mel(y) |
| mean, std = -4, 4 |
| y_mag = (torch.log(1e-5 + y_mag) - mean) / std |
|
|
| sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) |
| return sc_loss |
|
|
|
|
| class MultiResolutionSTFTLoss(torch.nn.Module): |
| """Multi resolution STFT loss module.""" |
|
|
| def __init__( |
| self, |
| fft_sizes=[1024, 2048, 512], |
| hop_sizes=[120, 240, 50], |
| win_lengths=[600, 1200, 240], |
| window=torch.hann_window, |
| ): |
| """Initialize Multi resolution STFT loss module. |
| Args: |
| fft_sizes (list): List of FFT sizes. |
| hop_sizes (list): List of hop sizes. |
| win_lengths (list): List of window lengths. |
| window (str): Window function type. |
| """ |
| super(MultiResolutionSTFTLoss, self).__init__() |
| assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) |
| self.stft_losses = torch.nn.ModuleList() |
| for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): |
| self.stft_losses += [STFTLoss(fs, ss, wl, window)] |
|
|
| def forward(self, x, y): |
| """Calculate forward propagation. |
| Args: |
| x (Tensor): Predicted signal (B, T). |
| y (Tensor): Groundtruth signal (B, T). |
| Returns: |
| Tensor: Multi resolution spectral convergence loss value. |
| Tensor: Multi resolution log STFT magnitude loss value. |
| """ |
| sc_loss = 0.0 |
| for f in self.stft_losses: |
| sc_l = f(x, y) |
| sc_loss += sc_l |
| sc_loss /= len(self.stft_losses) |
|
|
| return sc_loss |
|
|
|
|
| def feature_loss(fmap_r, fmap_g): |
| loss = 0 |
| for dr, dg in zip(fmap_r, fmap_g): |
| for rl, gl in zip(dr, dg): |
| loss += torch.mean(torch.abs(rl - gl)) |
|
|
| return loss * 2 |
|
|
|
|
| def discriminator_loss(disc_real_outputs, disc_generated_outputs): |
| loss = 0 |
| r_losses = [] |
| g_losses = [] |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
| r_loss = torch.mean((1 - dr) ** 2) |
| g_loss = torch.mean(dg**2) |
| loss += r_loss + g_loss |
| r_losses.append(r_loss.item()) |
| g_losses.append(g_loss.item()) |
|
|
| return loss, r_losses, g_losses |
|
|
|
|
| def generator_loss(disc_outputs): |
| loss = 0 |
| gen_losses = [] |
| for dg in disc_outputs: |
| l = torch.mean((1 - dg) ** 2) |
| gen_losses.append(l) |
| loss += l |
|
|
| return loss, gen_losses |
|
|
|
|
| """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """ |
|
|
|
|
| def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): |
| loss = 0 |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
| tau = 0.04 |
| m_DG = torch.median((dr - dg)) |
| L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) |
| loss += tau - F.relu(tau - L_rel) |
| return loss |
|
|
|
|
| def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs): |
| loss = 0 |
| for dg, dr in zip(disc_real_outputs, disc_generated_outputs): |
| tau = 0.04 |
| m_DG = torch.median((dr - dg)) |
| L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) |
| loss += tau - F.relu(tau - L_rel) |
| return loss |
|
|
|
|
| class GeneratorLoss(torch.nn.Module): |
| def __init__(self, mpd, msd): |
| super(GeneratorLoss, self).__init__() |
| self.mpd = mpd |
| self.msd = msd |
|
|
| def forward(self, y, y_hat): |
| y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat) |
| y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat) |
| loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) |
| loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) |
| loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) |
| loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) |
|
|
| loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss( |
| y_ds_hat_r, y_ds_hat_g |
| ) |
|
|
| loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel |
|
|
| return loss_gen_all.mean() |
|
|
|
|
| class DiscriminatorLoss(torch.nn.Module): |
| def __init__(self, mpd, msd): |
| super(DiscriminatorLoss, self).__init__() |
| self.mpd = mpd |
| self.msd = msd |
|
|
| def forward(self, y, y_hat): |
| |
| y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat) |
| loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( |
| y_df_hat_r, y_df_hat_g |
| ) |
| |
| y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat) |
| loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( |
| y_ds_hat_r, y_ds_hat_g |
| ) |
|
|
| loss_rel = discriminator_TPRLS_loss( |
| y_df_hat_r, y_df_hat_g |
| ) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) |
|
|
| d_loss = loss_disc_s + loss_disc_f + loss_rel |
|
|
| return d_loss.mean() |
|
|
|
|
| class WavLMLoss(torch.nn.Module): |
| def __init__(self, model, wd, model_sr, slm_sr=16000): |
| super(WavLMLoss, self).__init__() |
| self.wavlm = AutoModel.from_pretrained(model) |
| self.wd = wd |
| self.resample = torchaudio.transforms.Resample(model_sr, slm_sr) |
|
|
| def forward(self, wav, y_rec): |
| with torch.no_grad(): |
| wav_16 = self.resample(wav) |
| wav_embeddings = self.wavlm( |
| input_values=wav_16, output_hidden_states=True |
| ).hidden_states |
| y_rec_16 = self.resample(y_rec) |
| y_rec_embeddings = self.wavlm( |
| input_values=y_rec_16.squeeze(), output_hidden_states=True |
| ).hidden_states |
|
|
| floss = 0 |
| for er, eg in zip(wav_embeddings, y_rec_embeddings): |
| floss += torch.mean(torch.abs(er - eg)) |
|
|
| return floss.mean() |
|
|
| def generator(self, y_rec): |
| y_rec_16 = self.resample(y_rec) |
| y_rec_embeddings = self.wavlm( |
| input_values=y_rec_16, output_hidden_states=True |
| ).hidden_states |
| y_rec_embeddings = ( |
| torch.stack(y_rec_embeddings, dim=1) |
| .transpose(-1, -2) |
| .flatten(start_dim=1, end_dim=2) |
| ) |
| y_df_hat_g = self.wd(y_rec_embeddings) |
| loss_gen = torch.mean((1 - y_df_hat_g) ** 2) |
|
|
| return loss_gen |
|
|
| def discriminator(self, wav, y_rec): |
| with torch.no_grad(): |
| wav_16 = self.resample(wav) |
| wav_embeddings = self.wavlm( |
| input_values=wav_16, output_hidden_states=True |
| ).hidden_states |
| y_rec_16 = self.resample(y_rec) |
| y_rec_embeddings = self.wavlm( |
| input_values=y_rec_16, output_hidden_states=True |
| ).hidden_states |
|
|
| y_embeddings = ( |
| torch.stack(wav_embeddings, dim=1) |
| .transpose(-1, -2) |
| .flatten(start_dim=1, end_dim=2) |
| ) |
| y_rec_embeddings = ( |
| torch.stack(y_rec_embeddings, dim=1) |
| .transpose(-1, -2) |
| .flatten(start_dim=1, end_dim=2) |
| ) |
|
|
| y_d_rs = self.wd(y_embeddings) |
| y_d_gs = self.wd(y_rec_embeddings) |
|
|
| y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs |
|
|
| r_loss = torch.mean((1 - y_df_hat_r) ** 2) |
| g_loss = torch.mean((y_df_hat_g) ** 2) |
|
|
| loss_disc_f = r_loss + g_loss |
|
|
| return loss_disc_f.mean() |
|
|
| def discriminator_forward(self, wav): |
| with torch.no_grad(): |
| wav_16 = self.resample(wav) |
| wav_embeddings = self.wavlm( |
| input_values=wav_16, output_hidden_states=True |
| ).hidden_states |
| y_embeddings = ( |
| torch.stack(wav_embeddings, dim=1) |
| .transpose(-1, -2) |
| .flatten(start_dim=1, end_dim=2) |
| ) |
|
|
| y_d_rs = self.wd(y_embeddings) |
|
|
| return y_d_rs |
|
|