Marcel Alexandru
Marcel Alexandru

Reputation: 600

Get all classes that implement an interface and call a function in .NET Core

How can I get all classes that implements a specific interface then call a function of that class if a string member of the specific class matches a given one?

Basically what I have is a ICommandHandler interface:

interface ICommandHandler
{
    string Command { get; }
    Task ExecuteAsync();
}

and a class that implements it:

public class StartCommand : ICommandHandler
{

    public string Command { get => "start"; }
    public async Task ExecuteAsync()
    {
      // do something
    }
}

What I want to do is to get all the classes that implements the ICommandHandler interface, then verify if the class.Command equals a specific string and if it does then call the ExecuteAsync method.

I've tried using this answer here: https://stackoverflow.com/a/45382386/15306888 but the class.Command is always null

Edit: The answer I got bellow does what I wanted to do:

What I was looking for was a way to use ICommandHandler to allow me to easily gather all the classes inheriting from it and call the ExecuteAsync function instead of having to manually add the methods in the part of the code handling TdLib message events.

So now my project directory looks something like this:

Anyway, in the meantime I've found another answer on a stackoverflow question (Had to scroll a few times) that made it way easier and faster to get multiple handlers without having to repeat the same code over and over again. I've ended by combining the answer linked above with this one: https://stackoverflow.com/a/41650057/15306888

So I ended with a simple method:

        public static IEnumerable<T> GetAll<T>()
        {
            return Assembly.GetExecutingAssembly()
                .GetTypes()
                .Where(type => typeof(T).IsAssignableFrom(type))
                .Where(type =>
                    !type.IsAbstract &&
                    !type.IsGenericType &&
                    type.GetConstructor(new Type[0]) != null)
                .Select(type => (T)Activator.CreateInstance(type))
                .ToList();
        }

That can be easily used like:

      private async Task ProcessMessage(TdApi.Update.UpdateNewMessage message)
        {
            var command = GetCommand(message.Message);
            var textMessage = GetMessageText(message.Message);

            if (!String.IsNullOrWhiteSpace(command))
            {
                var commandHandlers = GetAll<ICommandHandler>();

                foreach (var handler in commandHandlers)
                {
                    if (command == handler.Command)
                        await handler.ExecuteAsync(_client, message.Message);
                }
            }

            else if (!String.IsNullOrWhiteSpace(textMessage))
            {
                var messageHandlers = GetAll<IMessageHandler>();
                foreach (var handler in messageHandlers)
                {
                    var outgoing = handler.Outgoing && message.Message.IsOutgoing;
                    var incoming = handler.Incoming && !message.Message.IsOutgoing;

                    if (outgoing || incoming)
                    {
                        if (!String.IsNullOrEmpty(handler.Pattern))
                        {
                            var match = Regex.Match(textMessage, handler.Pattern);
                            if (match.Success)
                                await handler.ExecuteAsync(_client, message.Message);
                        }
                    }
                }
            }
        }

How the Interface is actually implemented:

using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using TdLib;

namespace YoutubeDl_Bot.Handlers.CommandHandlers
{
    public class StartCommand : ICommandHandler
    {
        public string Command { get => "/start"; }

        public async Task ExecuteAsync(TdClient client, TdApi.Message message)
        {
            var stringBuilder = new StringBuilder();
            stringBuilder.AppendLine("Hi! I'm a bot that downloads and sends video and audio files from youtube links and many other supported services");
            stringBuilder.AppendLine(String.Empty);
            stringBuilder.AppendLine("**Usage:**");
            stringBuilder.AppendLine("• Send or forward a text message containing links and I will:");
            stringBuilder.AppendLine("• Download the best audio quality available for the video in the speecified link");
            stringBuilder.AppendLine("• Download the best video quality available for the video in the speecified link");
            stringBuilder.AppendLine("• Send the direct download URL for every link specified in the message");
            stringBuilder.AppendLine("• Supported links are available here: https://ytdl-org.github.io/youtube-dl/supportedsites.html");

            var formattedText = await client.ExecuteAsync(new TdLib.TdApi.ParseTextEntities { Text = stringBuilder.ToString(), ParseMode = new TdLib.TdApi.TextParseMode.TextParseModeMarkdown() });
            await client.ExecuteAsync(new TdLib.TdApi.SendMessage { ChatId = message.ChatId, InputMessageContent = new TdLib.TdApi.InputMessageContent.InputMessageText { Text = formattedText } });
        }
    }
}

Full TelegramClient class:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using TdLib;
using YoutubeDl_Bot.Handlers.CallbackHandlers;
using YoutubeDl_Bot.Handlers.CommandHandlers;
using YoutubeDl_Bot.Handlers.MessageHandlers;
using YoutubeDl_Bot.Settings;
using YoutubeDl_Bot.Utils;

namespace YoutubeDl_Bot
{
    class TelegramBotClient
    {
        private static TdClient _client;
        private static TdLib.TdApi.User _me;

