nop
nop

Reputation: 6311

Is it safe to replace Dictionary with ConcurrentDictionary and what modifications should be made?

I wonder if it's safe to replace Dictionary with ConcurrentDictionary and what modifications should I do to for ex. TryAdd, TryGetValue, removing locks, etc.?

protected class SubscriptionManager
{
    private readonly DeribitV2Client _client;
    private readonly Dictionary<string, SubscriptionEntry> _subscriptionMap;

    public SubscriptionManager(DeribitV2Client client)
    {
        _client = client;
        _subscriptionMap = new Dictionary<string, SubscriptionEntry>();
    }

    public async Task<SubscriptionToken> Subscribe(ISubscriptionChannel channel, Action<Notification> callback)
    {
        if (callback == null)
        {
            return SubscriptionToken.Invalid;
        }

        var channelName = channel.ToChannelName();
        TaskCompletionSource<SubscriptionToken> taskSource = null;
        SubscriptionEntry entry;

        lock (_subscriptionMap)
        {
            if (!_subscriptionMap.TryGetValue(channelName, out entry))
            {
                entry = new SubscriptionEntry();
                if (!_subscriptionMap.TryAdd(channelName, entry))
                {
                    _client.Logger?.Error("Subscribe: Could not add internal item for channel {Channel}", channelName);
                    return SubscriptionToken.Invalid;
                }

                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Entry already exists but is completely unsubscribed
            if (entry.State == SubscriptionState.Unsubscribed)
            {
                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Already subscribed - Put the callback in there and let's go
            if (entry.State == SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channelName);
                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            // We are in the middle of unsubscribing from the channel
            if (entry.State == SubscriptionState.Unsubscribing)
            {
                _client.Logger?.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }
        }

        // Only one state left: Subscribing

        // We are already subscribing
        if (taskSource == null && entry.State == SubscriptionState.Subscribing)
        {
            _client.Logger?.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channelName);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }

            _client.Logger?.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channelName);
            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (taskSource == null)
        {
            _client.Logger?.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channelName);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var subscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/subscribe" : "public/subscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = subscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                _client.Logger?.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channelName, response);
                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                _client.Logger?.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channelName);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;
            Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public async Task<bool> Unsubscribe(SubscriptionToken token)
    {
        string channelName;
        SubscriptionEntry entry;
        SubscriptionCallback callbackEntry;
        TaskCompletionSource<bool> taskSource;

        lock (_subscriptionMap)
        {
            (channelName, entry, callbackEntry) = GetEntryByToken(token);

            if (string.IsNullOrEmpty(channelName) || entry == null || callbackEntry == null)
            {
                _client.Logger?.Warning("Unsubscribe: Could not find token {token}", token.Token);
                return false;
            }

            switch (entry.State)
            {
                case SubscriptionState.Subscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channelName);
                    return false;
                case SubscriptionState.Unsubscribed:
                case SubscriptionState.Unsubscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channelName);
                    entry.Callbacks.Remove(callbackEntry);
                    return true;
                case SubscriptionState.Subscribed:
                    if (entry.Callbacks.Count > 1)
                    {
                        _client.Logger?.Debug("Unsubscribe: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channelName);
                        entry.Callbacks.Remove(callbackEntry);
                        return true;
                    }

                    _client.Logger?.Debug("Unsubscribe: No callbacks left. Unsubscribe and remove callback (Channel: {Channel})", channelName);
                    break;
                default:
                    return false;
            }

            // At this point it's only possible that the entry-State is Subscribed
            // and the callback list is empty after removing this callback.
            // Hence we unsubscribe at the server now
            entry.State = SubscriptionState.Unsubscribing;
            taskSource = new TaskCompletionSource<bool>();
            entry.UnsubscribeTask = taskSource.Task;
        }

        try
        {
            var unsubscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/unsubscribe" : "public/unsubscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = unsubscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);
                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(true);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptionMap.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    public void Reset()
    {
        _subscriptionMap.Clear();
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }

    private (string channelName, SubscriptionEntry entry, SubscriptionCallback callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        lock (_subscriptionMap)
        {
            foreach (var kvp in _subscriptionMap)
            {
                foreach (var callbackEntry in kvp.Value.Callbacks)
                {
                    if (callbackEntry.Token == token)
                    {
                        return (kvp.Key, kvp.Value, callbackEntry);
                    }
                }
            }
        }

        return (null, null, null);
    }
}

GitHub

My attempt

public class SubscriptionToken
{
    public static readonly SubscriptionToken Invalid = new(Guid.Empty);

    public SubscriptionToken(Guid token)
    {
        Token = token;
    }

    public Guid Token { get; }
}

public class SubscriptionCallback
{
    public SubscriptionCallback(SubscriptionToken token, Action<Notification> action)
    {
        Token = token;
        Action = action;
    }

    public Action<Notification> Action { get; }
    public SubscriptionToken Token { get; }
}

