Rudra
Rudra

Reputation: 148

Pytest for Google Spanner Database call

I am trying to write pytest cases for below spanner DB call

class Database:
    _client = None

    def __init__(self, instance_id: str, database_id: str, pool=None):
        if not Database._client:
            Database._client = SpannerClient()
        instance = Database._client.instance(instance_id)
        self._database = instance.database(database_id, pool=pool)

    def execute_query(
        self, query: str, params: Dict | None = None, param_types: Dict | None = None
    ):
        try:
            with self._database.snapshot() as snapshot:
                results = snapshot.execute_sql(query, params, param_types)
                df = DataFrame(
                    data=[row for row in results],
                    columns=[col.name for col in results.fields],
                )
                return df
        except GoogleAPICallError as e:
            print(f"Error code:{e.code},Error message: {e.message}")
            raise GoogleAPIError() from e

    def get_database(self):
        return self._database

Below is the pytest code that is currently written:

import pytest
from unittest.mock import patch, MagicMock
from database import Database
from google.api_core.exceptions import GoogleAPICallError
from pandas import DataFrame


@pytest.fixture
def mock_spanner_client():
    with patch("database.Database._client") as MockClient:
        yield MockClient


@pytest.fixture
def mock_instance(mock_spanner_client):
    mock_instance = MagicMock()
    mock_spanner_client.instance.return_value = mock_instance
    yield mock_instance


@pytest.fixture
def mock_database(mock_instance):
    mock_database = MagicMock()
    mock_instance.database.return_value = mock_database
    yield mock_database


def test_database_initialization(mock_spanner_client, mock_instance, mock_database):
    db = Database("test_instance", "test_database")
    assert db._database == mock_database


def test_get_database(mock_database):
    db = Database("test_instance", "test_database")
    assert db.get_database() == mock_database


def test_execute_query_success(mock_database):
    mock_snapshot = MagicMock()
    mock_snapshot.execute_sql.return_value = MagicMock(
        __iter__=lambda self: iter([["row1"], ["row2"]]),
        fields=[MagicMock(name="col1")],
    )
    mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
    db = Database("test_instance", "test_database")
    query = "SELECT * FROM test_table"
    result = db.execute_query(query)
    assert isinstance(result, DataFrame)
    assert not result.empty
    assert list(result.columns) == ["col1"]

Need help with mocking the execute sql part where spanner returns a StreamedResultSet Iterator. mock_snapshot.execute_sql.return_value is giving the below error.

ValueError: 0 columns passed, passed data had 1 columns

Upvotes: 0

Views: 31

Answers (1)

Rudra
Rudra

Reputation: 148

Mocked the iterator by creating a custom iterator as below, since the prev setup was not taking up values for columns and raised error for not passing columns.

def mock_spanner_results():
    class MockField:
        def __init__(self, name, code):
            self.name = name
            self.type_ = {'code': code}
   
    class MockRow:
        def __init__(self, values):
            self.values = values
   
    fields = [
        MockField("a", "STRING"),
        MockField("b", "STRING"),
        MockField("c", "TIMESTAMP"),
    ]
   
    rows = [
        MockRow([1,2,3]),
        MockRow([4,5,6]),
    ]
   
    class MockResults:
        def __init__(self, fields, rows):
            self.fields = fields
            self.rows = rows
       
        def __iter__(self):
            for row in self.rows:
                yield row.values
   
    return MockResults(fields, rows)

This is yielding results and creating dataframe as expected.

Upvotes: 0

Related Questions