Reputation: 9426
Given an INamedTypeSymbol
(that comes from an referenced assembly, not source) how can I find all types (in both source and referenced assemblies) that inherit from this type?
In my particular case, I'm looking for all types that inherit from NUnit.Framework.TestAttribute
. I can get access to the named type symbol as follows:
var ws = MSBuildWorkspace.Create();
var soln = ws.OpenSolutionAsync(@"C:\Users\...\SampleInheritanceStuff.sln").Result;
var proj = soln.Projects.Single();
var compilation = proj.GetCompilationAsync().Result;
string TEST_ATTRIBUTE_METADATA_NAME = "NUnit.Framework.TestAttribute";
var testAttributeType = compilation.GetTypeByMetadataName(TEST_ATTRIBUTE_METADATA_NAME);
//Now how do I find types that inherit from this type?
I've taken a look at SymbolFinder
, Compilation
and INamedTypeSymbol
but I haven't had any luck.
Edit: The FindDerivedClassesAsync
method looks close to what I need. (I'm not 100% sure that it finds derived classes in referenced assemblies). However it's internal, so I've opened an issue.
Upvotes: 9
Views: 1874
Reputation: 3416
The FindDerivedClassesAsync
is indeed what you are looking for.
It finds derived classes in referenced assemblies, as you can see in the source code for DependentTypeFinder
(notice the locationsInMetadata
variable).
As for using it, you can always do it with reflection in the meantime:
private static readonly Lazy<Func<INamedTypeSymbol, Solution, IImmutableSet<Project>, CancellationToken, Task<IEnumerable<INamedTypeSymbol>>>> FindDerivedClassesAsync
= new Lazy<Func<INamedTypeSymbol, Solution, IImmutableSet<Project>, CancellationToken, Task<IEnumerable<INamedTypeSymbol>>>>(() => (Func<INamedTypeSymbol, Solution, IImmutableSet<Project>, CancellationToken, Task<IEnumerable<INamedTypeSymbol>>>)Delegate.CreateDelegate(typeof(Func<INamedTypeSymbol, Solution, IImmutableSet<Project>, CancellationToken, Task<IEnumerable<INamedTypeSymbol>>>), DependentTypeFinder.Value.GetMethod("FindDerivedClassesAsync", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)));
(code borrowed from Tunnel Vision Laboratories Github)
Good luck!
UPDATE:
This method has been made public by now. (source)
Upvotes: 3
Reputation: 947
You can get this information using the SemanticModel exposed from the Compilation
public static IEnumerable<INamedTypeSymbol> GetBaseClasses(SemanticModel model, BaseTypeDeclarationSyntax type)
{
var classSymbol = model.GetDeclaredSymbol(type);
var returnValue = new List<INamedTypeSymbol>();
while (classSymbol.BaseType != null)
{
returnValue.Add(classSymbol.BaseType);
if (classSymbol.Interfaces != null)
returnValue.AddRange(classSymbol.Interfaces);
classSymbol = classSymbol.BaseType;
}
return returnValue;
}
This will give you a list of all the base classes as well as every interface that each base class implements. You can then filter to the INamedTypeSymbol that you are interested in:
public static IEnumerable<BaseTypeDeclarationSyntax>
FindClassesDerivedOrImplementedByType(Compilation compilation
, INamedTypeSymbol target)
{
foreach (var tree in compilation.SyntaxTrees)
{
var semanticModel = compilation.GetSemanticModel(tree);
foreach (var type in tree.GetRoot().DescendantNodes()
.OfType<TypeDeclarationSyntax>())
{
var baseClasses = GetBaseClasses(semanticModel, type);
if (baseClasses != null)
if (baseClasses.Contains(target))
yield return type;
}
}
}
Upvotes: 0