Reuven Bass
Reuven Bass

Reputation: 680

How to get the current task reference?

How can I get reference to the task my code is executed within?

ISomeInterface impl = new SomeImplementation();
Task.Factory.StartNew(() => impl.MethodFromSomeInterface(), new MyState());

...

void MethodFromSomeInterface()
{
    Task currentTask = Task.GetCurrentTask();    // No such method?
    MyState state = (MyState) currentTask.AsyncState();
}

Since I'm calling some interface method, I can't just pass the newly created task as an additional parameter.

Upvotes: 19

Views: 26317

Answers (5)

Tim Lovell-Smith
Tim Lovell-Smith

Reputation: 16125

If you could change the interface (which was not a constraint for me when I had a similar problem), to me it seemed like Lazy<Task> could be used to solve this OK. So I tried it out.

It works, at least for what I want 'the current task' to mean. But it is subtle code, because AsyncMethodThatYouWantToRun has to do Task.Yield().

If you don't yield, it will fail with System.AggregateException: 'One or more errors occurred. (ValueFactory attempted to access the Value property of this instance.)'

Lazy<Task> eventuallyATask = null; // silly errors about uninitialized variables :-/
eventuallyATask = new Lazy<Task>(
    () => AsyncMethodThatYouWantToRun(eventuallyATask));

Task t = eventuallyATask.Value; // actually start the task!

async Task AsyncMethodThatYouWantToRun(Lazy<Task> lazyThisTask)
{
    await Task.Yield(); // or else, the 'task' object won't finish being created!

    Task thisTask = lazyThisTask.Value;
    Console.WriteLine("you win! Your task got a reference to itself");
}

t.Wait();

Alternatively instead of the subtlety of Task.Yield we could just go tasks all the way, and use TaskCompletionSource<Task> to solve it. (eliminating any potential errors/deadlocks, since our task safely releases the thread until it can know itself!)

    var eventuallyTheTask = new TaskCompletionSource<Task>();
    Task t = AsyncMethodThatYouWantToRun(eventuallyTheTask.Task); // start the task!
    eventuallyTheTask.SetResult(t); //unblock the task and give it self-knowledge

    async Task AsyncMethodThatYouWantToRun(Task<Task> thisTaskAsync)
    {
        Task thisTask = await thisTaskAsync; // gets this task :)
        Console.WriteLine("you win! Your task got a reference to itself (== 't')");
    }

    t.Wait();

Upvotes: 0

Seth
Seth

Reputation: 1074

If you can use .NET 4.6 or greater, .NET Standard or .NET Core, they've solved this problem with AsyncLocal. https://learn.microsoft.com/en-gb/dotnet/api/system.threading.asynclocal-1?view=netframework-4.7.1

If not, you need to setup a data store somewhen prior to it's use and access it via a closure, not a thread or task. ConcurrentDictionary will help cover up any mistakes you make doing this.

When code awaits, the current task releases the thread - i.e. threads are unrelated to tasks, in the programming model at least.

Demo:

// I feel like demo code about threading needs to guarantee
// it actually has some in the first place :)
// The second number is IOCompletionPorts which would be relevant
// if we were using IO (strangely enough).
var threads = Environment.ProcessorCount * 4;
ThreadPool.SetMaxThreads(threads, threads);
ThreadPool.SetMinThreads(threads, threads);

var rand = new Random(DateTime.Now.Millisecond);

var tasks = Enumerable.Range(0, 50)
    .Select(_ =>
    {
        // State store tied to task by being created in the same closure.
        var taskState = new ConcurrentDictionary<string, object>();
        // There is absolutely no need for this to be a thread-safe
        // data structure in this instance but given the copy-pasta,
        // I thought I'd save people some trouble.

        return Task.Run(async () =>
        {
            taskState["ThreadId"] = Thread.CurrentThread.ManagedThreadId;
            await Task.Delay(rand.Next() % 100);
            return Thread.CurrentThread.ManagedThreadId == (int)taskState["ThreadId"];
        });
    })
    .ToArray();

Task.WaitAll(tasks);
Console.WriteLine("Tasks that stayed on the same thread: " + tasks.Count(t => t.Result));
Console.WriteLine("Tasks that didn't stay on the same thread: " + tasks.Count(t => !t.Result));

