Reputation: 19
I need to do CNN audio classification on insect data. i have a very small dataset of 158 recordings and 9 classes. I need to perform transfer learning using AudioSet pre-trained model. I followed the instructions on tensorflow website: https://www.tensorflow.org/tutorials/audio/transfer_learning_audio but i can't make it work as I get constant low training (40%) and validation acccuracy (10%) in all epochs and 80% in test. I dont understand why is that and what i am doing wrong. i think the problem is in the training validation test splitting. Below is the code:
yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'
yamnet_model = hub.load(yamnet_model_handle)
def load_wav_16k_mono(filename):
""" Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
file_contents = tf.io.read_file(filename)
wav, sample_rate = tf.audio.decode_wav(
file_contents,
desired_channels=1)
wav = tf.squeeze(wav, axis=-1)
sample_rate = tf.cast(sample_rate, dtype=tf.int64)
wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
return wav
class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')
class_names =list(pd.read_csv(class_map_path)['display_name'])
for name in class_names[:20]:
print(name)
print('...')
enter code here
metadataf = "C:\\Users\\Thesis_PC\\orthoptera\\metadata.csv"
base_data_path = "C:\\Users\\Thesis_PC\\orthoptera\\resized16\\"
metadata = pd.read_csv(metadataf)
metadata.head()
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
metadata['species'] = encoder.fit_transform(metadata['species'])
metadata.head()
my_classes = ["Chorthippus biguttulus",
"Chorthippus brunneus",
"Gryllus campestris",
"Nemobius sylvestris",
"Oecanthus pellucens",
"Pholidoptera griseoaptera",
"Pseudochorthippus parallelus" ,
"Roeseliana roeselii",
"Tettigonia viridissima"]
full_path = metadata['id'].apply(lambda row: os.path.join(base_data_path, row))
metadata = metadata.assign(id=full_path)
filenames = metadata['id']
targets = metadata['species']
main_ds = tf.data.Dataset.from_tensor_slices((filenames, targets))
main_ds.element_spec
metadata.head(100)
def load_wav_for_map(filename, label):
return load_wav_16k_mono(filename), label
main_ds = main_ds.map(load_wav_for_map)
main_ds.element_spec
# applies the embedding extraction model to a wav data
def extract_embedding(wav_data, label):
''' run YAMNet to extract embedding from the wav data '''
scores, embeddings, spectrogram = yamnet_model(wav_data)
num_embeddings = tf.shape(embeddings)[0]
return (embeddings,
tf.repeat(label, num_embeddings))
# extract embedding
main_ds = main_ds.map(extract_embedding).unbatch()
main_ds.element_spec
main_ds = main_ds.cache().shuffle(1000)
train_size = 130
val_size = 18
test_size = 10
train_ds = main_ds.take(train_size)
test_ds = main_ds.skip(train_size)
val_ds = test_ds.skip(test_size)
test_ds = test_ds.take(test_size)
train_ds = train_ds.cache().shuffle(15).batch(5).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().batch(5).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.cache().batch(5).prefetch(tf.data.AUTOTUNE)
my_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(1024), dtype=tf.float32,
name='input_embedding'),
tf.keras.layers.Dense(256, activation='softmax'),
tf.keras.layers.Dense(9)
], name='my_model')
my_model.summary()
my_model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
history = my_model.fit(train_ds,
epochs=5,
validation_data=val_ds)
Upvotes: 1
Views: 725
Reputation: 1508
You use categorical cross-entropy as a loss function, so you need to put a softmax activation in the last dense layer of your model :
my_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(1024), dtype=tf.float32,
name='input_embedding'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(9, activation='softmax')
], name='my_model')
The other option is to put no activation in the last layer and take logits in the loss function, as in the tensorflow code you refer to:
my_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam",
metrics=['accuracy'])
Upvotes: 1