Navneeth G
Navneeth G

Reputation: 7305

Making tornado to serve a request on a separate thread

I have a webservice written in Flask, wrapped in a WSGIContainer and served by Tornado by using its FallbackHandler mechanism. One of my routes in the flask webservice runs a very long operation (takes around 5 mins to complete), and when this route is triggered, every other call to any route is blocked until the operation completes. How do I get around this issue?

Here is how my Flask application is served using Tornado:

parse_command_line()

    frontend_path = os.path.join(os.path.dirname(__file__),"..","webapp")

    rest_app = WSGIContainer(app)
    tornado_app = Application(
        [
            (r"/api/(.*)", FallbackHandler, dict(fallback=rest_app)),
            (r"/app/(.*)", StaticFileHandler, dict(path=frontend_path))
        ]
    )

Upvotes: 3

Views: 3454

Answers (5)

honours
honours

Reputation: 1

When tornado work with Flask, please look into WSGIContainer model's source code, and...please see the example code below!!!

from concurrent.futures import ThreadPoolExecutor
import tornado.gen
from tornado.wsgi import WSGIContainer
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado import escape
from tornado import httputil
from typing import List, Tuple, Optional, Callable, Any, Dict
from types import TracebackType

__all__ = ("WSGIContainer_With_Thread",)


class WSGIContainer_With_Thread(WSGIContainer):

    executor = ThreadPoolExecutor(30)

    @tornado.gen.coroutine
    def __call__(self, request):
        data = {}  # type: Dict[str, Any]
        response = []  # type: List[bytes]

        def start_response(
                status: str,
                _headers: List[Tuple[str, str]],
                exc_info: Optional[
                    Tuple[
                        "Optional[Type[BaseException]]",
                        Optional[BaseException],
                        Optional[TracebackType],
                    ]
                ] = None,
        ) -> Callable[[bytes], Any]:
            data["status"] = status
            data["headers"] = _headers
            return response.append

        loop = tornado.ioloop.IOLoop.instance()
        app_response = yield loop.run_in_executor(
            self.executor, self.wsgi_application, WSGIContainer.environ(request), start_response
        )

        # --*-- put this into some executor --*--
        # app_response = self.wsgi_application(
        #     WSGIContainer.environ(request), start_response
        # )
        # --*-- put this into some executor --*--

        try:
            response.extend(app_response)
            body = b"".join(response)
        finally:
            if hasattr(app_response, "close"):
                app_response.close()  # type: ignore
        if not data:
            raise Exception("WSGI app did not call start_response")

        status_code_str, reason = data["status"].split(" ", 1)
        status_code = int(status_code_str)
        headers = data["headers"]  # type: List[Tuple[str, str]]
        header_set = set(k.lower() for (k, v) in headers)
        body = escape.utf8(body)
        if status_code != 304:
            if "content-length" not in header_set:
                headers.append(("Content-Length", str(len(body))))
            if "content-type" not in header_set:
                headers.append(("Content-Type", "text/html; charset=UTF-8"))
        if "server" not in header_set:
            headers.append(("Server", "TornadoServer/%s" % tornado.version))

        start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
        header_obj = httputil.HTTPHeaders()
        for key, value in headers:
            header_obj.add(key, value)
        assert request.connection is not None
        request.connection.write_headers(start_line, header_obj, chunk=body)
        request.connection.finish()
        self._log(status_code, request)


if __name__ == '__main__':

    from flask import Flask
    import time
    from tornado.ioloop import IOLoop

    app = Flask(__name__)

    @app.route('/1')
    def index1():
        time.sleep(5)
        return f'OK 1 - {int(time.time())}'


    @app.route('/2')
    def index2():
        time.sleep(5)
        return f'OK 2 - {int(time.time())}'


    @app.route('/3')
    def index3():
        return f'OK 3 - {int(time.time())}'

    http_server = HTTPServer(WSGIContainer_With_Thread(app))
    http_server.listen(5000)
    IOLoop.instance().start()

