sakurashinken
sakurashinken

Reputation: 4070

How to test a decorator in a python package

I am writing my first python package and I want to write unit tests for the following decorator:

class MaxTriesExceededError(Exception):
    pass

def tries(max_tries=3, error_message=os.strerror(errno.ETIME)):
    def decorator(func):
        try_count = 0
        def wrapper(*args, **kwargs):
            try_count+=1
            try:
                if try_count <= max_tries:
                    result = func(*args,**kwargs)
                    return result
                else:
                    raise MaxTriesExceededError(error_message)
            except:
                if try_count <= max_tries:
                    wrapper(*args,**kwargs)
                else:
                    raise Exception

        return wraps(func)(wrapper)

    return decorator

The purpose of the decorator is to throw an error if the function fails more than max_tries, but to eat the error and try again if the max try count has not been exceeded. To be honest, I'm not sure that the code doesn't have bugs. My question is therefore twofold, is the code correct, and how do I write unit tests for it using unittest?

Upvotes: 7

Views: 7895

Answers (1)

Adi Levin
Adi Levin

Reputation: 5233

Here is a corrected version, with unittests:

class MaxTriesExceededError(Exception):
    pass

def tries(max_tries=3, error_message="failure"):
    def decorator(func):
        def wrapper(*args, **kwargs):
            for try_count in range(max_tries):
              try:
                return func(*args,**kwargs)
              except:
                pass
            raise MaxTriesExceededError(error_message)
        return wrapper
    return decorator


import unittest

class TestDecorator(unittest.TestCase):

  def setUp(self):
      self.count = 0

  def test_success_single_try(self):
      @tries(1)
      def a():
          self.count += 1
          return "expected_result"
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 1)

  def test_success_two_tries(self):
      @tries(2)
      def a():
          self.count += 1
          return "expected_result"
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 1)

  def test_failure_two_tries(self):
      @tries(2)
      def a():
           self.count += 1
           raise Exception()
      try:
        a()
        self.fail()
      except MaxTriesExceededError:
        self.assertEqual(self.count,2)

  def test_success_after_third_try(self):
      @tries(5)
      def a():
           self.count += 1
           if self.count==3:
             return "expected_result"
           else:
             raise Exception()
      self.assertEqual(a(), "expected_result")
      self.assertEqual(self.count, 3)

if __name__ == '__main__':
    unittest.main()

Upvotes: 16

Related Questions