kubicwerke
kubicwerke

Reputation: 45

How can I change dataset at every epoch in keras?

Assume that I have noisy_features and clean_features. As the names suggest noisy_features are extracted from noisy input samples and clean_features are extracted from clean input samples. mixture_rate is arbitrary number defines how much noisy and clean features will be used. I want to change my network written in keras with the following pseudo code:

train_network():
    clean_features = get_features(clean_inputs)
    noisy_features = get_features(noisy_inputs)
    number_of_mix_features = mixture_rate*number_of_features

    for epoch in number_of_epochs:
       random_indices = random(number_of_mix_features)
       input_features = clean_features
       input_features[random_indices]=noisy_features[random_indices]
       train_network()

Function train_network() corresponds to model.fit() in keras but not appropiate in this context because epoch loop already inside of the model.fit().

So, the problem is that I can't change input features every epoch in keras. If I change input_features before usemodel.fit() random_indices will stay same during the whole training.

I tried using callbacks:

class Noisify(Callback):
    def __init__(self,mixture_rate = 0.2):
         self.mixture_rate = mixture_rate

    def on_epoch_begin(self,epoch,logs={}):
         # get input_tensor

         '''
         mix randomly noisy and clean features here

         '''

         # set input_tensor

but here I don't know a method to get and set the input tensor like model.get_layer or get_weights

Is there any suggestions?

Upvotes: 2

Views: 1729

Answers (1)

gauravtolani
gauravtolani

Reputation: 130

A straight method I can think of is, sample the data using pandas function .sample() to build the mixed dataset. Have a for loop to call the model for the number of epochs.

so you would have something like:

def data_resampler(): 
    # do data sampling using pandas sample function 

for i in range(<number of epochs>): 
    mixed_data = data_resampler() 
    model(mixed_data) 

comment if you need help with the code.

Upvotes: 1

Related Questions