When run this example, an tornado app is listening at 5000, and we can do some tests:

  1. request route '/1' and request route '/2' at the same time, and you are supposed to get response at the same time (both in 5 seconds)

  2. request route '/1' and request route '/3' at the same time, and you are supposed to get response from route '/3' immediately, and get response from route '/1' in 5 seconds

  3. request route '/1' and request route '/1' at the same time (like in different browser tab), and you are supposed to get first response from route '/1' in 5 seconds, and get second response from route '/1' in 10 seconds

Upvotes: 0

skrause
skrause

Reputation: 999

I created a custom WSGIHandler which supports multi-threaded requests for WSGI apps in Tornado by using a ThreadPoolExecutor. All calls into the WSGI app are performed in separate threads, so the main loop stays free even if your WSGI response takes a long time. The following code is based on this Gist and extended so that:

  • You can stream a response (using an iterator response) or large files directly from the WSGI app to the client, so you can keep the memory usage low even when generating large reponses.
  • You can upload large files. If the request body exceeds 1 MB the whole request body is dumped into a temporary file which is then passed to the WSGI app.

Currently the code has only been tested which Python 3.4, so I don't know if it works with Python 2.7. It also hasn't been stress tested yet, but seems to work fine so far.

# tornado_wsgi.py

import itertools
import logging
import sys
import tempfile
from concurrent import futures
from io import BytesIO

from tornado import escape, gen, web
from tornado.iostream import StreamClosedError
from tornado.wsgi import to_wsgi_str

_logger = logging.getLogger(__name__)


@web.stream_request_body
class WSGIHandler(web.RequestHandler):
    thread_pool_size = 20

    def initialize(self, wsgi_application):
        self.wsgi_application = wsgi_application

        self.body_chunks = []
        self.body_tempfile = None

    def environ(self, request):
        """
        Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment.
        """
        hostport = request.host.split(":")
        if len(hostport) == 2:
            host = hostport[0]
            port = int(hostport[1])
        else:
            host = request.host
            port = 443 if request.protocol == "https" else 80

        if self.body_tempfile is not None:
            body = self.body_tempfile
            body.seek(0)
        elif self.body_chunks:
            body = BytesIO(b''.join(self.body_chunks))
        else:
            body = BytesIO()

        environ = {
            "REQUEST_METHOD": request.method,
            "SCRIPT_NAME": "",
            "PATH_INFO": to_wsgi_str(escape.url_unescape(request.path, encoding=None, plus=False)),
            "QUERY_STRING": request.query,
            "REMOTE_ADDR": request.remote_ip,
            "SERVER_NAME": host,
            "SERVER_PORT": str(port),
            "SERVER_PROTOCOL": request.version,
            "wsgi.version": (1, 0),
            "wsgi.url_scheme": request.protocol,
            "wsgi.input": body,
            "wsgi.errors": sys.stderr,
            "wsgi.multithread": False,
            "wsgi.multiprocess": True,
            "wsgi.run_once": False,
        }
        if "Content-Type" in request.headers:
            environ["CONTENT_TYPE"] = request.headers.pop("Content-Type")
        if "Content-Length" in request.headers:
            environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length")
        for key, value in request.headers.items():
            environ["HTTP_" + key.replace("-", "_").upper()] = value
        return environ

    def prepare(self):
        # Accept up to 2GB upload data.
        self.request.connection.set_max_body_size(2 << 30)

    @gen.coroutine
    def data_received(self, chunk):
        if self.body_tempfile is not None:
            yield self.executor.submit(lambda: self.body_tempfile.write(chunk))
        else:
            self.body_chunks.append(chunk)

            # When the request body grows larger than 1 MB we dump all receiver chunks into
            # a temporary file to prevent high memory use. All subsequent body chunks will
            # be directly written into the tempfile.
            if sum(len(c) for c in self.body_chunks) > (1 << 20):
                self.body_tempfile = tempfile.NamedTemporaryFile('w+b')
                def copy_to_file():
                    for c in self.body_chunks:
                        self.body_tempfile.write(c)
                    # Remove the chunks to clear the memory.
                    self.body_chunks[:] = []
                yield self.executor.submit(copy_to_file)

    @web.asynchronous
    @gen.coroutine
    def get(self):
        data = {}
        response = []

        def start_response(status, response_headers, exc_info=None):
            data['status'] = status
            data['headers'] = response_headers
            return response.append

        environ = self.environ(self.request)
        app_response = yield self.executor.submit(self.wsgi_application, environ, start_response)
        app_response = iter(app_response)

        if not data:
            raise Exception('WSGI app did not call start_response')

        try:
            exhausted = object()

            def next_chunk():
                try:
                    return next(app_response)
                except StopIteration:
                    return exhausted

            for i in itertools.count():
                chunk = yield self.executor.submit(next_chunk)
                if i == 0:
                    status_code, reason = data['status'].split(None, 1)
                    status_code = int(status_code)
                    headers = data['headers']
                    self.set_status(status_code, reason)
                    for key, value in headers:
                        self.set_header(key, value)
                    c = b''.join(response)
                    if c:
                        self.write(c)
                        yield self.flush()
                if chunk is not exhausted:
                    self.write(chunk)
                    yield self.flush()
                else:
                    break
        except StreamClosedError:
            _logger.debug('stream closed early')
        finally:
            # Close the temporary file to make sure that it gets deleted.
            if self.body_tempfile is not None:
                try:
                    self.body_tempfile.close()
                except OSError as e:
                    _logger.warning(e)

            if hasattr(app_response, 'close'):
                yield self.executor.submit(app_response.close)

    post = put = delete = head = options = get

    @property
    def executor(self):
        cls = type(self)
        if not hasattr(cls, '_executor'):
            cls._executor = futures.ThreadPoolExecutor(cls.thread_pool_size)
        return cls._executor

