dell
dell

Reputation: 13

How to Unit Test a Python Class Which Needs to Make an API Call to an External Service?

The class is meant to wrap a remotely hosted large language model and has to make an API call to the service to fetch the results. This is an example.


class ModelWrapper(AbstractLLMInterface):
    """The Claude 3 Sonnet model wrapper following the interface."""

    def __init__(
        self,
        region: str = CONFIGS["GCP"]["REGION"],
        project: str = get_gcp_project_id(),
        model: str = CONFIGS["MODELS"]["SONNET_ID"],
    ) -> None:
        """Set up the claude client using the region and project."""
        self.client: AnthropicVertex = AnthropicVertex(region=region, project_id=project)
        self.model_name: str = model
        self.role_key: str = "role"
        self.content_key: str = "content"
        self.user_key: str = "user"
        llm_logger.debug(msg=f"Initialised sonnet client for {region}, {project} and {self.model_name}.")

    def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
        """
        Fetch a response from the model.This requires an egress request to GCP and the
        service for Anthropic model must be enabled in the VertexAI console.
        """

        # This is where the API call to GCP service happens
        return self.client.messages.stream(user_prompt, system_prompt, history)

I am aware of MagicMocks for python unit testing which are objects that can be configured to return anything I want. But in this instance, all the parameters of the constructor are simple strings, and I the client is built inside the class. So there is no scope for a mock client, right?

Does it mean any unit test must make an API call? Or is the class designed incorrectly? Any help will be appreciated.

The class should be testable without having to call the API.

Note

I know the changing the class constructor to accept an AnthropicVertex object make the mocking trivially simple. But the purpose of the question is to test this class as it exists, not to change the class.

However, if the current design of the class violates any basic design principle (such as SOLID), that is something I would love to know and understand. So any reference will be appreciated.

Upvotes: 0

Views: 71

Answers (2)

defalt
defalt

Reputation: 320

import importlib
from unittest import mock

import anthropic 

import your_module
from your_module import ModelWrapper

with mock.patch.object(target=anthropic, attribute="AnthropicVertex"):
    importlib.reload(module=your_module)
    assert isinstance(ModelWrapper().client, mock.MagicMock)

If you instantiate your class ModelWrapper under the with block, AnthropicVertex will remain patched.

Upvotes: 1

quantumsurge
quantumsurge

Reputation: 3

I think the most trivial way to do this would be to pass the client through a constructor, rather than instantiating it internally:

class ModelWrapper(AbstractLLMInterface):
    def __init__(
        self,
        client: Optional[AnthropicVertex] = None,
        region: str = CONFIGS["GCP"]["REGION"],
        project: str = get_gcp_project_id(),
        model: str = CONFIGS["MODELS"]["SONNET_ID"],
    ) -> None:
        if client is None:
            # Build the real client
            client = AnthropicVertex(region=region, project_id=project)
        self.client = client
        self.model_name: str = model

    def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
        return self.client.messages.stream(user_prompt, system_prompt, history)

Then you can test like so:

from unittest.mock import MagicMock

def test_model_wrapper():
    mock_client = MagicMock()
    mock_client.messages.stream.return_value = iter(["mocked response"])

    wrapper = ModelWrapper(client=mock_client)
    responses = list(wrapper.get_completion("prompt", "system", []))

    assert responses == ["mocked response"]

Upvotes: 0

Related Questions