Partho
Partho

Reputation: 19

How to unit test the python code using mock patch?

I want to unit test the following code snippet

def extract_features_from_bytes(self, binary: bytes) -> str:
    with TimingMetric("fx_time") as fx_timing_metric:
        fv = self.feature_extractor.extract_features_from_bytes(binary)
    if self.metrics_reporter:
        self.metrics_reporter.report_metric(fx_timing_metric)
    return fv

Basically, I am trying to collect feature extraction time using TimingMetric and report it using self.metrics_reporter. metrics_reporter is of type MetricsReporter class. Now, I need to mock the following function within the test:

self.feature_extractor.extract_features_from_bytes(binary)

This is because the time taken for feature extraction is uncertain . The idea is to mock the call using a sleep function (e.g. time.sleep(10)), and then check if the time reported to metrics reporter matches the sleep duration. We need to mock feature extractor class actual method and also mock of MetricsReporter is needed. Using patch those mock instances can be passed to the test.

Any idea how we can write unit test for the above code snippet?

I tried to do the following, but it seems the mock function is incorrect:

##################
# Classifier
##################
@pytest.mark.classification
class TestClassifierBase:

    @pytest.fixture
    def mocked_stats_client_and_class(mocker):
        mock_stats_client = MagicMock()
        mock_stats_client_class = mocker.patch("dsci_core.metrics.lib.StatsClient")
        mock_stats_client_class.return_value = mock_stats_client
        return mock_stats_client, mock_stats_client_class

    @pytest.fixture
    def mocked_stats_client_and_reporter(mocked_stats_client_and_class):
        mocked_stats_client, _ = mocked_stats_client_and_class
        reporter = MetricsReporter()
        return mocked_stats_client, reporter

    def test_extract_features_from_bytes(self, test_classifier, mocked_stats_client_and_reporter):
        mock_stats_client, reporter = mocked_stats_client_and_reporter
        metric_name = "my_metric" 
        
        with TimingMetric(metric_name) as timing_metric:
            time.sleep(3)
        assert timing_metric.get_metric_value() == 3000
        reporter.report_metric(timing_metric)
        mock_stats_client.timing.assert_called_with(metric_name, 3000)

Upvotes: 0

Views: 148

Answers (1)

jagmitg
jagmitg

Reputation: 4566

You need to mock the feature_extractor and metrics_reporter objects of the test instance (test_classifier), not create a real MetricsReporter or TimingMetric - Don't get confused.

You will also need to ensure that your test is actually calling extract_features_from_bytes method of the test_classifier to ensure it works.

Heres an example:

from unittest.mock import patch, MagicMock
import time

@pytest.mark.classification
class TestClassifierBase:
    @pytest.fixture
    def feature_extractor_mock(self, mocker):
        mock = mocker.Mock()
        mock.extract_features_from_bytes = mocker.Mock(side_effect=lambda _: time.sleep(3))
        return mock

    @pytest.fixture
    def metrics_reporter_mock(self, mocker):
        return mocker.Mock()

    def test_extract_features_from_bytes(self, test_classifier, feature_extractor_mock, metrics_reporter_mock):
        test_classifier.feature_extractor = feature_extractor_mock
        test_classifier.metrics_reporter = metrics_reporter_mock

        test_classifier.extract_features_from_bytes(b'binary')

        metrics_reporter_mock.report_metric.assert_called_once()

Example using patch calls:

class TestClassifierBase:

    def test_extract_features_from_bytes(self, test_classifier):

        mock_feature_method = MagicMock(side_effect=lambda _: time.sleep(3))

        with patch.object(test_classifier, 'feature_extractor', create=True) as mock_extractor, \
             patch.object(test_classifier, 'metrics_reporter', create=True) as mock_reporter:

            mock_extractor.extract_features_from_bytes = mock_feature_method

            test_classifier.extract_features_from_bytes(b'binary')
                mock_extractor.extract_features_from_bytes.assert_called_once_with(b'binary')
            mock_reporter.report_metric.assert_called_once()

Upvotes: 0

Related Questions