Jakob Buron
Jakob Buron

Reputation: 1238

Make Python unittest fail on exception from any thread

I am using the unittest framework to automate integration tests of multi-threaded python code, external hardware and embedded C. Despite my blatant abuse of a unittesting framework for integration testing, it works really well. Except for one problem: I need the test to fail if an exception is raised from any of the spawned threads. Is this possible with the unittest framework?

A simple but non-workable solution would be to either a) refactor the code to avoid multi-threading or b) test each thread separately. I cannot do that because the code interacts asynchronously with the external hardware. I have also considered implementing some kind of message passing to forward the exceptions to the main unittest thread. This would require significant testing-related changes to the code being tested, and I want to avoid that.

Time for an example. Can I modify the test script below to fail on the exception raised in my_thread without modifying the x.ExceptionRaiser class?

import unittest
import x

class Test(unittest.TestCase):
    def test_x(self):
        my_thread = x.ExceptionRaiser()
        # Test case should fail when thread is started and raises
        # an exception.
        my_thread.start()
        my_thread.join()

if __name__ == '__main__':
    unittest.main()

Upvotes: 8

Views: 3374

Answers (3)

Thematrixme
Thematrixme

Reputation: 318

I've been using the accepted answer above for a while now, but since Python 3.8 the solution doesn't work anymore because the threading module doesn't have this _format_exc import anymore.

On the other hand the threading module now has a nice way to register custom except hooks in Python 3.8 so here is a simple solution to run unit tests which assert that some exceptions are raised inside threads:

def test_in_thread():
    import threading

    exceptions_caught_in_threads = {}

    def custom_excepthook(args):
        thread_name = args.thread.name
        exceptions_caught_in_threads[thread_name] = {
            'thread': args.thread,
            'exception': {
                'type': args.exc_type,
                'value': args.exc_value,
                'traceback': args.exc_traceback
            }
        }

    # Registering our custom excepthook to catch the exception in the threads
    threading.excepthook = custom_excepthook

    # dummy function that raises an exception
    def my_function():
        raise Exception('My Exception')

    # running the funciton in a thread
    thread_1 = threading.Thread(name='thread_1', target=my_function, args=())

    thread_1.start()
    thread_1.join()

    assert 'thread_1' in exceptions_caught_in_threads  # there was an exception in thread 1
    assert exceptions_caught_in_threads['thread_1']['exception']['type'] == Exception
    assert str(exceptions_caught_in_threads['thread_1']['exception']['value']) == 'My Exception'

Upvotes: 2

Ohad
Ohad

Reputation: 2618

At first, sys.excepthook looked like a solution. It is a global hook which is called every time an uncaught exception is thrown.

Unfortunately, this does not work. Why? well threading wraps your run function in code which prints the lovely tracebacks you see on screen (noticed how it always tells you Exception in thread {Name of your thread here}? this is how it's done).

Starting with Python 3.8, there is a function which you can override to make this work: threading.excepthook

... threading.excepthook() can be overridden to control how uncaught exceptions raised by Thread.run() are handled

So what do we do? Replace this function with our logic, and voilà:

For python >= 3.8

import traceback
import threading 
import os


class GlobalExceptionWatcher(object):
    def _store_excepthook(self, args):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        self.__org_hook(args)
        formated_exc = traceback.format_exception(args.exc_type, args.exc_value, args.exc_traceback)
        self._exceptions.append('\n'.join(formated_exc))
        return formated_exc

    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading.excepthook
        threading.excepthook = self._store_excepthook

    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading.excepthook = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception(f'Exceptions in other threads: {tracebacks}')

For older versions of Python, this is a bit more complicated. Long story short, it appears that the threading nodule has an undocumented import which does something along the lines of:

threading._format_exc = traceback.format_exc

Not very surprisingly, this function is only called when an exception is thrown from a thread's run function.

So for python <= 3.7

import threading 
import os

class GlobalExceptionWatcher(object):
    def _store_excepthook(self):
        '''
        Uses as an exception handlers which stores any uncaught exceptions.
        '''
        formated_exc = self.__org_hook()
        self._exceptions.append(formated_exc)
        return formated_exc
        
    def __enter__(self):
        '''
        Register us to the hook.
        '''
        self._exceptions = []
        self.__org_hook = threading._format_exc
        threading._format_exc = self._store_excepthook
        
    def __exit__(self, type, value, traceback):
        '''
        Remove us from the hook, assure no exception were thrown.
        '''
        threading._format_exc = self.__org_hook
        if len(self._exceptions) != 0:
            tracebacks = os.linesep.join(self._exceptions)
            raise Exception('Exceptions in other threads: %s' % tracebacks)

Usage:

my_thread = x.ExceptionRaiser()
# will fail when thread is started and raises an exception.
with GlobalExceptionWatcher():
    my_thread.start()
    my_thread.join()
            

You still need to join yourself, but upon exit, the with-statement's context manager will check for any exception thrown in other threads, and will raise an exception appropriately.


THE CODE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED

This is an undocumented, sort-of-horrible hack. I tested it on linux and windows, and it seems to work. Use it at your own risk.

Upvotes: 6

damzam
damzam

Reputation: 1961

I've come across this problem myself, and the only solution I've been able to come up with is subclassing Thread to include an attribute for whether or not it terminates without an uncaught exception:

from threading import Thread

class ErrThread(Thread):
    """                                                                                                                                                                                               
    A subclass of Thread that will log store exceptions if the thread does                                                                                                                            
    not exit normally                                                                                                                                                                                 
    """
    def run(self):
        try:
            Thread.run(self)
        except Exception as self.err:
            pass
        else:
            self.err = None


class TaskQueue(object):
    """                                                                                                                                                                                               
    A utility class to run ErrThread objects in parallel and raises and exception                                                                                                                     
    in the event that *any* of them fail.                                                                                                                                                             
    """

    def __init__(self, *tasks):

        self.threads = []

        for t in tasks:
            try:
                self.threads.append(ErrThread(**t)) ## passing in a dict of target and args
            except TypeError:
                self.threads.append(ErrThread(target=t))

    def run(self):

        for t in self.threads:
            t.start()
        for t in self.threads:
            t.join()
            if t.err:
                raise Exception('Thread %s failed with error: %s' % (t.name, t.err))

Upvotes: 1

Related Questions