user8082934
user8082934

Reputation: 451

Mocking database calls using SQLalchemy in python using pytest-mock

I have this function which makes a call to database using SQLalchemy ORM.

def get_cust_id(self,_session_id):
    cust_id: str = self.sqlalchemy_orm.get_customer_id_sqlalchemy(_session_id)
    
    if cust_id is None:
        return _session_id
    return cust_id


def get_customer_id_sqlalchemy(self, _session_id: int) -> str:
    try:
        if self.session is None:
            self.session = self._connection.get_session()
        
        _res_id =  self.session.query(Users).with_entities(Users.cust_id).filter(
            Users.age > 20).order_by(Users.id.desc()).one_or_none()

        if _res_id is None:
            return _res_id
        else:
            return __res_id[0]

    finally:
        if not self.persistent_connection:
            self.session.close()

I want to unit test the get_cust_id()and mock database calls, how should I do it using pytest and mock?

Upvotes: 3

Views: 9336

Answers (1)

gold_cy
gold_cy

Reputation: 14236

Without knowing too much else about your class, mocking get_cust_id is pretty straightforward. We see that it makes a call to some attribute of another object so all we would need do to is mock self.sqlalchemy_orm. The last piece is to check the two branches of logic present here, which is when cust_id is None and when it is not. To check both scenarios at once using the same test we can use pytest.mark.parametrize.

import pytest

from unittest.mock import MagicMock

@pytest.mark.parametrize("val,expected", [(None, 3), (5, 5)])
def test_get_cust_id(val, expected):
    db = MyDB()
    mock_session = MagicMock()
    
    mock_session.configure_mock(
        **{
            "get_customer_id_sqlalchemy.return_value": val
        }
    )
    setattr(db, "sqlalchemy_orm", mock_session)
    id_ = db.get_cust_id(expected)

    assert id_ == expected

Which when run yields the following.

=========================================== test session starts ===========================================
platform darwin -- Python 3.9.1, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
cachedir: .pytest_cache
collected 2 items                                                                                         

test_db.py::test_get_cust_id[None-3] PASSED                                                         [ 50%]
test_db.py::test_get_cust_id[5-5] PASSED                                                            [100%]

============================================ 2 passed in 0.06s ============================================

Upvotes: 3

Related Questions