Bisnu Sarkar
Bisnu Sarkar

Reputation: 21

Getting NAN in loss function when training with multi gpu setup in tensorflow

I am recently training a deep learning model with multi GPU setup(dual 3090ti) with tf.distribute.MirroredStrategy().

In loss function, I have two types of loss; One is waveform loss and another is multi-resolution stft loss. Previously I have trained with waveform loss with multi-setup and it trained properly.

Now I am integrating multi-resolution stft loss with waveform loss. When I am trying with a single Gpu setup it is working properly(I have dual Gpus, and I have tested in both Gpu's with a single Gpu setup that is working correctly). But now when I try to do the same thing with multi GPU setup I am getting NAN values after tf.signal.stft function through which I calculate stft values(Inputs are not having any NAN values). I didn't find any proper reason why happening this.

Here is my loss function code. Custom loss is the class where overall losses are calculated.

import tensorflow as tf

epsilon = 1e-07
def stft(x, fft_size, hop_size, win_length, window):
    """Performs STFT and computes magnitude spectrogram."""
    window_fn = tf.signal.hann_window
    pad_amount = fft_size // 2
    x = tf.pad(x, [[0,0],[pad_amount, pad_amount]], mode='REFLECT')

    x_stft = tf.signal.stft(x, fft_length=fft_size, frame_step=hop_size, frame_length=win_length, window_fn=window_fn)

    real = tf.math.real(x_stft)
    imag = tf.math.imag(x_stft)
    magnitude = tf.sqrt(tf.maximum(real**2 + imag**2, epsilon))

    #tf.print("Stft ", magnitude)
    return magnitude

class SpectralConvergenceLoss(tf.keras.layers.Layer):
    """Spectral convergence loss."""

    def call(self, x_mag, y_mag):        
        numerator = tf.norm(y_mag -x_mag, ord=2 )
        denominator = tf.norm(y_mag, ord=2)

        loss = numerator / denominator
        return loss




class LogSTFTMagnitudeLoss(tf.keras.layers.Layer):
    """Log STFT magnitude loss."""

    
    def call(self, x_mag, y_mag):
        loss = tf.reduce_mean(tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag)))
        return loss

class STFTLoss(tf.keras.layers.Layer):
    """STFT loss."""

    def __init__(self, fft_size, hop_size, win_length, window='hann',batch_size=1):
        super().__init__()
        self.fft_size = fft_size
        self.hop_size = hop_size
        self.win_length = win_length
        self.window = window
        self.spectral_convergence_loss = SpectralConvergenceLoss()
        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
    

    def call(self, x, y):
        x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window)
        y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window)
        sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
        return sc_loss, mag_loss

class MultiResolutionSTFTLoss(tf.keras.layers.Layer):
    """Multi resolution STFT loss."""

    def __init__(self,
                 fft_sizes=[1024, 2048, 512],
                 hop_sizes=[120, 240, 50],
                 win_lengths=[600, 1200, 240],
                 window="hann", factor_sc=0.1, factor_mag=0.1,batch_size=1):
        super().__init__()
        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
        self.stft_losses = [
            STFTLoss(fs, ss, wl, window,batch_size)
            for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths)
        ]
        # self.stft_losses = tf.keras.layers.List()
        # for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
        #     self.stft_losses.append(STFTLoss(fs, ss, wl, window))
        self.factor_sc = factor_sc
        self.factor_mag = factor_mag
        

    def call(self, x, y):
        sc_loss = 0.0
        mag_loss = 0.0
        for f in self.stft_losses:
            sc_l, mag_l = f(x, y)
            sc_loss += sc_l
            mag_loss += mag_l
        sc_loss /= len(self.stft_losses)
        mag_loss /= len(self.stft_losses)
        self.factor_sc=0.125
        self.factor_mag=0.125
        return self.factor_sc * sc_loss, self.factor_mag * mag_loss

class custom_loss(tf.keras.losses.Loss):
    def __init__(self, BATCH_SIZE=-1, extra=0.0, **kwargs):
        super(custom_loss, self).__init__(**kwargs)
        self.BATCH_SIZE = BATCH_SIZE
        self.mrstft_loss = MultiResolutionSTFTLoss(batch_size=BATCH_SIZE)

    # self.Model_size = extra

    def call(self, y_true, y_pred):
        tf.print("Inout true ", y_true)
        tf.print("Inout pred ", y_pred)

        y_true = tf.squeeze(y_true, [-1])
        y_pred = tf.squeeze(y_pred, [-1])
      

        # wavefrom_loss = tf.keras.metrics.mean_absolute_error(y_true, y_pred)
        # wavefrom_loss = tf.nn.compute_average_loss(wavefrom_loss, global_batch_size=self.BATCH_SIZE)

        wavefrom_loss = tf.reduce_mean(tf.abs(y_true - y_pred))
        #return tf.nn.compute_average_loss(loss, global_batch_size=self.BATCH_SIZE)
        sc_loss, mag_loss = self.mrstft_loss(y_true, y_pred)
        #tf.print("Batch size inside ",self.BATCH_SIZE)
        tf.print("losses : ", wavefrom_loss, sc_loss, mag_loss)
     

        return tf.reduce_mean(wavefrom_loss + sc_loss + mag_loss)
        

    def get_config(self):
        config = super().get_config().copy()
        config.update({'BATCH_SIZE': self.BATCH_SIZE})
        return config

Multi-GPU setup training code:

strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = args.batch_size
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
with strategy.scope():
        model = make_or_restore_model(args, BATCH_SIZE_PER_REPLICA)
        train_dataset = get_dataset(args, BATCH_SIZE)
        train_dataset = strategy.experimental_distribute_dataset(train_dataset) 
    callbacks = [
      tf.keras.callbacks.ModelCheckpoint(
      filepath=os.path.join(args.checkpoints_dir , steps + "_epoch-{epoch:03d}_loss-{loss:.6f}.h5"),
      save_best_only=False,
      save_weights_only=False,
      save_freq = 110
      )
    ]

Can you please help find out the probable reason behind getting NAN when I am training with multi GPU setup and how to solve this?

Upvotes: 1

Views: 257

Answers (1)

Divide-by-zero error, Incorrect output size of the model, System configuration issue, Empty batch creation, Loss function and activation,

Upvotes: -1

Related Questions