user1050619
user1050619

Reputation: 20856

python mocking sqlalchemy connection

I have a simple function that connects to a DB and fetches some data.

db.py

from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool


def _create_engine(app):
    impac_engine = create_engine(
        app['DB'],
        poolclass=NullPool  # this setting enables NOT to use Pooling, preventing from timeout issues.
    )
    return impac_engine


def get_all_pos(app):
    engine = _create_engine(app)
    qry = """SELECT DISTINCT id, name FROM p_t ORDER BY name ASC"""
    try:
        cursor = engine.execute(qry)
        rows = cursor.fetchall()
        return rows
    except Exception as re:
        raise re

I'm trying to write some test cases by mocking this connection -

tests.py

import unittest
from db import get_all_pos
from unittest.mock import patch
from unittest.mock import Mock


class TestPosition(unittest.TestCase):

    @patch('db.sqlalchemy')
    def test_get_all_pos(self, mock_sqlalchemy):
        mock_sqlalchemy.create_engine = Mock()
        get_all_pos({'DB': 'test'})




if __name__ == '__main__':
    unittest.main()

When I run the above file python tests.py, I get the following error -

   "Could not parse rfc1738 URL from string '%s'" % name
sqlalchemy.exc.ArgumentError: Could not parse rfc1738 URL from string 'test'

Shouldn't mock_sqlalchemy.create_engine = Mock() give me a mock object and bypass the URL check.


Upvotes: 9

Views: 17347

Answers (1)

gold_cy
gold_cy

Reputation: 14216

Another option would be to mock your _create_engine function. Since this is a unit test and we want to test get_all_pos we shouldn't need to rely on the behavior of _create_engine, so we can just patch that like so.

import unittest
import db
from unittest.mock import patch


class TestPosition(unittest.TestCase):

    @patch.object(db, '_create_engine')
    def test_get_all_pos(self, mock_sqlalchemy):
        args = {'DB': 'test'}
        db.get_all_pos(args)
        mock_sqlalchemy.assert_called_once()
        mock_sqlalchemy.assert_called_with({'DB': 'test'})


if __name__ == '__main__':
    unittest.main()

If you want to test certain results you will need to properly set all the corresponding attributes. I would recommend not chaining it into one call so that it is more readable as shown below.

import unittest
import db
from unittest.mock import patch
from unittest.mock import Mock


class Cursor:
    def __init__(self, vals):
        self.vals = vals

    def fetchall(self):
        return self.vals


class TestPosition(unittest.TestCase):

    @patch.object(db, '_create_engine')
    def test_get_all_pos(self, mock_sqlalchemy):
        to_test = [1, 2, 3]

        mock_cursor = Mock()
        cursor_attrs = {'fetchall.return_value': to_test}
        mock_cursor.configure_mock(**cursor_attrs)

        mock_execute = Mock()
        engine_attrs = {'execute.return_value': mock_cursor}
        mock_execute.configure_mock(**engine_attrs)

        mock_sqlalchemy.return_value = mock_execute

        args = {'DB': 'test'}
        rows = db.get_all_pos(args)

        mock_sqlalchemy.assert_called_once()
        mock_sqlalchemy.assert_called_with({'DB': 'test'})
        self.assertEqual(to_test, rows)

Upvotes: 5

Related Questions