Rajan Sharma
Rajan Sharma

Reputation: 2273

Python mock multiple queries in a function using pytest_mock

I am writing unit test case for a function which has multiple sql queries in it.I am using psycopg2 module and trying to mock the cursor.

app.py

import psycopg2

def my_function():
    # all connection related code goes here ...

    query = "SELECT name,phone FROM customer WHERE name='shanky'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    customer_response = []
    for row in cursor.fetchall():
        customer_response.append(dict(zip(columns, row)))

    query = "SELECT name,id FROM product WHERE name='soap'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    product_response = []
    for row in cursor.fetchall():
        product_response.append(dict(zip(columns, row)))

    return product_response

test.py

from pytest_mock import mocker
import psycopg2

def test_my_function(mocker):
    from my_module import app
    mocker.patch('psycopg2.connect')

    #first query
    mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_one.description = [['name'],['phone']]
    mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
    mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"

    #second query
    mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_two.description = [['name'],['id']]
    mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
    mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"

    ret = app.my_function()
    assert ret == {'name' : 'nirma', 'id' : 12313}

But the mocker always takes the last mock object (the second query).I have already tried multiple hacks, but that didn't work out. How can i mock multiple queries in one function and successfully pass the unit test case? Is it possible to write a unit test case in this fashion or do i need to split the queries in different functions?

Upvotes: 7

Views: 10805

Answers (3)

Max Sirwa
Max Sirwa

Reputation: 151

As I have mentioned in an earlier comment, the best way to make unit testing portable is to develop a complete Mock of your database's behavior. I've done it for MySQL but it's pretty much the same for all databases.

First of all, I like using wrapper classes over the packages I'm using, it helps quickly change the database at one place instead of changing it everywhere in the code.

Here's a samople of what I use as a wrapper:

Now, you would need to Mock this MySQL class:

# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------


import logging


logger = logging.getLogger(__name__)


class Database:
    """Database Metaclass"""

    def __init__(self, connect_func, **kwargs):
        self.connection = connect_func(**kwargs)

    def execute(self, statement, fetchall=True):
        """Execute a statement.

        Execute the statement passed as arugment.

        Args:
            statement (str): SQL Query or Command to execute.

        Returns:
            set: List of returned objects by the cursor.
        """
        cursor = self.connection.cursor()
        logger.debug(f"Executing: {statement}")
        cursor.execute(statement)
        if fetchall:
            return cursor.fetchall()
        else:
            return cursor.fetchone()

    def __del__(self):
        """Close connection on object deletion."""
        self.connection.close()

And the mysql module:

# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------


import logging
import mysql.connector

from . import _database


logger = logging.getLogger(__name__)


class MySQL(_database.Database):
    """Snowflake Database Class Wrapper.

    Attributes:
        connection (obj): Object returned from mysql.connector.connect
    """

    def __init__(self, autocommit=True, **kwargs):
        super().__init__(connect_func=mysql.connector.connect, **kwargs)
        self.connection.autocommit = autocommit

Instantiate like: db = MySQL(user='...', password='...', ...)

Here's the data file:

# database_mock_data.json
{
    "customer": {
        "name": [
            "shanky",
            "nirma"
        ],
        "phone": [
            123123123,
            232342342
        ]
    },
    "product": {
        "name": [
            "shanky",
            "nirma"
        ],
        "id": [
            1,
            2
        ]
    }
}

The mocks.py

# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'


class MockDatabase(MySQL):
    """
    """
    def __init__(self, **kwargs):
        self.connection = MockConnection()


class MockConnection:
    """
    Mock the connection object by returning a mock cursor.
    """
    @staticmethod
    def cursor():
        return MockCursor()


class MockCursor:
    """
    The Mocked Cursor

    A call to execute() will initiate the read on the json data file and will set
    the description object (containing the column names usually).

    You could implement an update function like `_json_sql_update()`
    """
    def __init__(self):
        self.description = []
        self.__result = None

    def execute(self, statement):
        data = _read_json_file(_MOCK_DATA_PATH)
        if statement.upper().startswith('SELECT'):
            self.__result, self.description = _json_sql_select(data, statement)

    def fetchall(self):
        return self.__result

    def fetchone(self):
        return self.__result[0]


