czr
czr

Reputation: 658

estimator.fit hangs on sagemaker on local mode

I am trying to train a pytorch model using Sagemaker on local mode, but whenever I call estimator.fit the code hangs indefinitely and I have to interrupt the notebook kernel. This happens both in my local machine and in Sagemaker Studio. But when I use EC2, the training runs normally.

Here the call to the estimator, and the stack trace once I interrupt the kernel:

import sagemaker
from sagemaker.pytorch import PyTorch

bucket = "bucket-name"
role = sagemaker.get_execution_role()
training_input_path = f"s3://{bucket}/dataset/path"

sagemaker_session = sagemaker.LocalSession()
sagemaker_session.config = {"local": {"local_code": True}}

output_path = "file://."

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    hyperparameters={"max-epochs": 1},
    framework_version="1.8",
    py_version="py3",
    instance_count=1,
    instance_type="local",
    role=role,
    output_path=output_path,
    sagemaker_session=sagemaker_session,
)


estimator.fit({"training": training_input_path})

Stack trace:

    ---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-9-35cdd6021288> in <module>
----> 1 estimator.fit({"training": training_input_path})

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name, experiment_config)
    678         self._prepare_for_training(job_name=job_name)
    679 
--> 680         self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
    681         self.jobs.append(self.latest_training_job)
    682         if wait:

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in start_new(cls, estimator, inputs, experiment_config)
   1450         """
   1451         train_args = cls._get_train_args(estimator, inputs, experiment_config)
-> 1452         estimator.sagemaker_session.train(**train_args)
   1453 
   1454         return cls(estimator.sagemaker_session, estimator._current_job_name)

/opt/conda/lib/python3.7/site-packages/sagemaker/session.py in train(self, input_mode, input_config, role, job_name, output_config, resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions, enable_network_isolation, image_uri, algorithm_arn, encrypt_inter_container_traffic, use_spot_instances, checkpoint_s3_uri, checkpoint_local_path, experiment_config, debugger_rule_configs, debugger_hook_config, tensorboard_output_config, enable_sagemaker_metrics, profiler_rule_configs, profiler_config, environment, retry_strategy)
    572         LOGGER.info("Creating training-job with name: %s", job_name)
    573         LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
--> 574         self.sagemaker_client.create_training_job(**train_request)
    575 
    576     def _get_train_request(  # noqa: C901

/opt/conda/lib/python3.7/site-packages/sagemaker/local/local_session.py in create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig, ResourceConfig, InputDataConfig, **kwargs)
    184         hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
    185         logger.info("Starting training job")
--> 186         training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
    187 
    188         LocalSagemakerClient._training_jobs[TrainingJobName] = training_job

/opt/conda/lib/python3.7/site-packages/sagemaker/local/entities.py in start(self, input_data_config, output_data_config, hyperparameters, job_name)
    219 
    220         self.model_artifacts = self.container.train(
--> 221             input_data_config, output_data_config, hyperparameters, job_name
    222         )
    223         self.end_time = datetime.datetime.now()

/opt/conda/lib/python3.7/site-packages/sagemaker/local/image.py in train(self, input_data_config, output_data_config, hyperparameters, job_name)
    200         data_dir = self._create_tmp_folder()
    201         volumes = self._prepare_training_volumes(
--> 202             data_dir, input_data_config, output_data_config, hyperparameters
    203         )
    204         # If local, source directory needs to be updated to mounted /opt/ml/code path

/opt/conda/lib/python3.7/site-packages/sagemaker/local/image.py in _prepare_training_volumes(self, data_dir, input_data_config, output_data_config, hyperparameters)
    487             os.mkdir(channel_dir)
    488 
--> 489             data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
    490             volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
    491 

/opt/conda/lib/python3.7/site-packages/sagemaker/local/data.py in get_data_source_instance(data_source, sagemaker_session)
     52         return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
     53     if parsed_uri.scheme == "s3":
---> 54         return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
     55     raise ValueError(
     56         "data_source must be either file or s3. parsed_uri.scheme: {}".format(parsed_uri.scheme)

/opt/conda/lib/python3.7/site-packages/sagemaker/local/data.py in __init__(self, bucket, prefix, sagemaker_session)
    183             working_dir = "/private{}".format(working_dir)
    184 
--> 185         sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session)
    186         self.files = LocalFileDataSource(working_dir)
    187 

/opt/conda/lib/python3.7/site-packages/sagemaker/utils.py in download_folder(bucket_name, prefix, target, sagemaker_session)
    286                 raise
    287 
--> 288     _download_files_under_prefix(bucket_name, prefix, target, s3)
    289 
    290 

/opt/conda/lib/python3.7/site-packages/sagemaker/utils.py in _download_files_under_prefix(bucket_name, prefix, target, s3)
    314             if exc.errno != errno.EEXIST:
    315                 raise
--> 316         obj.download_file(file_path)
    317 
    318 

/opt/conda/lib/python3.7/site-packages/boto3/s3/inject.py in object_download_file(self, Filename, ExtraArgs, Callback, Config)
    313     return self.meta.client.download_file(
    314         Bucket=self.bucket_name, Key=self.key, Filename=Filename,
--> 315         ExtraArgs=ExtraArgs, Callback=Callback, Config=Config)
    316 
    317 

/opt/conda/lib/python3.7/site-packages/boto3/s3/inject.py in download_file(self, Bucket, Key, Filename, ExtraArgs, Callback, Config)
    171         return transfer.download_file(
    172             bucket=Bucket, key=Key, filename=Filename,
--> 173             extra_args=ExtraArgs, callback=Callback)
    174 
    175 

/opt/conda/lib/python3.7/site-packages/boto3/s3/transfer.py in download_file(self, bucket, key, filename, extra_args, callback)
    305             bucket, key, filename, extra_args, subscribers)
    306         try:
--> 307             future.result()
    308         # This is for backwards compatibility where when retries are
    309         # exceeded we need to throw the same error from boto3 instead of

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    107         except KeyboardInterrupt as e:
    108             self.cancel()
--> 109             raise e
    110 
    111     def cancel(self):

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    104             # however if a KeyboardInterrupt is raised we want want to exit
    105             # out of this and propogate the exception.
--> 106             return self._coordinator.result()
    107         except KeyboardInterrupt as e:
    108             self.cancel()

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    258         # possible value integer value, which is on the scale of billions of
    259         # years...
--> 260         self._done_event.wait(MAXINT)
    261 
    262         # Once done waiting, raise an exception if present or return the

/opt/conda/lib/python3.7/threading.py in wait(self, timeout)
    550             signaled = self._flag
    551             if not signaled:
--> 552                 signaled = self._cond.wait(timeout)
    553             return signaled
    554 

/opt/conda/lib/python3.7/threading.py in wait(self, timeout)
    294         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    295             if timeout is None:
--> 296                 waiter.acquire()
    297                 gotit = True
    298             else:

KeyboardInterrupt: 

Upvotes: 0

Views: 1308

Answers (1)

Sam Edwards
Sam Edwards

Reputation: 51

SageMaker Studio does not natively support local mode. Studio Apps are themselves docker containers and therefore they require privileged access if they were to be able to build and run docker containers.

As an alternative solution, you can create a remote docker host on an EC2 instance and setup docker on your Studio App. There is quite a bit of networking and package installation involved, but the solution will enable you to use full docker functionality. Additionally, as of version 2.80.0 of SageMaker Python SDK, it now supports local mode when you are using remote docker host.

sdockerSageMaker Studio Docker CLI extension (see this repo) can simplify deploying the above solution in simple two steps (only works for Studio Domain in VPCOnly mode) and it has an easy to follow example here.

UPDATE: There is now a UI extension (see repo) which can make the experience much smoother and easier to manage. sdocker UI

Upvotes: 1

Related Questions