        private static int _apiId;
        private static string _apiHash;
        private static string _token;

#if DEBUG
        private static readonly int _verbosityLevel = 4;
#else
        private static readonly int _verbosityLevel = 0; 
#endif

        public static bool AuthCompleted = false;

        public TelegramBotClient(int apiId, string apiHash)
        {
            _apiId = apiId;
            _apiHash = apiHash;
        }

        public async Task<TdClient> CreateClient()
        {
            _client = new TdClient();
            await _client.ExecuteAsync(new TdApi.SetLogVerbosityLevel { NewVerbosityLevel = _verbosityLevel });
            return _client;
        }

        public void StartListening(string botToken)
        {
            _token = botToken;
            _client.UpdateReceived += _client_UpdateReceived;
        }

        private async void _client_UpdateReceived(object sender, TdApi.Update update)
        {
            switch (update)
            {
                case TdApi.Update.UpdateAuthorizationState updateAuthorizationState when updateAuthorizationState.AuthorizationState.GetType() == typeof(TdApi.AuthorizationState.AuthorizationStateWaitTdlibParameters):
                    await _client.ExecuteAsync(new TdApi.SetTdlibParameters 
                    {
                        Parameters = new TdApi.TdlibParameters
                        {
                            ApiId = _apiId,
                            ApiHash = _apiHash,
                            ApplicationVersion = "0.0.1",
                            DeviceModel = "Bot",
                            SystemLanguageCode = "en",
                            SystemVersion = "Unknown"
                        }
                    });
                    break;
                case TdApi.Update.UpdateAuthorizationState updateAuthorizationState when updateAuthorizationState.AuthorizationState.GetType() == typeof(TdLib.TdApi.AuthorizationState.AuthorizationStateWaitEncryptionKey):
                    await _client.ExecuteAsync(new TdLib.TdApi.CheckDatabaseEncryptionKey());
                    break;
                case TdLib.TdApi.Update.UpdateAuthorizationState updateAuthorizationState when updateAuthorizationState.AuthorizationState.GetType() == typeof(TdLib.TdApi.AuthorizationState.AuthorizationStateWaitPhoneNumber):
                    await _client.ExecuteAsync(new TdLib.TdApi.CheckAuthenticationBotToken { Token = _token });
                    break;
                case TdLib.TdApi.Update.UpdateConnectionState updateConnectionState when updateConnectionState.State.GetType() == typeof(TdLib.TdApi.ConnectionState.ConnectionStateReady):
                    // To Do Settings
                    var botSettings = new BotSettings(_apiId, _apiHash, _token);
                    _me = await _client.ExecuteAsync(new TdLib.TdApi.GetMe());
                    Helpers.Print($"Logged in as: {_me.FirstName}");

                    SettingsManager.Set<BotSettings>("BotSettings.data", botSettings);
                    break;
                case TdLib.TdApi.Update.UpdateNewMessage message:
                    if (!message.Message.IsOutgoing)
                        await ProcessMessage(message);
                    break;
                case TdApi.Update.UpdateNewCallbackQuery callbackQuery:
                    await ProcessCallbackQuery(callbackQuery);
                    break;
                default:
                    break;
            }
        }

        #region PROCESS_MESSAGE
        private async Task ProcessMessage(TdApi.Update.UpdateNewMessage message)
        {
            var command = GetCommand(message.Message);
            var textMessage = GetMessageText(message.Message);

            #region COMMAND_HANDLERS

            if (!String.IsNullOrWhiteSpace(command))
            {
                var commandHandlers = GetAll<ICommandHandler>();

                foreach (var handler in commandHandlers)
                {
                    if (command == handler.Command)
                        await handler.ExecuteAsync(_client, message.Message);
                }
            }

            #endregion

            #region MESSAGE_HANDLERS

            else if (!String.IsNullOrWhiteSpace(textMessage))
            {
                var messageHandlers = GetAll<IMessageHandler>();
                foreach (var handler in messageHandlers)
                {
                    var outgoing = handler.Outgoing && message.Message.IsOutgoing;
                    var incoming = handler.Incoming && !message.Message.IsOutgoing;

                    if (outgoing || incoming)
                    {
                        if (!String.IsNullOrEmpty(handler.Pattern))
                        {
                            var match = Regex.Match(textMessage, handler.Pattern);
                            if (match.Success)
                                await handler.ExecuteAsync(_client, message.Message);
                        }
                    }
                }
            }
            #endregion
        }
        #endregion

        #region PROCESS_CALLACK
        private async Task ProcessCallbackQuery(TdApi.Update.UpdateNewCallbackQuery callbackQuery)
        {
            if (callbackQuery.Payload.GetType() == typeof(TdApi.CallbackQueryPayload.CallbackQueryPayloadData))
            {
                var payload = callbackQuery.Payload as TdApi.CallbackQueryPayload.CallbackQueryPayloadData;
                var callbackHandlers = GetAll<ICallbackHandler>();
                foreach (var handler in callbackHandlers)
                {
                    if (handler.DataIsRegex)
                        if (Regex.Match(System.Text.Encoding.UTF8.GetString(payload.Data), handler.Data).Success)
                            await handler.ExecuteAsync(_client, callbackQuery);
                    else if (handler.Data == System.Text.Encoding.UTF8.GetString(payload.Data))
                        await handler.ExecuteAsync(_client, callbackQuery);
                }
            }
        }
        #endregion

