Sagar Ghanwat
Sagar Ghanwat

Reputation: 15

Mocking a http.get function isnt working in pytest

The test file contains the test file is in status_update_job/tests and sync_mdl_status.py is in status_update/lib/ folders

from lib.sync_mdl_status import update_mdl_status    
class TestStatusUpdaterMdl(unittest.TestCase):
   def setUp(self):
    self.postgresql = testing.postgresql.Postgresql()
    db_init(self.postgresql)

   def tearDown(self):
    db_ops.close_session()
    self.postgresql.stop()

   def test_status_updater_mdl(self):
    mock_mdl_http_get = patch('lib.sync_mdl_status.http.get').start()
    mock_mdl_http_get.return_value.status_code = 200

    batch = update_mdl_status(self.postgresql.url(), "test", 1, aws_auth_token)

The update_mdl_status file contains -

import json
import time
from functools import partial
import multiprocessing
from requests.adapters import HTTPAdapter
from urllib3 import Retry
import db_ops.main as db_operations
import logging as log
import pandas as pd
import numpy as np
import requests
import pandas
from lib.utils import execute_batch_query,is_valid_http_status_code
from scripts.aws_secrets import fetch_aws_secret

http = requests.Session()
http.mount("http://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))
http.mount("https://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))


    def fetch_mdl_token(mdl_auth_aws_secret_name, region: str):
        mdl_auth_tokens = fetch_aws_secret(mdl_auth_aws_secret_name, region)
        return mdl_auth_tokens


    def fetch_mdl_status(mdl_auth_tokens: dict, mdl_endpoint: str, row: dict):
        mdl_id = row['mdl_id']
        print(mdl_auth_tokens)
        mdl_auth_token_to_use = mdl_auth_tokens[row['obfuscate_info']['db_name']]
        mdl_headers = {'Content-Type': 'application/json', 'Authorization': 'MDL-AUTH ' + mdl_auth_token_to_use}
        res = http.get(f"{mdl_endpoint}/{mdl_id}", headers=mdl_headers)

    def process_mdl_batch(batch: pandas.DataFrame, mdl_endpoint: str, mdl_auth_tokens: dict):
        batch['status'] = batch['status'].apply(lambda x: x.strip())
        batch['new_status'] = batch['status']
        return batch.apply(lambda row: fetch_mdl_status(mdl_auth_tokens, mdl_endpoint, row), axis=1)

    def update_mdl_status(mdl_con_url: str, schema: str, mdl_threads: int, mdl_auth_tokens: dict):
        start = time.time()
        log.info("Getting MDL Token and endpoint details")

        log.info("Getting one batch_id to fetch status from MDL")
        batch_info = db_operations.db.get_a_batch_to_process_its_status()
        print(batch_info)

        if batch_info is not None:
            batch_id = list(batch_info.keys())[0]
            region = batch_info[batch_id]
        else:
            log.info("There are not pending/obfuscated batch ids to process")
            return None

        mdl_endpoint_for_current_batch = fetch_mdl_endpoint(region)

        log.info(f"Current mdl batch to fetch the status for is {batch_id}")
        df = db_operations.db.get_pending_mdl_batch(batch_id)
        if df is None:
            log.info("No pending/obfuscated MDL requests in batch_table_mapping.")
            return batch_id

        size = len(df.index)
        log.debug(f"Retrieving MDL status for {size} batch requests")
        # multi thread MDL API calls to fetch request status
        batch = np.array_split(df, min(size, mdl_threads))
        pool = multiprocessing.Pool(mdl_threads)

        res = pd.concat(pool.map(partial(process_mdl_batch, mdl_endpoint=mdl_endpoint_for_current_batch,
                                         mdl_auth_tokens=mdl_auth_tokens), batch))

The mocking of http.get isnt working. Tried using patch and fixtures but it isnt working. An actual call is being made to the api instead of using mock response.Tried mocking lib.sync_mdl_status.requests.Session.get but same error

Upvotes: 0

Views: 77

Answers (1)

Daviid
Daviid

Reputation: 1596

IF update_mdl_status is a function inside 'lib/sync_mdl_status.py'

def update_mdl_status():
  http = requests.Session()
  http.mount("http://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))
  http.mount("https://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))
  http.get(f"{mdl_endpoint}/{mdl_id}", headers=mdl_headers)

Then you can't patch it like patch('lib.sync_mdl_status.http.get')

you're probably importing requests inside lib.sync_mdl_status so you probably want to patch('lib.sync_mdl_status.requests.Session.get')

If you defined http = requests.Session() outside 'update_mdl_status()' then maybe your original code would also run.

Edit:

Here's a MRE:

import unittest
from unittest.mock import patch
from sync_mdl_status import update_mdl_status

class TestUpdateMdlStatus(unittest.TestCase):
    
    @patch('sync_mdl_status.requests.Session.get')
    def test_status_updater_mdl(self, mock_get):
        # if you don't want to use decorators, remove mock_get argument and...
        # with patch('sync_mdl_status.requests.Session.get') as mock_get:
        mock_get.return_value.status_code = 200

        db_url = "http://example.com/db"
        mdl_id = "test"
        mdl_endpoint = "http://example.com/mdl"
        aws_auth_token = "dummy_token"
        
        batch = update_mdl_status(db_url, mdl_id, mdl_endpoint, aws_auth_token)
        print(batch)
        
        mock_get.assert_called_with(
            f"http://example.com/mdl/test",
            headers={
                'Authorization': 'Bearer dummy_token',
                'Content-Type': 'application/json'
            }
        )
        print("Test completed")

if __name__ == "__main__":
    unittest.main()

sync_mdl_status.py

import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

class TimeoutHTTPAdapter(HTTPAdapter):
    def __init__(self, *args, **kwargs):
        self.timeout = kwargs.pop("timeout", None)
        super().__init__(*args, **kwargs)

    def send(self, request, **kwargs):
        timeout = kwargs.get("timeout")
        if timeout is None:
            kwargs["timeout"] = self.timeout
        return super().send(request, **kwargs)

HTTP_RETRY_STRATEGY = Retry(
    total=3,
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=["HEAD", "GET", "OPTIONS"]
)

def update_mdl_status(db_url, mdl_id, mdl_endpoint, aws_auth_token):
    http = requests.Session()
    http.mount("http://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))
    http.mount("https://", TimeoutHTTPAdapter(max_retries=HTTP_RETRY_STRATEGY))

    mdl_headers = {
        'Authorization': f'Bearer {aws_auth_token}',
        'Content-Type': 'application/json'
    }
    
    response = http.get(f"{mdl_endpoint}/{mdl_id}", headers=mdl_headers)

    if response.status_code == 200:
        print(f"update_mdl_status(): status_code is 200")
        mdl_data = response.json()
        return mdl_data
    else:
        response.raise_for_status()

# Example of a database update function placeholder
def update_database(db_url, mdl_data):
    pass

These run and I get the mocked status_code and json, if you can't get this to work post the full code and error.

Upvotes: 0

Related Questions