Guiorgy
Guiorgy

Reputation: 1734

Consume all messages in a System.Threading.Channels.Channel

Suppose I have a many producers, 1 consumer unbound Channel, with a consumer:

await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
}

The problem is that the consume function does some IO access and potentially some network access too, thus before 1 message is consumed many more may be produced. But since the IO resources can't be accessed concurently, I can't have many consumers, nor can I throw the consume function into a Task and forget it.

The consume function is such that it can be easily modified to take multiple messages and handle them all in a batch. So my question is if there's a way to make the consumer take all messages in the channel queue whenever it tries to access it, something like this:

while (true) {
    Message[] messages = await channel.Reader.TakeAll();
    await consumeAll(messages);
}

Edit: 1 option that I can come up with, is:

List<Message> messages = new();
await foreach (var message in channel.Reader.ReadAllAsync(cts.Token))
{
    await consume(message);
    Message msg;
    while (channel.Reader.TryRead(out msg))
        messages.Add(msg);
    if (messages.Count > 0)
    {
        await consumeAll(messages);
        messages.Clear();
    }
}

But I feel like thare should be a better way to do this.

Upvotes: 5

Views: 6092

Answers (2)

Theodor Zoulias
Theodor Zoulias

Reputation: 43545

Here is a slightly more polished version of the ReadBatchesAsync method in spender's answer. It features also an optional maxSize parameter, that can set an upper limit to the size of each batch:

/// <summary>
/// Consumes the items in the channel in batches. Each batch contains all
/// the items that are immediately available, up to a specified maximum number.
/// </summary>
public static IAsyncEnumerable<T[]> ReadBatchImmediateAsync<T>(
    this ChannelReader<T> source, int maxSize = -1)
{
    ArgumentNullException.ThrowIfNull(source);
    if (maxSize == -1) maxSize = Array.MaxLength;
    if (maxSize < 1) throw new ArgumentOutOfRangeException(nameof(maxSize));
    return Implementation();

    async IAsyncEnumerable<T[]> Implementation(
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        while (await source.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
        {
            List<T> buffer = new();
            while (buffer.Count < maxSize && source.TryRead(out T item))
                buffer.Add(item);
            if (buffer.Count > 0)
                yield return buffer.ToArray();
        }
    }
}

Usage example:

await foreach (Item[] batch in channel.Reader.ReadBatchImmediateAsync())
{
    await ConsumeBatch(batch);
}

Similarly to spender's answer, the above implementation is non-destructive. No elements that have been consumed from the source channel can be lost, neither in case of cancellation, nor in case of breaking out from the consuming await foreach loop.

A CancellationToken can be injected with the standard WithCancellation operator.

The above implementation guarantees that each emitted T[] batch will contain at least one element, and at most maxSize elements.

Upvotes: 4

spender
spender

Reputation: 120460

After reading Stephen Toub's primer on channels, I had a stab at writing an extension method that should do what you need (It's been a while since I did any C#, so this was fun).

public static class ChannelReaderEx
{
    public static async IAsyncEnumerable<IEnumerable<T>> ReadBatchesAsync<T>(
        this ChannelReader<T> reader, 
        [EnumeratorCancellation] CancellationToken cancellationToken = default
    )
    {
        while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
        {
            yield return reader.Flush().ToList();
        }
    }

    public static IEnumerable<T> Flush<T>(this ChannelReader<T> reader)
    {
        while (reader.TryRead(out T item))
        {
            yield return item;
        }
    }
}

which can be used like this:

await foreach (var batch in channel.Reader.ReadBatchesAsync())
{
    await ConsumeBatch(batch);
}

Upvotes: 14

Related Questions