        #region COMMAND_PARSER
        public string GetCommand(TdApi.Message message)
        {
            string command = null;
            TdLib.TdApi.FormattedText formattedText = new TdLib.TdApi.FormattedText();
            if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageText))
            {
                var messageText = message.Content as TdLib.TdApi.MessageContent.MessageText;
                formattedText = messageText.Text;
            }
            else
            {
                if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessagePhoto))
                {
                    var messagePhoto = message.Content as TdLib.TdApi.MessageContent.MessagePhoto;
                    formattedText = messagePhoto.Caption;
                }
                else if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageDocument))
                {
                    var messageDocument = message.Content as TdLib.TdApi.MessageContent.MessageDocument;
                    formattedText = messageDocument.Caption;
                }
                else if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageVideo))
                {
                    var messageVideo = message.Content as TdLib.TdApi.MessageContent.MessageVideo;
                    formattedText = messageVideo.Caption;
                }
            }

            foreach (var entity in formattedText.Entities)
            {
                if (entity.Type.GetType() == typeof(TdLib.TdApi.TextEntityType.TextEntityTypeBotCommand) && String.IsNullOrWhiteSpace(command))
                {
                    if (entity.Offset == 0)
                    {
                        var splitCommand = formattedText.Text.Split();
                        if (splitCommand[0].EndsWith($"@{_me.Username}"))
                        {
                            command = splitCommand[0].Split('@')[0];
                        }
                        else
                        {
                            command = splitCommand[0];
                        }
                    }
                }
            }
            return command;
        }
        #endregion

        #region MESSAGE_PARSER

        public string GetMessageText(TdApi.Message message)
        {
            TdLib.TdApi.FormattedText formattedText = new TdLib.TdApi.FormattedText();
            if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageText))
            {
                var messageText = message.Content as TdLib.TdApi.MessageContent.MessageText;
                formattedText = messageText.Text;
            }
            else
            {
                if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessagePhoto))
                {
                    var messagePhoto = message.Content as TdLib.TdApi.MessageContent.MessagePhoto;
                    formattedText = messagePhoto.Caption;
                }
                else if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageDocument))
                {
                    var messageDocument = message.Content as TdLib.TdApi.MessageContent.MessageDocument;
                    formattedText = messageDocument.Caption;
                }
                else if (message.Content.GetType() == typeof(TdLib.TdApi.MessageContent.MessageVideo))
                {
                    var messageVideo = message.Content as TdLib.TdApi.MessageContent.MessageVideo;
                    formattedText = messageVideo.Caption;
                }
            }
            return formattedText.Text;
        }

        #endregion

        #region REFLECTION
        // https://stackoverflow.com/a/41650057/15306888
        public static IEnumerable<T> GetAll<T>()
        {
            return Assembly.GetExecutingAssembly()
                .GetTypes()
                .Where(type => typeof(T).IsAssignableFrom(type))
                .Where(type =>
                    !type.IsAbstract &&
                    !type.IsGenericType &&
                    type.GetConstructor(new Type[0]) != null)
                .Select(type => (T)Activator.CreateInstance(type))
                .ToList();
        }
        #endregion
    }
}

Upvotes: 1

Views: 5786

Answers (1)

Michal Rosenbaum
Michal Rosenbaum

Reputation: 2061

Since it's always null I think that the problem is that you're not creating an instance of your handler. I prepared a demo for you where I did that and it works.

public interface ICommandHandler 
{
    string Command { get; }
    Task ExecuteAsync();
}
public class FirstCommandHandler : ICommandHandler
{
    public string Command => "First";

    public async Task ExecuteAsync()
    {
        Console.WriteLine("Hello from first.");
        await Task.Delay(10);
    }
}
public class SecondCommandHandler : ICommandHandler
{
    public string Command => "Second";

    public async Task ExecuteAsync()
    {
        Console.WriteLine("Hello from second.");
        await Task.Delay(10);
    }
}

public class Program
{
    static async Task Main(string[] args)
    {
        var handlers = AppDomain.CurrentDomain.GetAssemblies()
            .SelectMany(s => s.GetTypes())
            .Where(p => typeof(ICommandHandler).IsAssignableFrom(p) && p.IsClass);

        foreach (var handler in handlers)
        {
            var handlerInstance = (ICommandHandler)Activator.CreateInstance(handler);
            if (handlerInstance.Command == "First")
            {
                await handlerInstance.ExecuteAsync();
            }
        }
    }
}

If it's not the case, could you show some more code? Are you trying to check Command value by reflection?

Upvotes: 6

Related Questions