def _json_sql_select(data, query):
    """
    Takes a dictionary and returns the values from a sql query.
    NOTE: It does not work with other where clauses than '='.
          Also, note that a where statement is expected.
    :param (dict) data: Dictionary with the following structure:
                        {
                            'tablename': {
                                'column_name_1': ['value1', 'value2],
                                'column_name_2': ['value1', 'value2],
                                ...
                            },
                            ...
                        }
    :param (str) query: An update sql query as:
                        `update TABLENAME set column_name_1='value'
                        where column_name_2='value1'`
    :return: List of list of values and header description
    """
    try:
        match = (re.search("select(.*)from(.*)where(.*)[;]?", query,
                 re.IGNORECASE | re.DOTALL).groups())
    except AttributeError:
        print("Select Query pattern mismatch... {}".format(query))
        raise

    # Parse values from the select query
    tablename = match[1].strip().upper()

    columns = [col.strip().upper() for col in match[0].split(",")]
    if columns == ['*']:
        columns = data[tablename].keys()

    where = [cmd.upper().strip().replace(' ', '')
             for cmd in match[2].split('and')]

    # Select values
    selected_values = []
    nb_lines = len(list(data[tablename].values())[0])
    for i in range(nb_lines):
        is_match = True
        for condition in where:
            key_condition, value_condition = (_clean_string(condition)
                                              .split('='))
            if data[tablename][key_condition][i].upper() != value_condition:
                # Set flag to yes
                is_match = False
        if is_match:
            sub_list = []
            for column in columns:
                sub_list.append(data[tablename][column][i])
            selected_values.append(sub_list)

    # Usual descriptor has nested list
    description = zip(columns, ['...'] * len(columns))

    return selected_values, description


def _read_json_file(file_path):
    with open(file_path, 'r') as f_in:
        data = json.load(f_in)
    return data

And then you have your test in a test_module_yourfunction.py

import pytest

def my_function(db, query):
    # Code goes here

@pytest.fixture
def db_connection():
    return MockDatabase()


@pytest.mark.parametrize(
    ("query", "expected"),
    [
        ("SELECT name,phone FROM customer WHERE name='shanky'", {'name' : 'nirma', 'id' : 12313}),
        ("<second query goes here>", "<second result goes here>")
    ]
)
def test_my_function(db_connection, query, expected):
    assert my_function(db_connection, query) == expected

Now I'm sorry if you can't copy/paste this code and make it work, but you get the feeling :) just trying to help

Upvotes: 2

Rajan Sharma
Rajan Sharma

Reputation: 2273

After drilling a lot through the documentation, I was able to achieve this with the help of unittest mock decorator and side_effect which was suggested by @Pavel Vergeev.I was able to write a unit test case that is good enough to test the functionality.

from unittest import mock
from my_module import app

@mock.patch('psycopg2.connect')
def test_my_function(mocked_db):

    mocked_cursor = mocked_db.return_value.cursor.return_value

    description_mock = mock.PropertyMock()
    type(mocked_cursor).description = description_mock

    fetchall_return_one = [('shanky', '347539593')]

    fetchall_return_two = [('nirma', 12313)]

    descriptions = [
        [['name'],['phone']],
        [['name'],['id']]
    ]

    mocked_cursor.fetchall.side_effect = [fetchall_return_one, fetchall_return_two]

    description_mock.side_effect = descriptions

    ret = app.my_function()

    # assert whether called with mocked side effect objects
    mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)

    # assert db query count is 2
    assert mocked_db.return_value.cursor.return_value.execute.call_count == 2

    # first query
    query1 = """
            SELECT name,phone FROM customer WHERE name='shanky'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1

    # second query
    query2 = """
            SELECT name,id FROM product WHERE name='soap'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2

    # assert the data of response
    assert ret == {'name' : 'nirma', 'id' : 12313}

In addition to this if there are dynamic parameters in the query, that can be asserted too by the following method.

assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)

so when the first query is executed, cursor.execute(query,(parameter_name,)) at call_args_list[0][0][0] the query can be obtained and asserted, at call_args_list[0][0][1] the first parameter parameter_name can be obtained. similarly incrementing the index, all the other params and different queries can be obtained and asserted.

Upvotes: 3

Pavel Vergeev
Pavel Vergeev

Reputation: 3380

Try side_effect argument of mocker.patch:

from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2

def test_my_function(mocker):
    from my_module import app
    mocker.patch('psycopg2.connect', side_effect=[MagicMock(), MagicMock()])

    #first query
    mocked_cursor_one = psycopg2.connect().cursor.return_value  # note that we actually call psyocpg2.connect -- it's important
    mocked_cursor_one.description = [['name'],['phone']]
    mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
    mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"

    #second query
    mocked_cursor_two = psycopg2.connect().cursor.return_value
    mocked_cursor_two.description = [['name'],['id']]
    mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
    mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"

    assert mocked_cursor_one is not mocked_cursor_two  # show that they are different

    ret = app.my_function()
    assert ret == {'name' : 'nirma', 'id' : 12313}

As per the docs, side_effect allows you to change returned value each time the patched object is called:

If you pass in an iterable, it is used to retrieve an iterator which must yield a value on every call. This value can either be an exception instance to be raised, or a value to be returned from the call to the mock

Upvotes: 2

Related Questions