exAres
exAres

Reputation: 4926

How to train tensorflow on sagemaker in script mode when the data resides in multiple files on s3?

I have a .npy file for each one of the training instances. All of these files are available on S3 in train_data folder. I want to train a tensorflow model on these training instances. To do that, I wish to spin up separate aws training instance for each training job which could access the files from s3 and train the model on it. What changes in the training script are required for doing this?

I have following config in the training script:

parser.add_argument('--gpu-count', type=int, default=os.environ['SM_NUM_GPUS'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train_channel', type=str, default=os.environ['SM_CHANNELS'])

I have created the training estimator in jupyter instance as:

tf_estimator = TensorFlow(entry_point = 'my_model.py', 
                          role = role, 
                          train_instance_count = 1, 
                          train_instance_type = 'local_gpu', 
                          framework_version = '1.15.2', 
                          py_version = 'py3', 
                          hyperparameters = {'epochs': 1})

I am calling the fit function of the estimator as:

tf_estimator.fit({'train_channel':'s3://sagemaker-ml/train_data/'})

where train_data folder on S3 contains the .npy files of training instances.

But when I call the fit function, I get an error:

FileNotFoundError: [Errno 2] No such file or directory: '["train_channel"]/train_data_12.npy'

Not sure what am I missing here, as I can see the file mentioned above on S3.

Upvotes: 0

Views: 444

Answers (1)

lauren
lauren

Reputation: 513

SM_CHANNELS returns a list of channel names. What you're looking for is SM_CHANNEL_TRAIN_CHANNEL ("SM_CHANNEL" + your channel name), which provides the filesystem location for the channel:

parser.add_argument('--train_channel', type=str, default=os.environ['SM_CHANNEL_TRAIN_CHANNEL'])

docs: https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md#sm_channel_channel_name

Upvotes: 1

Related Questions