Reputation: 5681
I have this code:
import numpy as np
class Variables(object):
def __init__(self, var_name, the_method):
self.var_name = var_name
self.the_method = the_method
def evaluate_v(self):
var_name, the_method = self.var_name, self.the_method
if the_method == 'diff':
return var_name[0] - var_name[1]
and this test code:
import unittest
import pytest
import numpy as np
from .variables import Variables
class TestVariables():
@classmethod
def setup_class(cls):
var_name = np.array([[1, 2, 3], [2, 3, 4]])
the_method = 'diff'
cls.variables = Variables(var_name, the_method)
@pytest.mark.parametrize(
"var_name, the_method, expected_output", [
(np.array([[1, 2, 3], [2, 3, 4]]), 'diff', np.array([-1, -1, -1]) ),
])
def test_evaluate_v_method_returns_correct_results(
self, var_name, the_method,expected_output):
var_name, the_method = self.variables.var_name, self.variables.the_method
obs = self.variables.evaluate_v()
assert obs == expected_output
if __name__ == "__main__":
unittest.main()
where I want to compute the difference between the first and last elements.
The result should be an array [-1, -1, -1]
.
If I try to run the test, it gives me:
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
I am not sure how to use (and if I have to) the np.all()
in my case.
Upvotes: 1
Views: 1209
Reputation: 85462
assert np.all(obs == expected_output)
works:
def test_evaluate_v_method_returns_correct_results(
self, var_name, the_method,expected_output):
var_name, the_method = self.variables.var_name, self.variables.the_method
obs = self.variables.evaluate_v()
assert np.all(obs == expected_output)
Test it:
py.test np_test.py
================================== test session starts ===================================
platform darwin -- Python 3.5.2, pytest-2.9.2, py-1.4.31, pluggy-0.3.1
rootdir: /Users/mike/tmp, inifile:
plugins: hypothesis-3.4.0, asyncio-0.4.1
collected 1 items
np_test.py .
================================ 1 passed in 0.10 seconds ================================
Upvotes: 1