Ben RR
Ben RR

Reputation: 943

Python unittest and multithreading

I am using python's unittest and would like to write a test that starts a few threads and waits for them to finish. The threads execute a function that has some unittest assertions. If any of the assertions fail, I wish the test to, well, fail. This does not seem to be the case.

EDIT: Minimal runnable example (python3)

import unittest
import threading

class MyTests(unittest.TestCase):

    def test_sample(self):
        t = threading.Thread(target=lambda: self.fail())
        t.start()
        t.join()

if __name__ == '__main__':
    unittest.main()

and the output is:

sh-4.3$ python main.py -v                                                                                                                                                                                                              
test_sample (__main__.MyTests) ... Exception in thread Thread-1:                                                                                                                                                                       
Traceback (most recent call last):                                                                                                                                                                                                     
  File "/usr/lib64/python2.7/threading.py", line 813, in __bootstrap_inner                                                                                                                                                             
    self.run()                                                                                                                                                                                                                         
  File "/usr/lib64/python2.7/threading.py", line 766, in run                                                                                                                                                                           
    self.__target(*self.__args, **self.__kwargs)                                                                                                                                                                                       
  File "main.py", line 7, in <lambda>                                                                                                                                                                                                  
    t = threading.Thread(target=lambda: self.fail())                                                                                                                                                                                   
  File "/usr/lib64/python2.7/unittest/case.py", line 450, in fail                                                                                                                                                                      
    raise self.failureException(msg)                                                                                                                                                                                                   
AssertionError: None                                                                                                                                                                                                                   

ok                                                                                                                                                                                                                                     

----------------------------------------------------------------------                                                                                                                                                                 
Ran 1 test in 0.002s                                                                                                                                                                                                                   

OK     

Upvotes: 16

Views: 33901

Answers (3)

Thomas Grainger
Thomas Grainger

Reputation: 2431

use a concurrent.futures.ThreadPoolExecutor or https://docs.python.org/3/library/threading.html#threading.excepthook to collect exceptions thrown in threads

import unittest
import threading
from concurrent import futures

class catch_threading_exception:
    """
    https://docs.python.org/3/library/test.html#test.support.catch_threading_exception
    Context manager catching threading.Thread exception using
    threading.excepthook.

    Attributes set when an exception is catched:

    * exc_type
    * exc_value
    * exc_traceback
    * thread

    See threading.excepthook() documentation for these attributes.

    These attributes are deleted at the context manager exit.

    Usage:

        with support.catch_threading_exception() as cm:
            # code spawning a thread which raises an exception
            ...

            # check the thread exception, use cm attributes:
            # exc_type, exc_value, exc_traceback, thread
            ...

        # exc_type, exc_value, exc_traceback, thread attributes of cm no longer
        # exists at this point
        # (to avoid reference cycles)
    """

    def __init__(self):
        self.exc_type = None
        self.exc_value = None
        self.exc_traceback = None
        self.thread = None
        self._old_hook = None

    def _hook(self, args):
        self.exc_type = args.exc_type
        self.exc_value = args.exc_value
        self.exc_traceback = args.exc_traceback
        self.thread = args.thread

    def __enter__(self):
        self._old_hook = threading.excepthook
        threading.excepthook = self._hook
        return self

    def __exit__(self, *exc_info):
        threading.excepthook = self._old_hook
        del self.exc_type
        del self.exc_value
        del self.exc_traceback
        del self.thread


class MyTests(unittest.TestCase):
    def test_tpe(self):
        with futures.ThreadPoolExecutor() as pool:
            pool.submit(self.fail).result()

    def test_t_excepthook(self):
        with catch_threading_exception() as cm:
            t = threading.Thread(target=self.fail)
            t.start()
            t.join()
            if cm.exc_value is not None:
                raise cm.exc_value


if __name__ == '__main__':
    unittest.main()

on pytest these are collected for you: https://docs.pytest.org/en/latest/how-to/failures.html?highlight=unraisable#warning-about-unraisable-exceptions-and-unhandled-thread-exceptions

Upvotes: 7

Fred S
Fred S

Reputation: 995

Your test isn't failing for the same reason that this code will print "no exception"

import threading

def raise_err():
    raise Exception()

try:
    t = threading.Thread(target=raise_err)
    t.start()
    t.join()
    print('no exception')
except:
    print('caught exception')

When unittest runs your test function, it determines pass/fail by seeing if the code execution results in some exception. If the exception occurs inside the thread, there still is no exception in the main thread.

You could do something like this if you think you HAVE to get a pass/fail result from running something in a thread. But this is really not how unittest is designed to work, and there's probably a much easier way to do what you're trying to accomplish.

import threading
import unittest

def raise_err():
    raise Exception()
def no_err():
    return

class Runner():

    def __init__(self):
        self.threads = {}
        self.thread_results = {}

    def add(self, target, name):
        self.threads[name] = threading.Thread(target = self.run, args = [target, name])
        self.threads[name].start()

    def run(self, target, name):
        self.thread_results[name] = 'fail'
        target()
        self.thread_results[name] = 'pass'

    def check_result(self, name):
        self.threads[name].join()
        assert(self.thread_results[name] == 'pass')

runner = Runner()

class MyTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        runner.add(raise_err, 'test_raise_err')
        runner.add(no_err, 'test_no_err')

    def test_raise_err(self):
        runner.check_result('test_raise_err')

    def test_no_err(self):
        runner.check_result('test_no_err')

if __name__ == '__main__':
    unittest.main()

Upvotes: 4

Grumbel
Grumbel

Reputation: 7043

Python unittest assertions are communicated by exceptions, so you have to ensure that the exceptions end up in the main thread. So for a thread that means you have to run .join(), as that will throw the exception from the thread over into the main thread:

    t = threading.Thread(target=lambda: self.assertTrue(False))
    t.start()
    t.join()

Also make sure that you don't have any try/except blocks that might eat up the exception before the unittest can register them.

Edit: self.fail() is indeed not communicated when called from a thread, even if .join() is present. Not sure what's up with that.

Upvotes: 3

Related Questions