The following is a simple Flask app which demonstrates the WSGIHandler. The hello() function blocks for one second, so if your ThreadPoolExecutor uses 20 threads you will be able to load 20 requests at the same time (in one second).

The stream() function creates an iterator response and streams 50 chunks of data to the client within 5 seconds. It should be noted that it will probably not be possible to use Flask's stream_with_context decorator here: Since each load from the iterator results in a new executor.submit(), it's very likely that different chunks from the streaming response will be loading from different threads, breaking Flask's use of thread-locals.

import time
from flask import Flask, Response
from tornado import ioloop, log, web
from tornado_wsgi import WSGIHandler

def main():
    app = Flask(__name__)

    @app.route("/")
    def hello():
        time.sleep(1)
        return "Hello World!"

    @app.route("/stream")
    def stream():
        def generate():
            for i in range(50):
                time.sleep(0.1)
                yield '%d\n' % i
        return Response(generate(), mimetype='text/plain')

    application = web.Application([
        (r'/.*', WSGIHandler, {'wsgi_application': app}),
    ])

    log.enable_pretty_logging()
    application.listen(8888)
    ioloop.IOLoop.instance().start()

if __name__ == '__main__':
    main()

Upvotes: 7

Jakob Simon-Gaarde
Jakob Simon-Gaarde

Reputation: 725

You could use Ladon's task-type methods for these long-duration operations.

It provides a framework solution for these type of situations.

Ladon Tasks documentation

Upvotes: 0

Ben Darnell
Ben Darnell

Reputation: 22154

Tornado's WSGI container is not very scalable and should only be used when you have a specific reason to combine WSGI and Tornado applications in the same process. Tornado does not support long-running WSGI requests without blocking; anything that may take a long time needs to use Tornado's native asynchronous interfaces instead of WSGI.

See the warning in the docs:

WSGI is a synchronous interface, while Tornado’s concurrency model is based on single-threaded asynchronous execution. This means that running a WSGI app with Tornado’s WSGIContainer is less scalable than running the same app in a multi-threaded WSGI server like gunicorn or uwsgi. Use WSGIContainer only when there are benefits to combining Tornado and WSGI in the same process that outweigh the reduced scalability.

Upvotes: 3

mehdix
mehdix

Reputation: 5164

You can consider using tornado-threadpool, in that case your request will return immediately and the task will complete in background.

from thread_pool import in_thread_pool
from flask import flash

@app.route('/wait')
def wait():
    time_consuming_task()
    flash('Time consuming task running in backround...')
    return render_template('index.html')

@in_thread_pool
def time_consuming_task():
    import time
    time.sleep(5)

Upvotes: 2

Related Questions