Steven Jeuris
Steven Jeuris

Reputation: 19130

Check whether a constructor calls another constructor

During reflection, is it possible in C# to check whether one constructor calls another?

class Test
{
    public Test() : this( false ) { }
    public Test( bool inner ) { }    
}

I would like to determine for each ConstructorInfo whether or not it's at the end of chain of invocation.

Upvotes: 4

Views: 840

Answers (5)

Tetsu
Tetsu

Reputation: 1

I wrote an MSIL analysis code for checking whether it calls this or not.

The code can detect all MethodBase objects being called within a constructor. The method IsCallingOtherConstructor checks if one of the other constructors is found among the detected MethodBase objects.

static bool IsCallingOtherConstructor(ConstructorInfo constructorInfo)
{
    var otherConstructors = constructorInfo.DeclaringType?.GetConstructors(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Where(x => !Equals(x, constructorInfo));
    return otherConstructors is not null && otherConstructors.Count() > 0
        && Instruction.GetList(constructorInfo, new[] { OperandType.InlineMethod }).Where(x => x.Operand as ConstructorInfo is not null).Any(x => otherConstructors.Any(y => Equals(x.Operand, y)));
}

Here is the MSIL analysis code.

using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.InteropServices;
    
public class Instruction
{
    private sealed class ByteBuffer
    {
        public int position;

        public ByteBuffer(byte[] buffer)
        {
            Buffer = buffer;
        }

        public byte ReadByte()
        {
            CheckCanRead(1);
            return Buffer[position++];
        }

        private byte[] ReadBytes(int length)
        {
            CheckCanRead(length);
            var bytes = new byte[length];
            System.Buffer.BlockCopy(Buffer, position, bytes, 0, length);
            position += length;
            return bytes;
        }

        public short ReadInt16()
        {
            CheckCanRead(2);
            var @short = (short)(Buffer[position]
                                  + (Buffer[position + 1] << 8));
            position += 2;
            return @short;
        }

        public int ReadInt32()
        {
            CheckCanRead(4);
            var @int = Buffer[position]
                       + (Buffer[position + 1] << 8)
                       + (Buffer[position + 2] << 16)
                       + (Buffer[position + 3] << 24);
            position += 4;
            return @int;
        }

        public long ReadInt64()
        {
            CheckCanRead(8);
            long @long = Buffer[position]
                         + (Buffer[position + 1] << 8)
                         + (Buffer[position + 2] << 16)
                         + (Buffer[position + 3] << 24)
                         + (Buffer[position + 4] << 32)
                         + (Buffer[position + 5] << 40)
                         + (Buffer[position + 6] << 48)
                         + (Buffer[position + 7] << 56);
            position += 8;
            return @long;
        }

        public float ReadSingle()
        {
            if (!BitConverter.IsLittleEndian)
            {
                var bytes = ReadBytes(4);
                Array.Reverse(bytes);
                return BitConverter.ToSingle(bytes, 0);
            }
            else
            {
                CheckCanRead(4);
                var value = BitConverter.ToSingle(Buffer, position);
                position += 4;
                return value;
            }
        }

        public double ReadDouble()
        {
            if (!BitConverter.IsLittleEndian)
            {
                var bytes = ReadBytes(8);
                Array.Reverse(bytes);
                return BitConverter.ToDouble(bytes, 0);
            }
            else
            {
                CheckCanRead(8);
                var value = BitConverter.ToDouble(Buffer, position);
                position += 8;
                return value;
            }
        }

        public int[] ReadBranches()
        {
            var length = ReadInt32();
            var branches = new int[length];
            var offsets = new int[length];
            for (var i = 0; i < length; i++)
                offsets[i] = ReadInt32();
            for (var i = 0; i < length; i++)
                branches[i] = position + offsets[i];

            return branches;
        }

        private void CheckCanRead(int count)
        {
            if (position + count > Buffer.Length)
                throw new ArgumentOutOfRangeException();
        }

        public void MoveByte() => position++;

        private void MoveBytes(int length) => position += length;

        public void MoveInt16() => position += 2;

        public void MoveInt32() => position += 4;

        public void MoveInt64() => position += 8;

        public void MoveSingle()
        {
            if (!BitConverter.IsLittleEndian)
                MoveBytes(4);
            else
                position += 4;

        }

        public void MoveDouble()
        {
            if (!BitConverter.IsLittleEndian)
                MoveBytes(8);
            else
                position += 8;
        }

        public void MoveBranches()
        {
            var length = ReadInt32();
            for (var i = 0; i < length; i++)
                MoveInt32();
        }

        public byte[] Buffer { get; }
    }

    public static readonly OpCode[] OneByteOpcodes;

    public static readonly OpCode[] TwoBytesOpcodes;

    static Instruction()
    {
        var oneByteOpcodes = new OpCode[0xe1];
        var twoBytesOpcodes = new OpCode[0x1f];
        var opCodeFields = typeof(OpCodes).GetFields(BindingFlags.Public | BindingFlags.Static);

        foreach (var field in opCodeFields)
        {
            var opcode = (OpCode)field.GetValue(null)!;
            if (opcode.OpCodeType == OpCodeType.Nternal)
                continue;

            if (opcode.Size == 1)
                oneByteOpcodes[opcode.Value] = opcode;
            else
                twoBytesOpcodes[opcode.Value & 0xff] = opcode;
        }

        OneByteOpcodes = oneByteOpcodes;
        TwoBytesOpcodes = twoBytesOpcodes;
    }

    public static IList<Instruction> GetList(MethodBase methodBase, [Optional] OperandType[]? targetOperandTypes)
    {
        static bool TryGetByteBuffer(MethodBase method, [NotNullWhen(true)] out IList<LocalVariableInfo>? localVariableInfos, [NotNullWhen(true)] out ByteBuffer? byteBuffer)
        {
            var body = method.GetMethodBody();
            if (body is not null)
            {
                var bytes = body.GetILAsByteArray();
                if (bytes != null)
                {
                    localVariableInfos = body.LocalVariables;
                    byteBuffer = new ByteBuffer(bytes);
                    return true;
                }
            }

            localVariableInfos = null;
            byteBuffer = null;
            return false;
        }

        var body = methodBase.GetMethodBody();
        if (!TryGetByteBuffer(methodBase, out var locals, out var byteBuffer))
            return new List<Instruction>();

        var methodGenericArguments = methodBase is not ConstructorInfo ? methodBase.GetGenericArguments() : default;
        var typeGenericArguments = methodBase.DeclaringType is not null ? methodBase.DeclaringType.GetGenericArguments() : default;

        var parameters = methodBase.GetParameters();
        var module = methodBase.Module;

        object GetVariable(OpCode opCode, int index)
        {
            if (opCode.Name is not null && opCode.Name.Contains("loc"))
                return locals[index];
            else
            {
                if (!methodBase.IsStatic)
                    index--;

                return parameters[index];
            }
        }

        object? GetOperand(OpCode opCode, ByteBuffer byteBuffer)
            => opCode.OperandType switch
            {
                OperandType.InlineNone => null,
                OperandType.InlineSwitch => byteBuffer.ReadBranches(),
                OperandType.ShortInlineBrTarget => byteBuffer.position - (sbyte)byteBuffer.ReadByte(),
                OperandType.InlineBrTarget => byteBuffer.position - byteBuffer.ReadInt32(),
                OperandType.ShortInlineI => opCode == OpCodes.Ldc_I4_S ? (sbyte)byteBuffer.ReadByte() : byteBuffer.ReadByte(),
                OperandType.InlineI => byteBuffer.ReadInt32(),
                OperandType.ShortInlineR => byteBuffer.ReadSingle(),
                OperandType.InlineR => byteBuffer.ReadDouble(),
                OperandType.InlineI8 => byteBuffer.ReadInt64(),
                OperandType.InlineSig => module.ResolveSignature(byteBuffer.ReadInt32()),
                OperandType.InlineString => module.ResolveString(byteBuffer.ReadInt32()),
                OperandType.InlineTok => module.ResolveMember(byteBuffer.ReadInt32(), typeGenericArguments, methodGenericArguments),
                OperandType.InlineType => module.ResolveType(byteBuffer.ReadInt32(), typeGenericArguments, methodGenericArguments),
                OperandType.InlineMethod => module.ResolveMethod(byteBuffer.ReadInt32(), typeGenericArguments, methodGenericArguments),
                OperandType.InlineField => module.ResolveField(byteBuffer.ReadInt32(), typeGenericArguments, methodGenericArguments),
                OperandType.ShortInlineVar => GetVariable(opCode, byteBuffer.ReadByte()),
                OperandType.InlineVar => GetVariable(opCode, byteBuffer.ReadInt16()),
                _ => throw new NotSupportedException(),
            };

        static void Move(OpCode opCode, ByteBuffer byteBuffer)
        {
            switch (opCode.OperandType)
            {
                case OperandType.InlineNone:
                    break;
                case OperandType.ShortInlineBrTarget:
                case OperandType.ShortInlineVar:
                case OperandType.ShortInlineI:
                    byteBuffer.MoveByte();
                    break;
                case OperandType.InlineVar:
                    byteBuffer.MoveInt16();
                    break;
                case OperandType.InlineBrTarget:
                case OperandType.InlineI:
                case OperandType.InlineSig:
                case OperandType.InlineString:
                case OperandType.InlineTok:
                case OperandType.InlineMethod:
                case OperandType.InlineType:
                case OperandType.InlineField:
                    byteBuffer.MoveInt32();
                    break;
                case OperandType.InlineI8:
                    byteBuffer.MoveInt64();
                    break;
                case OperandType.ShortInlineR:
                    byteBuffer.MoveSingle();
                    break;
                case OperandType.InlineR:
                    byteBuffer.MoveDouble();
                    break;
                case OperandType.InlineSwitch:
                    byteBuffer.MoveBranches();
                    break;
                default:
                    throw new NotSupportedException();
            }
        }

        var result = new List<Instruction>(byteBuffer.Buffer.Length / 3);
        var previous = default(Instruction);
        while (byteBuffer.position < byteBuffer.Buffer.Length)
        {
            var ilOpCode = byteBuffer.ReadByte();
            var opCode = ilOpCode != 0xfe ? OneByteOpcodes[ilOpCode] : TwoBytesOpcodes[byteBuffer.ReadByte()];
            object? operand;
            if (targetOperandTypes is null || targetOperandTypes.Any(x => Equals(x, opCode.OperandType)))
                operand = GetOperand(opCode, byteBuffer);
            else
            {
                operand = null;
                Move(opCode, byteBuffer);
            }
            var instruction = new Instruction(byteBuffer.position, opCode, operand);

            if (previous != null)
            {
                instruction.Previous = previous;
                previous.Next = instruction;
            }

            result.Add(previous = instruction);
        }

        return result.AsReadOnly();
    }

    private Instruction? previous;
    private Instruction? next;

    public Instruction(int offset, OpCode opCode, object? operand)
    {
        Offset = offset;
        OpCode = opCode;
        Operand = operand;
    }

    public int Offset { get; }

    public OpCode OpCode { get; }

    public object? Operand { get; }

    public Instruction Previous { get => previous ?? throw new NullReferenceException(nameof(previous)); set => previous = value; }

    public Instruction Next { get => next ?? throw new NullReferenceException(nameof(next)); set => next = value; }
}

Upvotes: 0

Joe White
Joe White

Reputation: 97818

Consider looking at Cecil or Roslyn.

Cecil operates on the compiled assembly, like Reflection does. it has higher-level libraries built on top of it to support refactorings in the SharpDevelop IDE, so it might have something to make this easier.

Roslyn operates on source code and gives you an object model based on that, so if you're willing to work against the source instead of binaries, it might be even easier to work with.

(I've never actually used Cecil for anything like this and I've never used Roslyn at all, so I can't do much more than point you at the projects and wish you luck. If you do manage to get something working, I'd be interested to hear how it went!)

Upvotes: 1

Ivo
Ivo

Reputation: 8362

what you can do is to add a property to the object telling the aspect was applied. So, you won't be applying the aspect several times as you can check that property. It's not what you asked but it may help you with your underlying issue.

Upvotes: 1

Steven Jeuris
Steven Jeuris

Reputation: 19130

This is a temporary answer, to state what I found so far.

I didn't find any property of ConstructorInfo which could indicate whether the constructor calls another constructor or not. Neither did the properties of MethodBody.

I am having somewhat success evaluating the MSIL byte code. My first findings indicate the constructor which is eventually called starts out with OpCodes.Call immediately, except for a few possible other OpCodes. Constructors which call other constructors have 'unexpected' OpCodes.

public static bool CallsOtherConstructor( this ConstructorInfo constructor )
{
    MethodBody body = constructor.GetMethodBody();
    if ( body == null )
    {
        throw new ArgumentException( "Constructors are expected to always contain byte code." );
    }

    // Constructors at the end of the invocation chain start with 'call' immediately.
    var untilCall = body.GetILAsByteArray().TakeWhile( b => b != OpCodes.Call.Value );
    return !untilCall.All( b =>
        b == OpCodes.Nop.Value ||     // Never encountered, but my intuition tells me a no-op would be valid.
        b == OpCodes.Ldarg_0.Value || // Seems to always precede Call immediately.
        b == OpCodes.Ldarg_1.Value    // Seems to be added when calling base constructor.
        );
}

I'm not sure at all about MSIL. Perhaps it's impossible to have no-ops in between there, or there is no need at all to start out a constructor like that, but for all my current unit tests it seems to work.

[TestClass]
public class ConstructorInfoExtensionsTest
{
    class PublicConstructors
    {
        // First
        public PublicConstructors() : this( true ) {}

        // Second
        public PublicConstructors( bool one ) : this( true, true ) {}

        // Final
        public PublicConstructors( bool one, bool two ) {}

        // Alternate final
        public PublicConstructors( bool one, bool two, bool three ) {}
    }

    class PrivateConstructors
    {
        // First
        PrivateConstructors() : this( true ) {}

        // Second
        PrivateConstructors( bool one ) : this( true, true ) {}

        // Final
        PrivateConstructors( bool one, bool two ) {}

        // Alternate final
        PrivateConstructors( bool one, bool two, bool three ) {}
    }

    class TripleBaseConstructors : DoubleBaseConstructors
    {
        public TripleBaseConstructors() : base() { }
        public TripleBaseConstructors( bool one ) : base( one ) { }
    }

    class DoubleBaseConstructors : BaseConstructors
    {
        public DoubleBaseConstructors() : base() {}
        public DoubleBaseConstructors( bool one ) : base( one ) {}
    }

    class BaseConstructors : Base
    {
        public BaseConstructors() : base() {}
        public BaseConstructors( bool one ) : base( one ) {}
    }

    class Base
    {
        // No parameters
        public Base() {}

        // One parameter
        public Base( bool one ) {} 
    }

    class ContentConstructor
    {
        public ContentConstructor()
        {
            SomeMethod();
        }

        public ContentConstructor( bool one )
        {
            int bleh = 0;
        }

        bool setTwo;
        public ContentConstructor( bool one, bool two )
        {
            setTwo = two;
        }

        void SomeMethod() {}
    }

    [TestMethod]
    public void CallsOtherConstructorTest()
    {           
        Action<ConstructorInfo[]> checkConstructors = cs =>
        {
            ConstructorInfo first = cs.Where( c => c.GetParameters().Count() == 0 ).First();
            Assert.IsTrue( first.CallsOtherConstructor() );
            ConstructorInfo second = cs.Where( c => c.GetParameters().Count() == 1 ).First();
            Assert.IsTrue( second.CallsOtherConstructor() );
            ConstructorInfo final = cs.Where( c => c.GetParameters().Count() == 2 ).First();
            Assert.IsFalse( final.CallsOtherConstructor() );
            ConstructorInfo alternateFinal = cs.Where( c => c.GetParameters().Count() == 3 ).First();
            Assert.IsFalse( alternateFinal.CallsOtherConstructor() );
        };

        // Public and private constructors.
        checkConstructors( typeof( PublicConstructors ).GetConstructors() );
        checkConstructors( typeof( PrivateConstructors ).GetConstructors( BindingFlags.NonPublic | BindingFlags.Instance ) );

        // Inheritance.
        Action<ConstructorInfo[]> checkBaseConstructors = cs =>
        {
            ConstructorInfo noParameters = cs.Where( c => c.GetParameters().Count() == 0 ).First();
            ConstructorInfo oneParameter = cs.Where( c => c.GetParameters().Count() == 1 ).First();

            // Only interested in constructors specified on this type, not base constructors,
            // thus calling a base constructor shouldn't qualify as 'true'.
            Assert.IsFalse( noParameters.CallsOtherConstructor() );
            Assert.IsFalse( oneParameter.CallsOtherConstructor() );
        };
        checkBaseConstructors( typeof( BaseConstructors ).GetConstructors() );
        checkBaseConstructors( typeof( DoubleBaseConstructors ).GetConstructors() );
        checkBaseConstructors( typeof( TripleBaseConstructors ).GetConstructors() );

        // Constructor with content.
        foreach( var constructor in typeof( ContentConstructor ).GetConstructors() )
        {
            Assert.IsFalse( constructor.CallsOtherConstructor() );
        }               
    }
}

Upvotes: 3

Shahar Prish
Shahar Prish

Reputation: 4847

As far as I know, you cannot check or inspect code using reflection in an easy fashion. All reflection lets you do is reflect on the metadata information of the assembly.

You can use GetMethodBody to grab the content of the method, but then you will have to actually parse it and understand the IL yourself.

Upvotes: 0

Related Questions