G. Macia
G. Macia

Reputation: 1501

How to deploy Pytorch in Python via a REST API with Flask?

I am working on AWS Sagemaker and my goal is to follow this tutorial from Pytorch's official documentation.

The original predict function from the tutorial above is the following:

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

I was getting this error, so I added 'GET' as a method as mentioned in here. I also simplified my example to its minimal expression:

from flask import Flask, jsonify, request

app = Flask(__name__)

@app.route('/predict', methods=['GET','POST'])
def predict():
    if request.method == 'POST':
        return jsonify({'class_name': 'cat'})
    return 'OK'

if __name__ == '__main__':
    app.run()

I perform requests with the following code:

import requests

resp = requests.post("https://catdogclassifier.notebook.eu-west-1.sagemaker.aws/proxy/5000/predict",
                     files={"file": open('/home/ec2-user/SageMaker/cat.jpg', 'rb')})

resp is <Response [200]> but resp.json() returns JSONDecodeError: Expecting value: line 1 column 1 (char 0) Finally, resp.url points me to a page saying 'OK'.

Moreover, this is the output of resp.content

<!DOCTYPE HTML>

<html>
<head>
  <style type="text/css">

#loadingImage {
    margin: 10em auto;
    width: 234px;
    height: 238px;
    background-repeat: no-repeat;
    background-image: url();

    -webkit-animation:spin 4s linear infinite;
    -moz-animation:spin 4s linear infinite;
    animation:spin 4s linear infinite;
}


@-moz-keyframes spin { 100% { -moz-transform: rotate(360deg); } }
@-webkit-keyframes spin { 100% { -webkit-transform: rotate(360deg); } }
@keyframes spin { 100% { -webkit-transform: rotate(360deg); transform:rotate(360deg); } }

  </style>
</head>

<body>

  <div id="loadingImage"></div>

<script type="text/javascript">



var RegionFinder = (function()
{
    function RegionFinder( location ) {
        this.location = location;
    }


    RegionFinder.prototype = {


        getURLWithRegion: function() {

            var isDynamicDefaultRegion = ifPathContains(this.location.pathname, "region/dynamic-default-region");

            var queryArgs = removeURLParameter(this.location.search, "region");

            var hashArgs = this.location.href.split("#")[1] || "";
            if (hashArgs) {
                hashArgs = "#" + hashArgs;
            }

            var region = this._getCurrentRegion();
            var newArgs = "region=" + region;
            if (_shouldAuth()) {
                newArgs = "needs_auth=true";
                region = "nil";
            }

            if (queryArgs &&
                queryArgs != "?") {
                queryArgs += "&" + newArgs;
            } else {
                queryArgs = "?" + newArgs;
            }



            if (!region) {

                var contactUs = "https://portal.aws.amazon.com/gp/aws/html-forms-controller/contactus/aws-report-issue1";

                alert("How embarrassing! There is something wrong with this URL, please contact AWS at " + contactUs);
            }

            var pathname = isDynamicDefaultRegion ?  "/console/home" : this.location.pathname;

            return this.location.protocol + "//" + _getRedirectHostFromAttributes() +
                pathname + queryArgs + hashArgs;
        },


        _getCurrentRegion: function() {

            return _getRegionFromHash( this.location ) ||
                   _getRegionFromAttributes();
        }
    };



    function ifPathContains(url, parameter) {
        return (url.indexOf(parameter) != -1);
    }


    function removeURLParameter(url, parameter) {
        var urlparts= url.split('?');
        if (urlparts.length>=2) {
            var prefix= encodeURIComponent(parameter);
            var pars= urlparts[1].split(/[&;]/g);
            //reverse iteration as may be destructive
            for (var i= pars.length; i-- > 0;) {
                if (pars[i].lastIndexOf(prefix, 0) !== -1) {
                    pars.splice(i, 1);
                }
            }
            url= urlparts[0]+'?'+pars.join('&');
            return url;
        } else {
            return url;
        }
    }


    function _getRegionFromAttributes() {
        return "eu-west-1";
    };

    function _shouldAuth() {
        return "";
    };

    function _getRedirectHostFromAttributes() {
        return "eu-west-1.console.aws.amazon.com";
    }


    function _getRegionFromHash( location ) {

        var hashArgs = "#" + (location.href.split("#")[1] || "");


        var hashRegionArg = "";


        var match = hashArgs.match("region=([a-zA-Z0-9-]+)");
        if (match && match.length > 1 && match[1]) {
            hashRegionArg = match[1];
        }
        return hashRegionArg;
    }

    return RegionFinder;
})();


var regionFinder = new RegionFinder( window.location );

window.location.href = regionFinder.getURLWithRegion();

</script>


</body>
</html>

What am I missing?

Upvotes: 0

Views: 685

Answers (1)

shrmaab
shrmaab

Reputation: 26

Looks like the content of your resp is HTML as opposed to JSON; this is likely a consequence of how the Jupyter server proxy endpoint you're attempting to POST to (https://catdogclassifier.notebook.eu-west-1.sagemaker.aws/proxy/5000/predict) is configured.

It looks like you're using a SageMaker notebook instance, so you might not have much control over this configuration. A workaround could be to instead deploy your Flask server as a SageMaker endpoint running outside JupyterLab, instead of directly on a notebook instance.

If you want to prototype using only a notebook instance, you can alternately just bypass the proxy entirely and simply call your Flask route relative to localhost from another notebook tab while the Flask server runs in your main notebook tab:

import requests

resp = requests.post("https://localhost:5000/predict",
                     files={"file": open('/home/ec2-user/SageMaker/cat.jpg', 'rb')})

Upvotes: 1

Related Questions