Reputation: 78
I am trying to create a webhook endpoint with FastAPI, and write any json request body that arrives there to RabbitMQ.
I cannot figure out how to connect to RabbitMQ, create the channel, and keep it all alive by being hooked into the FastAPI asyncio loop. None of the other questions or answers here on SO are helping.
My current 'solution' is to start a 2nd thread in a FastAPI app startup method, and use a queue.SimpleQueue to talk between the threads. So the FastAPI path method writes the object to this SimpleQueue and the 2nd thread reads the object from it and publishes it to RabbitMQ.
The problem with this simple-minded approach is that because the 2nd thread is blocking on reading the SimpleQueue the RabbitMQ connection is not getting the keepalives sent over it and the server closes it. My code catches exceptions on write to RabbitMQ and reconnects and tries again but that's pretty ugly.
I cannot understand how to adapt the async examples in either pika or aio-pika map to a FastAPI app.
Is anyone able to show me how, given the trivial FastAPI app below, I can open a RabbitMQ connection that will have the necessary keepalives sent over it to keep it open, such that I can publish via the connection in the path method?
from fastapi import FastAPI, Response
from typing import Any, Dict
app = FastAPI()
@app.on_event("startup")
async def startup() -> None:
# Connect to RabbitMQ
# Create channel
# Declare queue
JSONObject = Dict[str, Any]
@app.post("/webhook")
async def webhook_endpoint(msg: JSONObject):
# Write msg to RabbitMQ channel here.
return Response(status_code=204)
My other idea was to still use a thread but have the thread perform a blocking read on the RabbitMQ connection in the hope that will keep sending keepalives over it, and that it won't interfere with publishing over the same connection from the path method. But that is obviously a hack. I'd prefer to do it the 'proper' way and use async code.
EDIT: There does not seem to be a way to do a blocking read, so and endless loop with channel.basic_get() and sleep instead. Even more hacky.
Upvotes: 1
Views: 9585
Reputation: 78
I have something working by adapting the pika async publisher example.
I changed the example so it creates an AsyncioConnection rather than using SelectConnection, because FastAPI has already started the standard asyncio event loop and I want pika to use that rather than whatever SelectConnection decided to use.
This means the reconnection logic from the example doesn't work as coded so I need to fix that up, it has been removed from the code below.
However, this code does keep the connection alive in the background and I can publish messages in response to hits to the webhook URL. Changing the log level to DEBUG shows the heartbeats being sent and received.
This is not the final code - I'll be changing queues etc but it does show that the AsyncioConnection can be hooked up to FastAPIs already running event loop ok.
import pika
from pika.adapters.asyncio_connection import AsyncioConnection
from pika.exchange_type import ExchangeType
from typing import Any, Dict
import asyncio, json, logging, os, queue, threading, time
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s: %(message)s', datefmt='%Y-%m-%dT%H:%M:%S%z')
logger = logging.getLogger(__name__)
class AsyncioRabbitMQ(object):
EXCHANGE = 'message'
EXCHANGE_TYPE = ExchangeType.topic
PUBLISH_INTERVAL = 1
QUEUE = 'text'
ROUTING_KEY = 'example.text'
def __init__(self, amqp_url):
self._connection = None
self._channel = None
self._deliveries = []
self._acked = 0
self._nacked = 0
self._message_number = 0
self._stopping = False
self._url = amqp_url
def connect(self):
logger.info('Connecting to %s', self._url)
return AsyncioConnection(
pika.URLParameters(self._url),
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_closed)
def on_connection_open(self, connection):
logger.info('Connection opened')
self._connection = connection
logger.info('Creating a new channel')
self._connection.channel(on_open_callback=self.on_channel_open)
def on_connection_open_error(self, _unused_connection, err):
logger.error('Connection open failed: %s', err)
def on_connection_closed(self, _unused_connection, reason):
logger.warning('Connection closed: %s', reason)
self._channel = None
def on_channel_open(self, channel):
logger.info('Channel opened')
self._channel = channel
self.add_on_channel_close_callback()
self.setup_exchange(self.EXCHANGE)
def add_on_channel_close_callback(self):
logger.info('Adding channel close callback')
self._channel.add_on_close_callback(self.on_channel_closed)
def on_channel_closed(self, channel, reason):
logger.warning('Channel %i was closed: %s', channel, reason)
self._channel = None
if not self._stopping:
self._connection.close()
def setup_exchange(self, exchange_name):
logger.info('Declaring exchange %s', exchange_name)
# Note: using functools.partial is not required, it is demonstrating
# how arbitrary data can be passed to the callback when it is called
cb = functools.partial(self.on_exchange_declareok, userdata=exchange_name)
self._channel.exchange_declare(exchange=exchange_name, exchange_type=self.EXCHANGE_TYPE, callback=cb)
def on_exchange_declareok(self, _unused_frame, userdata):
logger.info('Exchange declared: %s', userdata)
self.setup_queue(self.QUEUE)
def setup_queue(self, queue_name):
logger.info('Declaring queue %s', queue_name)
self._channel.queue_declare(queue=queue_name, callback=self.on_queue_declareok)
def on_queue_declareok(self, _unused_frame):
logger.info('Binding %s to %s with %s', self.EXCHANGE, self.QUEUE, self.ROUTING_KEY)
self._channel.queue_bind(self.QUEUE, self.EXCHANGE, routing_key=self.ROUTING_KEY, callback=self.on_bindok)
def on_bindok(self, _unused_frame):
logger.info('Queue bound')
self.start_publishing()
def start_publishing(self):
logger.info('Issuing Confirm.Select RPC command')
self._channel.confirm_delivery(self.on_delivery_confirmation)
def on_delivery_confirmation(self, method_frame):
confirmation_type = method_frame.method.NAME.split('.')[1].lower()
logger.info('Received %s for delivery tag: %i', confirmation_type, method_frame.method.delivery_tag)
if confirmation_type == 'ack':
self._acked += 1
elif confirmation_type == 'nack':
self._nacked += 1
self._deliveries.remove(method_frame.method.delivery_tag)
logger.info(
'Published %i messages, %i have yet to be confirmed, '
'%i were acked and %i were nacked', self._message_number,
len(self._deliveries), self._acked, self._nacked)
def publish_message(self, message):
if self._channel is None or not self._channel.is_open:
return
hdrs = { "a": "b" }
properties = pika.BasicProperties(
app_id='example-publisher',
content_type='application/json',
headers=hdrs)
self._channel.basic_publish(self.EXCHANGE, self.ROUTING_KEY,
json.dumps(message, ensure_ascii=False),
properties)
self._message_number += 1
self._deliveries.append(self._message_number)
logger.info('Published message # %i', self._message_number)
app = FastAPI()
ep = None
@app.on_event("startup")
async def startup() -> None:
global ep
await asyncio.sleep(10) # Wait for MQ
user = os.environ['RABBITMQ_DEFAULT_USER']
passwd = os.environ['RABBITMQ_DEFAULT_PASS']
host = os.environ['RABBITMQ_HOST']
port = os.environ['RABBITMQ_PORT']
ep = AsyncioRabbitMQ(f'amqp://{user}:{passwd}@{host}:{port}/%2F')
ep.connect()
JSONObject = Dict[str, Any]
@app.post("/webhook")
async def webhook_endpoint(msg: JSONObject) -> None:
global ep
ep.publish_message(msg)
return Response(status_code=204)
Upvotes: 3