Reputation: 1937
I have created the function below which would wait for all the tasks to complete or raise an exception when a cancellation or time-out occurs.
public static async Task WhenAll(
IEnumerable<Task> tasks,
CancellationToken cancellationToken,
int millisecondsTimeOut)
{
Task timeoutTask = Task.Delay(millisecondsTimeOut, cancellationToken);
Task completedTask = await Task.WhenAny(
Task.WhenAll(tasks),
timeoutTask
);
if (completedTask == timeoutTask)
{
throw new TimeoutException();
}
}
If all the tasks
finished before a long time-out (i.e. millisecondsTimeOut
= 60,000), would timeoutTask
be staying around until 60 seconds has elapsed even after the function returns? If yes, what is the best way to fix the runaway problem?
Upvotes: 3
Views: 1109
Reputation: 116538
Yes, the timeoutTask
would hang around until that timeout is over (or the CancellationToken
is canceled).
You can fix it by passing in a different CancellationToken
you get from a new CancellationTokenSource
you create using CancellationTokenSource.CreateLinkedTokenSource
and cancelling at the end. You should also await the completed task, otherwise you aren't really observing exceptions (or cancellations):
public static async Task WhenAll(
IEnumerable<Task> tasks,
CancellationToken cancellationToken,
int millisecondsTimeOut)
{
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var timeoutTask = Task.Delay(millisecondsTimeOut, cancellationTokenSource.Token);
var completedTask = await Task.WhenAny(Task.WhenAll(tasks), timeoutTask);
if (completedTask == timeoutTask)
{
throw new TimeoutException();
}
cancellationTokenSource.Cancel();
await completedTask;
}
However, I think there's a simpler way to achieve what you want, if you don't need to distinguish between a TimeoutException
and TaskCancelledException
. You just add a continuation that is cancelled when the CancellationToken
is cancelled or when the timeout is over:
public static Task WhenAll(
IEnumerable<Task> tasks,
CancellationToken cancellationToken,
int millisecondsTimeOut)
{
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cancellationTokenSource.CancelAfter(millisecondsTimeOut);
return Task.WhenAll(tasks).ContinueWith(
_ => _.GetAwaiter().GetResult(),
cancellationTokenSource.Token,
TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);
}
Upvotes: 2