George
George

Reputation: 5681

valuerror use np.all() or np.any() after using an assert

I have this code:

import numpy as np


class Variables(object):

    def __init__(self, var_name, the_method):

        self.var_name = var_name
        self.the_method = the_method

    def evaluate_v(self):
        var_name, the_method = self.var_name, self.the_method

        if the_method == 'diff':
            return var_name[0] - var_name[1]

and this test code:

import unittest
import pytest
import numpy as np

from .variables import Variables


class TestVariables():

    @classmethod
    def setup_class(cls):
        var_name = np.array([[1, 2, 3], [2, 3, 4]])
        the_method = 'diff'
        cls.variables = Variables(var_name, the_method)

    @pytest.mark.parametrize(
        "var_name, the_method, expected_output", [
            (np.array([[1, 2, 3], [2, 3, 4]]), 'diff', np.array([-1, -1, -1]) ),
        ])
    def test_evaluate_v_method_returns_correct_results(
        self, var_name, the_method,expected_output):

        var_name, the_method = self.variables.var_name, self.variables.the_method

        obs = self.variables.evaluate_v()  
        assert obs == expected_output

if __name__ == "__main__":
    unittest.main()

where I want to compute the difference between the first and last elements.

The result should be an array [-1, -1, -1].

If I try to run the test, it gives me:

ValueError: The truth value of an array with more than one element is ambiguous. 
Use a.any() or a.all()

I am not sure how to use (and if I have to) the np.all() in my case.

Upvotes: 1

Views: 1209

Answers (1)

Mike Müller
Mike Müller

Reputation: 85462

assert np.all(obs == expected_output) works:

def test_evaluate_v_method_returns_correct_results(
        self, var_name, the_method,expected_output):

        var_name, the_method = self.variables.var_name, self.variables.the_method

        obs = self.variables.evaluate_v()
        assert np.all(obs == expected_output)

Test it:

py.test np_test.py 
================================== test session starts ===================================
platform darwin -- Python 3.5.2, pytest-2.9.2, py-1.4.31, pluggy-0.3.1
rootdir: /Users/mike/tmp, inifile: 
plugins: hypothesis-3.4.0, asyncio-0.4.1
collected 1 items 

np_test.py .

================================ 1 passed in 0.10 seconds ================================

Upvotes: 1

Related Questions