Raven Cheuk
Raven Cheuk

Reputation: 3053

Implementing STFT with Pytorch gives a slightly different result than the STFT with Librose

I am trying to implement STFT with Pytorch. But the output from the Pytorch implementation is slightly off, when compared with the implementation from Librosa.

Librosa version

import numpy as np
from librosa.core import stft
import matplotlib.pyplot as plt

np.random.seed(3)
y = np.sin(2*np.pi*50*np.linspace(0,10,2048))+np.sin(2*np.pi*20*np.linspace(0,10,2048)) + np.random.normal(scale=1,size=2048)

S_stft = np.abs(stft(y, hop_length=512, n_fft=2048,center=False))

plt.plot(S_stft)

enter image description here

Pytorch version

import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d

from scipy.signal.windows import hann

stride = 512

def create_filters(d,k,low=50,high=6000):
    x = np.arange(0, d, 1)
    wsin = np.empty((k,1,d), dtype=np.float32)
    wcos = np.empty((k,1,d), dtype=np.float32)
    start_freq = low
    end_freq = high
    # num_cycles = start_freq*d/44000.
    # scaling_ind = np.log(end_freq/start_freq)/k

    window_mask = hann(2048, sym=False) # same as 0.5-0.5*np.cos(2*np.pi*x/(k))
    for ind in range(k):
        wsin[ind,0,:] = window_mask*np.sin(2*np.pi*ind/k*x)
        wcos[ind,0,:] = window_mask*np.cos(2*np.pi*ind/k*x)

    return wsin,wcos

wsin, wcos = create_filters(2048,2048)

wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)

network_input = torch.from_numpy(y).float()
network_input = network_input.reshape(1,-1)

zx = np.sqrt(conv1d(network_input[:,None,:], wsin_var, stride=stride).pow(2)+conv1d(network_input[:,None,:], wcos_var, stride=stride).pow(2))
pytorch_Xs = zx.cpu().numpy()
plt.plot(pytorch_Xs[0,:1025,0])

enter image description here

My Question

The two graphs might look the same, but if I check the two outputs with np.allclose, we can see that they are slightly different.

np.allclose(S_stft, pytorch_Xs[0,:1025,0].reshape(1025,1))
output >>> False

Only when I tune up the tolerance to 1e-5, it gives me True result

np.allclose(S_stft, pytorch_Xs[0,:1025,0].reshape(1025,1),atol=1e-5)
output >>> True

What causes the difference in values? Is it because of the data conversion by using torch.from_numpy(y).float()?

I would like to have a difference in value less than 1e-7, 1e-8 is even better.

Upvotes: 2

Views: 2159

Answers (1)

Hirotoshi Takeuchi
Hirotoshi Takeuchi

Reputation: 11

The difference is from the difference between their default bit. NumPy's float is 64bit by default. PyTorch's float is 32bit by default.

Upvotes: 1

Related Questions