public class SubscriptionEntry
{
    public List<SubscriptionCallback> Callbacks { get; } = new();
    public Task<SubscriptionToken>? SubscribeTask { get; set; }
    public Task<bool>? UnsubscribeTask { get; set; }
    public SubscriptionState State { get; set; } = SubscriptionState.Unsubscribed;
}

public class SubscriptionManager
{
    private readonly DeribitClient _client;
    private readonly ConcurrentDictionary<string, SubscriptionEntry> _subscriptions = new();

    public SubscriptionManager(DeribitClient client)
    {
        _client = client ?? throw new ArgumentNullException(nameof(client));
    }

    public async Task<SubscriptionToken> SubscribeAsync(string channel, Action<Notification>? callback)
    {
        if (callback == null)
        {
            throw new ArgumentNullException(nameof(callback));
        }

        TaskCompletionSource<SubscriptionToken>? tcs = null;

        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            if (entry.State == SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            if (entry.State == SubscriptionState.Unsubscribing)
            {
                Log.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            if (entry.State == SubscriptionState.Unsubscribed)
            {
                Log.Debug("Subscribe: Entry already exists but is completely unsubscribed (Channel: {Channel})", channel);

                tcs = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = tcs.Task;
            }
        }
        else
        {
            tcs = new TaskCompletionSource<SubscriptionToken>();
            entry = new SubscriptionEntry
            {
                State = SubscriptionState.Subscribing,
                SubscribeTask = tcs.Task
            };

            if (!_subscriptions.TryAdd(channel, entry))
            {
                Log.Error("Subscribe: Could not add internal item for channel {Channel}", channel);
                return SubscriptionToken.Invalid;
            }
        }

        if (tcs == null && entry.State == SubscriptionState.Subscribing)
        {
            Log.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channel);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask.ConfigureAwait(false) != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            Log.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channel);

            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (tcs == null)
        {
            Log.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channel);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/subscribe" : "public/subscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var subscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (subscribeResponse == null)
            {
                Log.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channel, subscribeResponse);

                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                Log.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);

                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;

            Debug.Assert(tcs != null);

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    public async Task<bool> UnsubscribeAsync(SubscriptionToken token)
    {
        TaskCompletionSource<bool> tcs;

        var (channel, entry, callbackEntry) = GetEntryByToken(token);

        if (string.IsNullOrEmpty(channel) || entry == null || callbackEntry == null)
        {
            Log.Warning("UnsubscribeAsync: Could not find token {token}", token.Token);
            return false;
        }

        switch (entry.State)
        {
            case SubscriptionState.Subscribing:
                Log.Debug("UnsubscribeAsync: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channel);
                return false;
            case SubscriptionState.Unsubscribed:
            case SubscriptionState.Unsubscribing:
                Log.Debug("UnsubscribeAsync: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed when entry.Callbacks.Count > 1:
                Log.Debug("UnsubscribeAsync: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed:
                Log.Debug("UnsubscribeAsync: No callbacks left. UnsubscribeAsync and remove callback (Channel: {Channel})", channel);
                tcs = new TaskCompletionSource<bool>();
                entry.State = SubscriptionState.Unsubscribing;
                entry.UnsubscribeTask = tcs.Task;
                break;
            default:
                return false;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/unsubscribe" : "public/unsubscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var unsubscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (unsubscribeResponse == null)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);

                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(true);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    private (string? channelName, SubscriptionEntry? entry, SubscriptionCallback? callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        foreach (var (key, value) in _subscriptions)
        {
            foreach (var callbackEntry in value.Callbacks.Where(callbackEntry => callbackEntry.Token == token))
            {
                return (key, value, callbackEntry);
            }
        }

        return (null, null, null);
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }
}

Upvotes: 0

Views: 765

Answers (1)

Theodor Zoulias
Theodor Zoulias

Reputation: 43390

A ConcurrentDictionary<K,V> is thread-safe in the sense that it protects its internal state from corruption. It doesn't protect the keys and values it contains, in case these are mutable objects.

In your case the values stored in the dictionary (SubscriptionEntry) are mutable objects. They have public setters, and they expose a public property of type List<SubscriptionCallback>. The List<T> class is not thread-safe. So, no, you can't replace the Dictionary with a ConcurrentDictionary the way you've shown in the question (the My attempt section). Here are some options:

  1. Make sure that the SubscriptionCallback type is immutable. If you want to change it, create a new SubscriptionCallback instance and discard the previous one.
  2. Keep the SubscriptionCallback mutable, but make it thread-safe.
  3. Just keep the Dictionary, and forget about switching to ConcurrentDictionary. The overhead of a lock is minuscule, provided that you are not doing anything non trivial while holding the lock. If you are doing only basic operations (Add/TryGetValue/Remove), it's unlikely that you'll notice any measurable contention, unless you are doing 100,000 operations per second or more.

Upvotes: 1

Related Questions