BerndKarlsb
BerndKarlsb

Reputation: 39

Load custom dataset like Mnist ( Tensorflow Python )

I'm experimenting with a clustering model from https://github.com/astirn/IIC (already tried to contact him about this)

It uses Mnist Dataset like in most research papers. In here they first define the dataset name as 'mnist' which is enough for tensorflow to import mnist from their standard online datasets. Then he loads the dataset using the tensorflow_dataset.load() function

I have created a tfrecord file for my dataset and now I just need to replace the part where the beforementioned script points to "mnist" (line 1 in code below) and instead point to my local dataset.

Do I just replace 'mnist' with the file path in the first line???

Code from actual training model file:

if __name__ == '__main__':
# pick a data set
    DATA_SET = 'mnist'

# define splits
    DS_CONFIG = {
        # mnist data set parameters
        'mnist': {
            'batch_size': 700,
            'num_repeats': 5,
            'mdl_input_dims': [24, 24, 1]}
    }

# load the data set
    TRAIN_SET, TEST_SET, SET_INFO = load(data_set_name=DATA_SET, **DS_CONFIG[DATA_SET])

# configure the common model elements
    MDL_CONFIG = {
    # mist hyper-parameters
        'mnist': {
            'num_classes': SET_INFO.features['label'].num_classes,
            'learning_rate': 1e-4,
            'num_repeats': DS_CONFIG[DATA_SET]['num_repeats'],
            'save_dir': None},
    }

Code from 'data preparation file' where he calls the dataset with tensorflor_dataset.load as tfds.load :

def load(data_set_name, **kwargs):
    """
    :param data_set_name: data set name--call tfds.list_builders() for options
    :return:
        train_ds: TensorFlow Dataset object for the training data
        test_ds: TensorFlow Dataset object for the testing data
        info: data set info object
    """
    # get data and its info
    ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True)

thanks for the help

Upvotes: 0

Views: 1531

Answers (1)

artona
artona

Reputation: 1272

According to docs, you need to use download parameter as False and data_dir with the name of directory:

ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True, download=False, data_dir="/path/to/file")

Upvotes: 1

Related Questions