Reputation: 148
I am trying to write pytest cases for below spanner DB call
class Database:
_client = None
def __init__(self, instance_id: str, database_id: str, pool=None):
if not Database._client:
Database._client = SpannerClient()
instance = Database._client.instance(instance_id)
self._database = instance.database(database_id, pool=pool)
def execute_query(
self, query: str, params: Dict | None = None, param_types: Dict | None = None
):
try:
with self._database.snapshot() as snapshot:
results = snapshot.execute_sql(query, params, param_types)
df = DataFrame(
data=[row for row in results],
columns=[col.name for col in results.fields],
)
return df
except GoogleAPICallError as e:
print(f"Error code:{e.code},Error message: {e.message}")
raise GoogleAPIError() from e
def get_database(self):
return self._database
Below is the pytest code that is currently written:
import pytest
from unittest.mock import patch, MagicMock
from database import Database
from google.api_core.exceptions import GoogleAPICallError
from pandas import DataFrame
@pytest.fixture
def mock_spanner_client():
with patch("database.Database._client") as MockClient:
yield MockClient
@pytest.fixture
def mock_instance(mock_spanner_client):
mock_instance = MagicMock()
mock_spanner_client.instance.return_value = mock_instance
yield mock_instance
@pytest.fixture
def mock_database(mock_instance):
mock_database = MagicMock()
mock_instance.database.return_value = mock_database
yield mock_database
def test_database_initialization(mock_spanner_client, mock_instance, mock_database):
db = Database("test_instance", "test_database")
assert db._database == mock_database
def test_get_database(mock_database):
db = Database("test_instance", "test_database")
assert db.get_database() == mock_database
def test_execute_query_success(mock_database):
mock_snapshot = MagicMock()
mock_snapshot.execute_sql.return_value = MagicMock(
__iter__=lambda self: iter([["row1"], ["row2"]]),
fields=[MagicMock(name="col1")],
)
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
db = Database("test_instance", "test_database")
query = "SELECT * FROM test_table"
result = db.execute_query(query)
assert isinstance(result, DataFrame)
assert not result.empty
assert list(result.columns) == ["col1"]
Need help with mocking the execute sql part where spanner returns a StreamedResultSet Iterator. mock_snapshot.execute_sql.return_value is giving the below error.
ValueError: 0 columns passed, passed data had 1 columns
Upvotes: 0
Views: 31
Reputation: 148
Mocked the iterator by creating a custom iterator as below, since the prev setup was not taking up values for columns and raised error for not passing columns.
def mock_spanner_results():
class MockField:
def __init__(self, name, code):
self.name = name
self.type_ = {'code': code}
class MockRow:
def __init__(self, values):
self.values = values
fields = [
MockField("a", "STRING"),
MockField("b", "STRING"),
MockField("c", "TIMESTAMP"),
]
rows = [
MockRow([1,2,3]),
MockRow([4,5,6]),
]
class MockResults:
def __init__(self, fields, rows):
self.fields = fields
self.rows = rows
def __iter__(self):
for row in self.rows:
yield row.values
return MockResults(fields, rows)
This is yielding results and creating dataframe as expected.
Upvotes: 0