Taylor
Taylor

Reputation: 2087

How to prematurely kill std::async threads before they are finished *without* using a std::atomic_bool?

I have a function that takes a callback, and used it to do work on 10 separate threads. However, it is often the case that not all of the work is needed. For example, if the desired result is obtained on the third thread, it should stop all work being done on of the remaining alive threads.

This answer here suggests that it is not possible unless you have the callback functions take an additional std::atomic_bool argument, that signals whether the function should terminate prematurely.

This solution does not work for me. The workers are spun up inside a base class, and the whole point of this base class is to abstract away details of multithreading. How can I do this? I am anticipating that I will have to ditch std::async for something more involved.

#include <iostream>
#include <future>
#include <vector>

class ABC{
public:
    std::vector<std::future<int> > m_results;
    ABC() {};
    ~ABC(){};
    virtual int callback(int a) = 0;
    void doStuffWithCallBack();
};


void ABC::doStuffWithCallBack(){

    // start working
    for(int i = 0; i < 10; ++i)
        m_results.push_back(std::async(&ABC::callback, this, i));

    // analyze results and cancel all threads when you get the 1
    for(int j = 0; j < 10; ++j){

        double foo = m_results[j].get();

        if ( foo == 1){
            break;  // but threads continue running
        }

    }
    std::cout << m_results[9].get() << " <- this shouldn't have ever been computed\n";
}

class Derived : public ABC {
public:
    Derived() : ABC() {};
    ~Derived() {};
    int callback(int a){
        std::cout << a << "!\n";
        if (a == 3)
            return 1;
        else
            return 0;
    };
};

int main(int argc, char **argv)
{

    Derived myObj;
    myObj.doStuffWithCallBack();

    return 0;
}

Upvotes: 2

Views: 894

Answers (1)

bob2
bob2

Reputation: 1112

I'll just say that this should probably not be a part of a 'normal' program, since it could leak resources and/or leave your program in an unstable state, but in the interest of science...

If you have control of the thread loop, and you don't mind using platform features, you could inject an exception into the thread. With posix you can use signals for this, on Windows you would have to use SetThreadContext(). Though the exception will generally unwind the stack and call destructors, your thread may be in a system call or other 'non-exception safe place' when the exception occurs.

Disclaimer: I only have Linux at the moment, so I did not test the Windows code.

#if defined(_WIN32)
#   define ITS_WINDOWS
#else
#   define ITS_POSIX
#endif


#if defined(ITS_POSIX)
#include <signal.h>
#endif

void throw_exception() throw(std::string())
{
    throw std::string();
}

void init_exceptions()
{
    volatile int i = 0;
    if (i)
        throw_exception();
}

bool abort_thread(std::thread &t)
{

#if defined(ITS_WINDOWS)

    bool bSuccess = false;
    HANDLE h = t.native_handle();
    if (INVALID_HANDLE_VALUE == h)
        return false;

    if (INFINITE == SuspendThread(h))
        return false;

    CONTEXT ctx;
    ctx.ContextFlags = CONTEXT_CONTROL;
    if (GetThreadContext(h, &ctx))
    {
#if defined( _WIN64 )
        ctx.Rip = (DWORD)(DWORD_PTR)throw_exception;
#else
        ctx.Eip = (DWORD)(DWORD_PTR)throw_exception;
#endif

        bSuccess = SetThreadContext(h, &ctx) ? true : false;
    }

    ResumeThread(h);

    return bSuccess;

#elif defined(ITS_POSIX)

    pthread_kill(t.native_handle(), SIGUSR2);

#endif

    return false;
}


#if defined(ITS_POSIX)

void worker_thread_sig(int sig)
{
    if(SIGUSR2 == sig)
        throw std::string();
}

#endif

void init_threads()
{
#if defined(ITS_POSIX)

    struct sigaction sa;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = 0;
    sa.sa_handler = worker_thread_sig;
    sigaction(SIGUSR2, &sa, 0);

#endif
}

class tracker
{
public:
    tracker() { printf("tracker()\n"); }
    ~tracker() { printf("~tracker()\n"); }
};

int main(int argc, char *argv[])
{
    init_threads();

    printf("main: starting thread...\n");
    std::thread t([]()
    {
        try
        {
            tracker a;

            init_exceptions();

            printf("thread: started...\n");
            std::this_thread::sleep_for(std::chrono::minutes(1000));
            printf("thread: stopping...\n");
        }
        catch(std::string s)
        {
            printf("thread: exception caught...\n");
        }
    });

    printf("main: sleeping...\n");
    std::this_thread::sleep_for(std::chrono::seconds(2));

    printf("main: aborting...\n");
    abort_thread(t);

    printf("main: joining...\n");
    t.join();

    printf("main: exiting...\n");

    return 0;
}

Output:

main: starting thread...
main: sleeping...
tracker()
thread: started...
main: aborting...
main: joining...
~tracker()
thread: exception caught...
main: exiting...

Upvotes: 1

Related Questions