Reputation: 20856
I have a simple function that connects to a DB and fetches some data.
db.py
from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool
def _create_engine(app):
impac_engine = create_engine(
app['DB'],
poolclass=NullPool # this setting enables NOT to use Pooling, preventing from timeout issues.
)
return impac_engine
def get_all_pos(app):
engine = _create_engine(app)
qry = """SELECT DISTINCT id, name FROM p_t ORDER BY name ASC"""
try:
cursor = engine.execute(qry)
rows = cursor.fetchall()
return rows
except Exception as re:
raise re
I'm trying to write some test cases by mocking this connection -
tests.py
import unittest
from db import get_all_pos
from unittest.mock import patch
from unittest.mock import Mock
class TestPosition(unittest.TestCase):
@patch('db.sqlalchemy')
def test_get_all_pos(self, mock_sqlalchemy):
mock_sqlalchemy.create_engine = Mock()
get_all_pos({'DB': 'test'})
if __name__ == '__main__':
unittest.main()
When I run the above file python tests.py, I get the following error -
"Could not parse rfc1738 URL from string '%s'" % name
sqlalchemy.exc.ArgumentError: Could not parse rfc1738 URL from string 'test'
Shouldn't mock_sqlalchemy.create_engine = Mock() give me a mock object and bypass the URL check.
Upvotes: 9
Views: 17347
Reputation: 14216
Another option would be to mock your _create_engine
function. Since this is a unit test and we want to test get_all_pos
we shouldn't need to rely on the behavior of _create_engine
, so we can just patch that like so.
import unittest
import db
from unittest.mock import patch
class TestPosition(unittest.TestCase):
@patch.object(db, '_create_engine')
def test_get_all_pos(self, mock_sqlalchemy):
args = {'DB': 'test'}
db.get_all_pos(args)
mock_sqlalchemy.assert_called_once()
mock_sqlalchemy.assert_called_with({'DB': 'test'})
if __name__ == '__main__':
unittest.main()
If you want to test certain results you will need to properly set all the corresponding attributes. I would recommend not chaining it into one call so that it is more readable as shown below.
import unittest
import db
from unittest.mock import patch
from unittest.mock import Mock
class Cursor:
def __init__(self, vals):
self.vals = vals
def fetchall(self):
return self.vals
class TestPosition(unittest.TestCase):
@patch.object(db, '_create_engine')
def test_get_all_pos(self, mock_sqlalchemy):
to_test = [1, 2, 3]
mock_cursor = Mock()
cursor_attrs = {'fetchall.return_value': to_test}
mock_cursor.configure_mock(**cursor_attrs)
mock_execute = Mock()
engine_attrs = {'execute.return_value': mock_cursor}
mock_execute.configure_mock(**engine_attrs)
mock_sqlalchemy.return_value = mock_execute
args = {'DB': 'test'}
rows = db.get_all_pos(args)
mock_sqlalchemy.assert_called_once()
mock_sqlalchemy.assert_called_with({'DB': 'test'})
self.assertEqual(to_test, rows)
Upvotes: 5