Upvotes: 4

hofi
hofi

Reputation: 111

Here is a "hacky" class that can be used for that.
Just use the CurrentTask property to get the current running Task.
I strongly advise against using it anywhere near production code!

public static class TaskGetter
{
    private static string _propertyName;
    private static Type _taskType;
    private static PropertyInfo _property;
    private static Func<Task> _getter;

    static TaskGetter()
    {
        _taskType = typeof(Task);
        _propertyName = "InternalCurrent";
        SetupGetter();
    }

    public static void SetPropertyName(string newName)
    {
        _propertyName = newName;
        SetupGetter();
    }

    public static Task CurrentTask
    {
        get
        {
            return _getter();
        }
    }

    private static void SetupGetter()
    {
        _getter = () => null;
        _property = _taskType.GetProperties(BindingFlags.Static | BindingFlags.NonPublic).Where(p => p.Name == _propertyName).FirstOrDefault();
        if (_property != null)
        {
            _getter = () =>
            {
                var val = _property.GetValue(null);
                return val == null ? null : (Task)val;
            };
        }
    }
}

Upvotes: 4

Ananke
Ananke

Reputation: 1230

The following example shows how it can be achieved, resolving the issue with the answer provided by @stephen-cleary. It is a bit convoluted but essentially the key is in the TaskContext class below which uses CallContext.LogicalSetData, CallContext.LogicalGetData and CallContext.FreeNamedDataSlot which are useful for creating your own Task contexts. The rest of the fluff is to answer the OP's question:

class Program
{
    static void Main(string[] args)
    {
        var t1 = Task.Factory.StartNewWithContext(async () => { await DoSomething(); });
        var t2 = Task.Factory.StartNewWithContext(async () => { await DoSomething(); });

        Task.WaitAll(t1, t2);
    }

    private static async Task DoSomething()
    {
        var id1 = TaskContext.Current.Task.Id;
        Console.WriteLine(id1);
        await Task.Delay(1000);

        var id2 = TaskContext.Current.Task.Id;
        Console.WriteLine(id2);
        Console.WriteLine(id1 == id2);
    }
}

public static class TaskFactoryExtensions
{
    public static Task StartNewWithContext(this TaskFactory factory, Action action)
    {
        Task task = null;

        task = new Task(() =>
        {
            Debug.Assert(TaskContext.Current == null);
            TaskContext.Current = new TaskContext(task);
            try
            {
                action();
            }
            finally
            {
                TaskContext.Current = null;
            }
        });

        task.Start();

        return task;
    }

    public static Task StartNewWithContext(this TaskFactory factory, Func<Task> action)
    {
        Task<Task> task = null;

        task = new Task<Task>(async () =>
        {
            Debug.Assert(TaskContext.Current == null);
            TaskContext.Current = new TaskContext(task);
            try
            {
                await action();
            }
            finally
            {
                TaskContext.Current = null;
            }
        });

        task.Start();

        return task.Unwrap();
    }
}

public sealed class TaskContext
{
    // Use your own unique key for better performance
    private static readonly string contextKey = Guid.NewGuid().ToString();

    public TaskContext(Task task)
    {
        this.Task = task;
    }

    public Task Task { get; private set; }

    public static TaskContext Current
    {
        get { return (TaskContext)CallContext.LogicalGetData(contextKey); }
        internal set
        {
            if (value == null)
            {
                CallContext.FreeNamedDataSlot(contextKey);
            }
            else
            {
                CallContext.LogicalSetData(contextKey, value);
            }
        }
    }
}

Upvotes: 2

Stephen Cleary
Stephen Cleary

Reputation: 456457

Since you can't change the interface nor the implementation, you'll have to do it yourself, e.g., using ThreadStaticAttribute:

static class SomeInterfaceTask
{
  [ThreadStatic]
  static Task Current { get; set; }
}

...

ISomeInterface impl = new SomeImplementation();
Task task = null;
task = Task.Factory.StartNew(() =>
{
  SomeInterfaceTask.Current = task;
  impl.MethodFromSomeInterface();
}, new MyState());

...

void MethodFromSomeInterface()
{
  Task currentTask = SomeInterfaceTask.Current;
  MyState state = (MyState) currentTask.AsyncState();
}

Upvotes: 10

Related Questions