diff --git a/src/EntityFrameworkCore.Projectables.Generator/Infrastructure/Diagnostics.cs b/src/EntityFrameworkCore.Projectables.Generator/Infrastructure/Diagnostics.cs index 9fab2f28..bd44b763 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Infrastructure/Diagnostics.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Infrastructure/Diagnostics.cs @@ -46,8 +46,8 @@ static internal class Diagnostics public readonly static DiagnosticDescriptor RequiresBodyDefinition = new DiagnosticDescriptor( id: "EFP0006", - title: "Method or property should expose a body definition", - messageFormat: "Method or property '{0}' should expose a body definition (e.g. an expression-bodied member or a block-bodied method) to be used as the source for the generated expression tree.", + title: "Method or property should expose a body definition if not overwritten in classes derived from the declaring class", + messageFormat: "Method or property '{0}' should expose a body definition (e.g. an expression-bodied member or a block-bodied method) to be used as the source for the generated expression tree if not overwritten in at least one class derived from the class where the method or property is declared.", category: "Design", DiagnosticSeverity.Error, isEnabledByDefault: true); diff --git a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.BodyProcessors.cs b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.BodyProcessors.cs index 60cdc32f..6648f208 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.BodyProcessors.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.BodyProcessors.cs @@ -14,17 +14,23 @@ static internal partial class ProjectableInterpreter /// Returns false and reports diagnostics on failure. /// private static bool TryApplyMethodBody( + MemberDeclarationSyntax originalMemberDeclarationSyntax, MethodDeclarationSyntax methodDeclarationSyntax, + SemanticModel semanticModel, bool allowBlockBody, ISymbol memberSymbol, ExpressionSyntaxRewriter expressionSyntaxRewriter, DeclarationSyntaxRewriter declarationSyntaxRewriter, SourceProductionContext context, + Compilation? compilation, ProjectableDescriptor descriptor) { ExpressionSyntax? bodyExpression = null; var isExpressionBodied = false; + var derivedTypes = GetDerivedTypes(semanticModel.GetDeclaredSymbol(originalMemberDeclarationSyntax), compilation); + var isHierarchy = derivedTypes?.Count > 0; + if (methodDeclarationSyntax.ExpressionBody is not null) { bodyExpression = methodDeclarationSyntax.ExpressionBody.Expression; @@ -48,7 +54,7 @@ private static bool TryApplyMethodBody( return false; // diagnostics already reported by BlockStatementConverter } } - else + else if (!isHierarchy) { return ReportRequiresBodyAndFail(context, methodDeclarationSyntax, memberSymbol.Name); } @@ -57,7 +63,7 @@ private static bool TryApplyMethodBody( descriptor.ReturnTypeName = returnType.ToString(); // Only rewrite expression-bodied methods; block-bodied methods are already rewritten - descriptor.ExpressionBody = isExpressionBodied + descriptor.ExpressionBody = isExpressionBodied && bodyExpression != null ? (ExpressionSyntax)expressionSyntaxRewriter.Visit(bodyExpression) : bodyExpression; @@ -73,6 +79,13 @@ private static bool TryApplyMethodBody( ApplyExtensionBlockTypeParameters(memberSymbol, descriptor); } + // If we are rewriting a hierarchy method we need to invoke the derived types' overrides + if(isHierarchy) + { + descriptor.HierarchyOriginalExpressionBody = descriptor.ExpressionBody; + descriptor.ExpressionBody = new HierarchyMembersConverter().DuplicateMethodExpression(derivedTypes!, descriptor); + } + return true; } @@ -92,14 +105,18 @@ private static bool TryApplyExpressionPropertyBody( ExpressionSyntaxRewriter expressionSyntaxRewriter, DeclarationSyntaxRewriter declarationSyntaxRewriter, SourceProductionContext context, + Compilation? compilation, ProjectableDescriptor descriptor) { + var derivedTypes = GetDerivedTypes(semanticModel.GetDeclaredSymbol(originalMethodDecl), compilation); + var isHierarchy = derivedTypes?.Count > 0; + var rawExpr = TryGetPropertyGetterExpression(exprPropDecl); var (innerBody, lambdaParamNames) = rawExpr is not null ? TryExtractLambdaBodyAndParams(rawExpr, semanticModel, member.SyntaxTree) : (null, []); - if (innerBody is null) + if (innerBody is null && !isHierarchy) { return ReportRequiresBodyAndFail(context, exprPropDecl, memberSymbol.Name); } @@ -112,77 +129,80 @@ private static bool TryApplyExpressionPropertyBody( // For cross-tree expression properties the rewriter's SemanticModel cannot resolve // nodes from the other file — skip rewriting in that case (simple lambda bodies need // no rewrites; advanced features like null-conditional rewriting are unsupported cross-file). - var visitedBody = exprPropDecl.SyntaxTree == member.SyntaxTree + var visitedBody = exprPropDecl.SyntaxTree == member.SyntaxTree && innerBody != null ? (ExpressionSyntax)expressionSyntaxRewriter.Visit(innerBody) : innerBody; - // For instance methods and C#14 extension members, BuildBaseDescriptor adds an - // implicit @this receiver parameter. If the expression property lambda uses a - // different parameter name (e.g. c => c.Value > 0), rename it so the generated - // code references @this instead of an undefined identifier. + if (visitedBody != null) + { + // For instance methods and C#14 extension members, BuildBaseDescriptor adds an + // implicit @this receiver parameter. If the expression property lambda uses a + // different parameter name (e.g. c => c.Value > 0), rename it so the generated + // code references @this instead of an undefined identifier. #if ROSLYN_5_0_OR_LATER - var isExtensionMember = memberSymbol.ContainingType is { IsExtension: true }; + var isExtensionMember = memberSymbol.ContainingType is { IsExtension: true }; #else - var isExtensionMember = false; + var isExtensionMember = false; #endif - var hasImplicitReceiver = isExtensionMember - || !originalMethodDecl.Modifiers.Any(SyntaxKind.StaticKeyword); + var hasImplicitReceiver = isExtensionMember + || !originalMethodDecl.Modifiers.Any(SyntaxKind.StaticKeyword); - // Collect (lambdaParamName → methodParamName) rename pairs to apply in a - // single multi-variable pass, avoiding cascading renames when names overlap. - var renames = new List<(string From, string To)>(); + // Collect (lambdaParamName → methodParamName) rename pairs to apply in a + // single multi-variable pass, avoiding cascading renames when names overlap. + var renames = new List<(string From, string To)>(); - var lambdaOffset = 0; - if (hasImplicitReceiver) - { - if (lambdaParamNames.Count > 0 && lambdaParamNames[0] != "@this") + var lambdaOffset = 0; + if (hasImplicitReceiver) { - renames.Add((lambdaParamNames[0], "@this")); - } - - lambdaOffset = 1; - } + if (lambdaParamNames.Count > 0 && lambdaParamNames[0] != "@this") + { + renames.Add((lambdaParamNames[0], "@this")); + } - // Rename each explicit method parameter from its lambda counterpart name. - var methodParams = originalMethodDecl.ParameterList.Parameters; - for (var i = 0; i < methodParams.Count; i++) - { - var lambdaIdx = lambdaOffset + i; - if (lambdaIdx >= lambdaParamNames.Count) - { - break; + lambdaOffset = 1; } - var lambdaName = lambdaParamNames[lambdaIdx]; - var methodName = methodParams[i].Identifier.ValueText; - if (lambdaName != methodName) + // Rename each explicit method parameter from its lambda counterpart name. + var methodParams = originalMethodDecl.ParameterList.Parameters; + for (var i = 0; i < methodParams.Count; i++) { - renames.Add((lambdaName, methodName)); - } - } + var lambdaIdx = lambdaOffset + i; + if (lambdaIdx >= lambdaParamNames.Count) + { + break; + } - // Apply all renames. To avoid cascading substitutions when names overlap - // (e.g. swapped parameter names), use a unique sentinel prefix for each - // intermediate name, then replace sentinels with the final names. - if (renames.Count > 0) - { - // Phase 1: rename each source name to a collision-free sentinel. - var sentinels = new List<(string Sentinel, string To)>(renames.Count); - for (var i = 0; i < renames.Count; i++) - { - var sentinel = $"__rename_sentinel_{i}__"; - visitedBody = (ExpressionSyntax)new VariableReplacementRewriter( - renames[i].From, - SyntaxFactory.IdentifierName(sentinel)).Visit(visitedBody); - sentinels.Add((sentinel, renames[i].To)); + var lambdaName = lambdaParamNames[lambdaIdx]; + var methodName = methodParams[i].Identifier.ValueText; + if (lambdaName != methodName) + { + renames.Add((lambdaName, methodName)); + } } - // Phase 2: replace each sentinel with the final target name. - foreach (var (sentinel, to) in sentinels) + // Apply all renames. To avoid cascading substitutions when names overlap + // (e.g. swapped parameter names), use a unique sentinel prefix for each + // intermediate name, then replace sentinels with the final names. + if (renames.Count > 0) { - visitedBody = (ExpressionSyntax)new VariableReplacementRewriter( - sentinel, - SyntaxFactory.IdentifierName(to)).Visit(visitedBody); + // Phase 1: rename each source name to a collision-free sentinel. + var sentinels = new List<(string Sentinel, string To)>(renames.Count); + for (var i = 0; i < renames.Count; i++) + { + var sentinel = $"__rename_sentinel_{i}__"; + visitedBody = (ExpressionSyntax)new VariableReplacementRewriter( + renames[i].From, + SyntaxFactory.IdentifierName(sentinel)).Visit(visitedBody); + sentinels.Add((sentinel, renames[i].To)); + } + + // Phase 2: replace each sentinel with the final target name. + foreach (var (sentinel, to) in sentinels) + { + visitedBody = (ExpressionSyntax)new VariableReplacementRewriter( + sentinel, + SyntaxFactory.IdentifierName(to)).Visit(visitedBody); + } } } @@ -191,6 +211,13 @@ private static bool TryApplyExpressionPropertyBody( ApplyParameterList(originalMethodDecl.ParameterList, declarationSyntaxRewriter, descriptor); ApplyTypeParameters(originalMethodDecl, declarationSyntaxRewriter, descriptor); + // If we are rewriting a hierarchy method we need to invoke the derived types' overrides + if (isHierarchy) + { + descriptor.HierarchyOriginalExpressionBody = descriptor.ExpressionBody; + descriptor.ExpressionBody = new HierarchyMembersConverter().DuplicateMethodExpression(derivedTypes!, descriptor); + } + return true; } @@ -211,14 +238,18 @@ private static bool TryApplyExpressionPropertyBodyForProperty( ExpressionSyntaxRewriter expressionSyntaxRewriter, DeclarationSyntaxRewriter declarationSyntaxRewriter, SourceProductionContext context, + Compilation? compilation, ProjectableDescriptor descriptor) { + var derivedTypes = GetDerivedTypes(semanticModel.GetDeclaredSymbol(originalPropertyDecl), compilation); + var isHierarchy = derivedTypes?.Count > 0; + var rawExpr = TryGetPropertyGetterExpression(exprPropDecl); var (innerBody, firstParamName) = rawExpr is not null ? TryExtractLambdaBodyAndFirstParam(rawExpr, semanticModel, member.SyntaxTree) : (null, null); - if (innerBody is null) + if (innerBody is null && !isHierarchy) { return ReportRequiresBodyAndFail(context, exprPropDecl, memberSymbol.Name); } @@ -229,10 +260,10 @@ private static bool TryApplyExpressionPropertyBodyForProperty( // uses the semantic model which requires the original (pre-rename) syntax nodes. // For cross-tree expression properties the rewriter's SemanticModel cannot resolve // nodes from the other file — skip rewriting in that case. - var visitedBody = exprPropDecl.SyntaxTree == member.SyntaxTree + var visitedBody = exprPropDecl.SyntaxTree == member.SyntaxTree && innerBody != null ? (ExpressionSyntax)expressionSyntaxRewriter.Visit(innerBody) : innerBody; - if (firstParamName is not null && firstParamName != "@this") + if (visitedBody != null && firstParamName != null && firstParamName != "@this") { visitedBody = (ExpressionSyntax)new VariableReplacementRewriter( firstParamName, @@ -243,6 +274,13 @@ private static bool TryApplyExpressionPropertyBodyForProperty( descriptor.ReturnTypeName = returnType.ToString(); descriptor.ExpressionBody = visitedBody; + // If we are rewriting a hierarchy method we need to invoke the derived types' overrides + if (isHierarchy) + { + descriptor.HierarchyOriginalExpressionBody = descriptor.ExpressionBody; + descriptor.ExpressionBody = new HierarchyMembersConverter().DuplicatePropertyExpression(derivedTypes!, descriptor); + } + return true; } @@ -251,14 +289,20 @@ private static bool TryApplyExpressionPropertyBodyForProperty( /// Returns false and reports diagnostics on failure. /// private static bool TryApplyPropertyBody( + MemberDeclarationSyntax originalMemberDeclarationSyntax, PropertyDeclarationSyntax propertyDeclarationSyntax, + SemanticModel semanticModel, bool allowBlockBody, ISymbol memberSymbol, ExpressionSyntaxRewriter expressionSyntaxRewriter, DeclarationSyntaxRewriter declarationSyntaxRewriter, SourceProductionContext context, + Compilation? compilation, ProjectableDescriptor descriptor) { + var derivedTypes = GetDerivedTypes(semanticModel.GetDeclaredSymbol(originalMemberDeclarationSyntax), compilation); + var isHierarchy = derivedTypes?.Count > 0; + ExpressionSyntax? bodyExpression = null; var isBlockBodiedGetter = false; @@ -299,7 +343,7 @@ private static bool TryApplyPropertyBody( } } - if (bodyExpression is null) + if (bodyExpression is null && !isHierarchy) { return ReportRequiresBodyAndFail(context, propertyDeclarationSyntax, memberSymbol.Name); } @@ -308,10 +352,17 @@ private static bool TryApplyPropertyBody( descriptor.ReturnTypeName = returnType.ToString(); // Only rewrite expression-bodied properties; block-bodied getters are already rewritten - descriptor.ExpressionBody = isBlockBodiedGetter + descriptor.ExpressionBody = isBlockBodiedGetter || bodyExpression == null ? bodyExpression : (ExpressionSyntax)expressionSyntaxRewriter.Visit(bodyExpression); + // If we are rewriting a hierarchy method we need to invoke the derived types' overrides + if (isHierarchy) + { + descriptor.HierarchyOriginalExpressionBody = descriptor.ExpressionBody; + descriptor.ExpressionBody = new HierarchyMembersConverter().DuplicatePropertyExpression(derivedTypes!, descriptor); + } + return true; } diff --git a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.Helpers.cs b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.Helpers.cs index ec6c25f8..5ee9cd08 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.Helpers.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.Helpers.cs @@ -183,5 +183,100 @@ private static bool ReportRequiresBodyAndFail( memberName)); return false; } + + private static IEnumerable GetAllTypes(INamespaceSymbol namespaceSymbol) + { + foreach (var type in namespaceSymbol.GetTypeMembers()) + { + yield return type; + } + + foreach (var nestedNamespace in namespaceSymbol.GetNamespaceMembers()) + { + foreach (var type in GetAllTypes(nestedNamespace)) + { + yield return type; + } + } + } + + private static IList GetDerivedTypes(ISymbol? symbol, Compilation? compilation) + { + if (symbol != null && compilation != null && (symbol.IsAbstract || symbol.IsVirtual || symbol.IsOverride)) + { + var types = GetAllTypes(compilation.GlobalNamespace) + .Where(t => IsDerivedFrom(t, symbol.ContainingType) && + t.DeclaringSyntaxReferences.Any(s => ((ClassDeclarationSyntax)s.GetSyntax()).Members.Any(m => { + var ss = compilation.GetSemanticModel(m.SyntaxTree).GetDeclaredSymbol(m); + return (ss != null && ss.IsOverride && ss.Kind == symbol.Kind && ss.Name == symbol.Name); + }))) + .OrderByDescending(GetDepth) // More specific types first + .ThenBy(t => t.Name) + .ToList(); + + // Remove types which are derived from another type in the list which has the declared symbol + // with the Projectable attribute (generation will be delegated to them) + var typesToRemove = types.Where(t => types.Any(tt => IsDerivedFrom(t, tt) && + tt.DeclaringSyntaxReferences.Any(s => ((ClassDeclarationSyntax)s.GetSyntax()).Members.First(m => { + var ss = compilation.GetSemanticModel(m.SyntaxTree).GetDeclaredSymbol(m); + return (ss != null && ss.IsOverride && ss.Kind == symbol.Kind && ss.Name == symbol.Name); + }).AttributeLists.Any(a => a.Attributes.Any(aa => { + var attributeSymbol = compilation.GetSemanticModel(aa.SyntaxTree).GetSymbolInfo(aa).Symbol; + + INamedTypeSymbol attributeTypeSymbol; + if (attributeSymbol is IMethodSymbol methodSymbol) + { + attributeTypeSymbol = methodSymbol.ContainingType; + } + else + { + attributeTypeSymbol = ((INamedTypeSymbol)attributeSymbol!); + } + + return attributeTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) == + "global::EntityFrameworkCore.Projectables.ProjectableAttribute"; + }))))).ToList(); + + foreach(var type in typesToRemove) + { + types.Remove(type); + } + + return types; + } + else + { + return Array.Empty(); + } + } + + private static bool IsDerivedFrom(INamedTypeSymbol candidate, INamedTypeSymbol baseClass) + { + var current = candidate.BaseType; + + while (current != null) + { + // SymbolEqualityComparer ensures we compare symbols accurately across compilation boundaries + if (SymbolEqualityComparer.Default.Equals(current, baseClass)) + { + return true; + } + current = current.BaseType; + } + + return false; + } + + private static int GetDepth(INamedTypeSymbol type) + { + var depth = 0; + while(type.BaseType != null) + { + depth++; + type = type.BaseType; + } + + return depth; + } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.cs b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.cs index 23db5a47..303eb485 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.cs @@ -75,26 +75,26 @@ static internal partial class ProjectableInterpreter { // Projectable method (_, MethodDeclarationSyntax methodDecl) => - TryApplyMethodBody(methodDecl, allowBlockBody, memberSymbol, - expressionSyntaxRewriter, declarationSyntaxRewriter, context, descriptor), + TryApplyMethodBody(member, methodDecl, semanticModel, allowBlockBody, memberSymbol, + expressionSyntaxRewriter, declarationSyntaxRewriter, context, compilation, descriptor), // Projectable method whose body is an Expression property (MethodDeclarationSyntax originalMethodDecl, PropertyDeclarationSyntax exprPropDecl) => TryApplyExpressionPropertyBody(originalMethodDecl, exprPropDecl, semanticModel, member, memberSymbol, - expressionSyntaxRewriter, declarationSyntaxRewriter, context, descriptor), + expressionSyntaxRewriter, declarationSyntaxRewriter, context, compilation, descriptor), // Projectable property whose body is an Expression property (PropertyDeclarationSyntax originalPropertyDecl, PropertyDeclarationSyntax exprPropDecl) when IsExpressionDelegatePropertyDecl(exprPropDecl, semanticModel) => TryApplyExpressionPropertyBodyForProperty(originalPropertyDecl, exprPropDecl, semanticModel, member, memberSymbol, - expressionSyntaxRewriter, declarationSyntaxRewriter, context, descriptor), + expressionSyntaxRewriter, declarationSyntaxRewriter, context, compilation, descriptor), // Projectable property (_, PropertyDeclarationSyntax propDecl) => - TryApplyPropertyBody(propDecl, allowBlockBody, memberSymbol, - expressionSyntaxRewriter, declarationSyntaxRewriter, context, descriptor), + TryApplyPropertyBody(member, propDecl, semanticModel, allowBlockBody, memberSymbol, + expressionSyntaxRewriter, declarationSyntaxRewriter, context, compilation, descriptor), // Projectable constructor (_, ConstructorDeclarationSyntax ctorDecl) => diff --git a/src/EntityFrameworkCore.Projectables.Generator/Models/ProjectableDescriptor.cs b/src/EntityFrameworkCore.Projectables.Generator/Models/ProjectableDescriptor.cs index e653a018..55dfdbbe 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/Models/ProjectableDescriptor.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/Models/ProjectableDescriptor.cs @@ -34,4 +34,6 @@ internal class ProjectableDescriptor public SyntaxList? ConstraintClauses { get; set; } public ExpressionSyntax? ExpressionBody { get; set; } + + public ExpressionSyntax? HierarchyOriginalExpressionBody { get; set; } } \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 36d68bed..6347c30c 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -7,7 +7,6 @@ using System.Text; using EntityFrameworkCore.Projectables.CodeFixes; using EntityFrameworkCore.Projectables.Generator.Comparers; -using EntityFrameworkCore.Projectables.Generator.Infrastructure; using EntityFrameworkCore.Projectables.Generator.Interpretation; using EntityFrameworkCore.Projectables.Generator.Models; using EntityFrameworkCore.Projectables.Generator.Registry; @@ -229,80 +228,91 @@ private static void Execute( } var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName, projectable.ParameterTypeNames); - var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs"; - - var classSyntax = ClassDeclaration(generatedClassName) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.ClassTypeParameterList) - .WithConstraintClauses(projectable.ClassConstraintClauses ?? List()) - .AddAttributeLists( - AttributeList() - .AddAttributes(_editorBrowsableAttribute) - ) - .WithLeadingTrivia(member is ConstructorDeclarationSyntax ctor && compilation is not null ? BuildSourceDocComment(ctor, compilation) : TriviaList()) - .AddMembers( - MethodDeclaration( - GenericName( - Identifier("global::System.Linq.Expressions.Expression"), - TypeArgumentList( - SingletonSeparatedList( - (TypeSyntax)GenericName( - Identifier("global::System.Func"), - GetLambdaTypeArgumentListSyntax(projectable) + + AddSource(generatedClassName, projectable.ExpressionBody); + if(projectable.HierarchyOriginalExpressionBody != null) + { + AddSource(generatedClassName + "_Base", projectable.HierarchyOriginalExpressionBody); + } + + + void AddSource(string generatedClassName, ExpressionSyntax? body) + { + var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs"; + + var classSyntax = ClassDeclaration(generatedClassName) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.ClassTypeParameterList) + .WithConstraintClauses(projectable.ClassConstraintClauses ?? List()) + .AddAttributeLists( + AttributeList() + .AddAttributes(_editorBrowsableAttribute) + ) + .WithLeadingTrivia(member is ConstructorDeclarationSyntax ctor && compilation is not null ? BuildSourceDocComment(ctor, compilation) : TriviaList()) + .AddMembers( + MethodDeclaration( + GenericName( + Identifier("global::System.Linq.Expressions.Expression"), + TypeArgumentList( + SingletonSeparatedList( + (TypeSyntax)GenericName( + Identifier("global::System.Func"), + GetLambdaTypeArgumentListSyntax(projectable) + ) ) ) - ) - ), - "Expression" - ) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.TypeParameterList) - .WithConstraintClauses(projectable.ConstraintClauses ?? List()) - .WithBody( - Block( - ReturnStatement( - ParenthesizedLambdaExpression( - projectable.ParametersList ?? ParameterList(), - null, - projectable.ExpressionBody + ), + "Expression" + ) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.TypeParameterList) + .WithConstraintClauses(projectable.ConstraintClauses ?? List()) + .WithBody( + Block( + ReturnStatement( + ParenthesizedLambdaExpression( + projectable.ParametersList ?? ParameterList(), + null, + body + ) ) ) ) - ) - ); + ); - var compilationUnit = CompilationUnit(); + var compilationUnit = CompilationUnit(); - foreach (var usingDirective in projectable.UsingDirectives!) - { - compilationUnit = compilationUnit.AddUsings(usingDirective); - } + foreach (var usingDirective in projectable.UsingDirectives!) + { + compilationUnit = compilationUnit.AddUsings(usingDirective); + } - if (projectable.ClassNamespace is not null) - { - compilationUnit = compilationUnit.AddUsings( - UsingDirective( - ParseName(projectable.ClassNamespace) - ) - ); - } + if (projectable.ClassNamespace is not null) + { + compilationUnit = compilationUnit.AddUsings( + UsingDirective( + ParseName(projectable.ClassNamespace) + ) + ); + } - compilationUnit = compilationUnit - .AddMembers( - NamespaceDeclaration( - ParseName("EntityFrameworkCore.Projectables.Generated") - ).AddMembers(classSyntax) - ) - .WithLeadingTrivia( - TriviaList( - Comment("// "), - // Uncomment line below, for debugging purposes, to see when the generator is run on source generated files - // CarriageReturnLineFeed, Comment($"// Generated at {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss.fff} UTC for '{memberSymbol.Name}' in '{memberSymbol.ContainingType?.Name}'"), - Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)) + compilationUnit = compilationUnit + .AddMembers( + NamespaceDeclaration( + ParseName("EntityFrameworkCore.Projectables.Generated") + ).AddMembers(classSyntax) ) - ); + .WithLeadingTrivia( + TriviaList( + Comment("// "), + // Uncomment line below, for debugging purposes, to see when the generator is run on source generated files + // CarriageReturnLineFeed, Comment($"// Generated at {DateTime.UtcNow:yyyy-MM-dd HH:mm:ss.fff} UTC for '{memberSymbol.Name}' in '{memberSymbol.ContainingType?.Name}'"), + Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)) + ) + ); - context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); + context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); + } static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable) { diff --git a/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/ExpressionSyntaxRewriter.cs index 619a90f9..183548db 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/ExpressionSyntaxRewriter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/ExpressionSyntaxRewriter.cs @@ -116,12 +116,22 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition public override SyntaxNode? VisitBaseExpression(BaseExpressionSyntax node) { - // Swap out the use of this to @this - return VisitThisBaseExpression(node); + // Swap out the use of this to @this and cast it to the base type + return SyntaxFactory.ParenthesizedExpression( + SyntaxFactory.CastExpression( + SyntaxFactory.ParseTypeName(_semanticModel.GetTypeInfo(node).Type!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + SyntaxFactory.IdentifierName("@this"))) + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(node.GetTrailingTrivia()); } public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) { + if (node.Identifier.Text == "@this") + { + return node; + } + // Handle C# 14 extension parameter replacement (e.g., `e` in `extension(Entity e)` becomes `@this`) #if ROSLYN_5_0_OR_LATER if (_extensionParameterName is not null && node.Identifier.Text == _extensionParameterName) diff --git a/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/HierarchyMembersConverter.cs b/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/HierarchyMembersConverter.cs new file mode 100644 index 00000000..845d40c2 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/SyntaxRewriters/HierarchyMembersConverter.cs @@ -0,0 +1,142 @@ +using EntityFrameworkCore.Projectables.Generator.Models; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace EntityFrameworkCore.Projectables.Generator.SyntaxRewriters +{ + /// + /// Converts methods/properties bodies of hierarchies of classes into typed expressions. + /// + internal class HierarchyMembersConverter + { + public ExpressionSyntax DuplicateMethodExpression(IList derivedTypes, ProjectableDescriptor descriptor) + { + var @this = SyntaxFactory.IdentifierName("@this"); + + var arguments = descriptor.ParametersList?.Parameters.Count > 1 ? ConvertParameters(descriptor.ParametersList) : null; + + // Check if the method has an implementation or if it is abstract, if it is not abstract it will be added + // as the last result in the if/else if/else chain, otherwise the last type will be used instead + if (descriptor.ExpressionBody != null) + { + // @this is Type1 ? ((Type1)@this).Method(...) : ... + // ... ? ... : + // @this is TypeN ? ((TypeN)@this).Method(...) : ... + // virtualImplementation + return derivedTypes.Reverse().Aggregate(descriptor.ExpressionBody, AggregateTypes); + } + else + { + // DEV: handle generic types + var lastType = derivedTypes[derivedTypes.Count - 1]; + + // @this is Type1 ? ((Type1)@this).Method(...) : ... + // ... ? ... : + // ((TypeN)@this).Method(...) + return derivedTypes.Reverse().Skip(1) + .Aggregate((ExpressionSyntax)GetMethodInvocationExpression(lastType, descriptor.MemberName!, arguments), AggregateTypes); + } + + + ExpressionSyntax AggregateTypes(ExpressionSyntax expr, INamedTypeSymbol type) + { + return SyntaxFactory.ConditionalExpression( + SyntaxFactory.BinaryExpression(SyntaxKind.IsExpression, @this, GetTypeName(type)), + GetMethodInvocationExpression(type, descriptor.MemberName!, arguments), + expr); + } + } + + public ExpressionSyntax DuplicatePropertyExpression(IList derivedTypes, ProjectableDescriptor descriptor) + { + var @this = SyntaxFactory.IdentifierName("@this"); + + // Check if the property has an implementation or if it is abstract, if it is not abstract it will be added + // as the last result in the if/else if/else chain, otherwise the last type will be used instead + if (descriptor.ExpressionBody != null) + { + // @this is Type1 ? ((Type1)@this).Property : ... + // ... ? ... : + // @this is TypeN ? ((TypeN)@this).Property : ... + // virtualImplementation + return derivedTypes.Reverse().Aggregate(descriptor.ExpressionBody, AggregateTypes); + } + else + { + // DEV: handle generic types + var lastType = derivedTypes[derivedTypes.Count - 1]; + + // @this is Type1 ? ((Type1)@this).Property : ... + // ... ? ... : + // ((TypeN)@this).Property + return derivedTypes.Reverse().Skip(1) + .Aggregate((ExpressionSyntax)GetPropertyExpression(lastType, descriptor.MemberName!), AggregateTypes); + } + + + ExpressionSyntax AggregateTypes(ExpressionSyntax expr, INamedTypeSymbol type) + { + return SyntaxFactory.ConditionalExpression( + SyntaxFactory.BinaryExpression(SyntaxKind.IsExpression, @this, GetTypeName(type)), + GetPropertyExpression(type, descriptor.MemberName!), + expr); + } + } + + private static ArgumentListSyntax ConvertParameters(ParameterListSyntax parameters) + { + return SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(parameters.Parameters.Skip(1).Select(p => { + // Extract the name of the parameter (e.g., "myParam") + ExpressionSyntax identifier = SyntaxFactory.IdentifierName(p.Identifier); + + // Handle parameter modifiers (like 'ref', 'out', or 'in') + SyntaxToken? refKindKeyword = null; + if (p.Modifiers.Any(SyntaxKind.RefKeyword)) + refKindKeyword = SyntaxFactory.Token(SyntaxKind.RefKeyword); + else if (p.Modifiers.Any(SyntaxKind.OutKeyword)) + refKindKeyword = SyntaxFactory.Token(SyntaxKind.OutKeyword); + else if (p.Modifiers.Any(SyntaxKind.InKeyword)) + refKindKeyword = SyntaxFactory.Token(SyntaxKind.InKeyword); + + // Create the Argument node. If it has a ref/out modifier, pass it along. + if (refKindKeyword != null) + { + return SyntaxFactory.Argument(null, refKindKeyword.Value, identifier); + } + else + { + return SyntaxFactory.Argument(identifier); + } + }))); + } + + private static TypeSyntax GetTypeName(INamedTypeSymbol type) + { + return SyntaxFactory.ParseTypeName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + + private static InvocationExpressionSyntax GetMethodInvocationExpression(INamedTypeSymbol type, string methodName, ArgumentListSyntax? arguments) + { + var typeName = GetTypeName(type); + + var method = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.ParenthesizedExpression(SyntaxFactory.CastExpression(typeName, SyntaxFactory.IdentifierName("@this"))), + SyntaxFactory.IdentifierName(methodName)); + + // ((Type)@this).Method(...) + return arguments != null ? SyntaxFactory.InvocationExpression(method, arguments) : SyntaxFactory.InvocationExpression(method); + } + + private static MemberAccessExpressionSyntax GetPropertyExpression(INamedTypeSymbol type, string propertyName) + { + var typeName = GetTypeName(type); + + return SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.ParenthesizedExpression(SyntaxFactory.CastExpression(typeName, SyntaxFactory.IdentifierName("@this"))), + SyntaxFactory.IdentifierName(propertyName)); + } + } +} diff --git a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs index 3de24d8d..b411b8dc 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs @@ -9,5 +9,8 @@ public static class ExpressionExtensions /// Replaces all calls to properties and methods that are marked with the Projectable attribute with their respective expression tree /// public static Expression ExpandProjectables(this Expression expression) - => new ProjectableExpressionReplacer(new ProjectionExpressionResolver(), false).Replace(expression); + { + var resolver = new ProjectionExpressionResolver(); + return new ProjectableExpressionReplacer(resolver, resolver, false).Replace(expression); + } } \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs index f8a206e9..e0958e5e 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs @@ -61,7 +61,8 @@ public CustomQueryCompiler(IQueryCompiler decoratedQueryCompiler, var trackingByDefault = (contextOptions.FindExtension()?.QueryTrackingBehavior ?? QueryTrackingBehavior.TrackAll) == QueryTrackingBehavior.TrackAll; - _projectableExpressionReplacer = new ProjectableExpressionReplacer(new ProjectionExpressionResolver(), trackingByDefault); + var resolver = new ProjectionExpressionResolver(); + _projectableExpressionReplacer = new ProjectableExpressionReplacer(resolver, resolver, trackingByDefault); } public override Func CreateCompiledAsyncQuery(Expression query) diff --git a/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionBaseResolver.cs b/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionBaseResolver.cs new file mode 100644 index 00000000..f3455279 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables/Services/IProjectionExpressionBaseResolver.cs @@ -0,0 +1,10 @@ +using System.Linq.Expressions; +using System.Reflection; + +namespace EntityFrameworkCore.Projectables.Services; + +public interface IProjectionExpressionBaseResolver +{ + LambdaExpression FindGeneratedBaseExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null); +} \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 72bc0f8a..244abbc8 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -13,8 +13,10 @@ namespace EntityFrameworkCore.Projectables.Services public sealed class ProjectableExpressionReplacer : ExpressionVisitor { private readonly IProjectionExpressionResolver _resolver; + private readonly IProjectionExpressionBaseResolver _resolverBase; private readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); private readonly Dictionary _projectableMemberCache = new(); + private readonly Dictionary _projectableBaseMemberCache = new(); private readonly HashSet _expandingConstructors = new(); private IQueryProvider? _currentQueryProvider; private bool _disableRootRewrite = false; @@ -38,15 +40,38 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor private readonly static ConditionalWeakTable _closedSelectCache = new(); private readonly static ConditionalWeakTable _closedWhereCache = new(); - public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false) + public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false): + this(projectionExpressionResolver, null!, trackByDefault) { } + public ProjectableExpressionReplacer( + IProjectionExpressionResolver projectionExpressionResolver, + IProjectionExpressionBaseResolver projectionExpressionBaseResolver, + bool trackByDefault = false) { _trackingByDefault = trackByDefault; _resolver = projectionExpressionResolver; + _resolverBase = projectionExpressionBaseResolver; } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) { - if (!_projectableMemberCache.TryGetValue(memberInfo, out reflectedExpression)) + return TryGetReflectedExpression(memberInfo, false, out reflectedExpression); + } + bool TryGetReflectedExpression(MemberInfo memberInfo, bool isBase, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) + { + if (isBase) + { + if (!_projectableBaseMemberCache.TryGetValue(memberInfo, out reflectedExpression)) + { + var projectableAttribute = memberInfo.GetCustomAttribute(false); + + reflectedExpression = projectableAttribute is not null + ? _resolverBase?.FindGeneratedBaseExpression(memberInfo, projectableAttribute) + : null; + + _projectableBaseMemberCache.Add(memberInfo, reflectedExpression); + } + } + else if (!_projectableMemberCache.TryGetValue(memberInfo, out reflectedExpression)) { var projectableAttribute = memberInfo.GetCustomAttribute(false); @@ -185,7 +210,12 @@ protected override Expression VisitMethodCall(MethodCallExpression node) _disableRootRewrite = false; } - if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) + // Check if we are rewriting a base invocation ((BaseType)@this).MyMethod(...) or ((BaseBaseType)(BaseType)@this).MyMethod(...) + // We are only checking for a type cast from a type to its immediate parent, + // unwrapping nested casts, because the original parameter might have been replaced + var isBase = (node.Object is UnaryExpression u && UnwrapUnaryConvert(u) != u); + + if (TryGetReflectedExpression(methodInfo, isBase, out var reflectedExpression)) { for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) { @@ -198,6 +228,19 @@ protected override Expression VisitMethodCall(MethodCallExpression node) if (mappedArgumentExpression is not null) { + // If the type is different in case of a base call we re-cast it + if(isBase && mappedArgumentExpression.Type != parameterExpression.Type && + mappedArgumentExpression.Type.IsAssignableTo(parameterExpression.Type) && + mappedArgumentExpression is UnaryExpression u2) + { + var unwrapped = UnwrapUnaryConvert(u2); + if (unwrapped != u2) + { + mappedArgumentExpression = Expression.Convert(unwrapped, parameterExpression.Type); + } + } + + _expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpression, mappedArgumentExpression); } } @@ -213,6 +256,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node) return base.VisitMethodCall(node); } + private Expression UnwrapUnaryConvert(UnaryExpression node) + { + if (node.NodeType != ExpressionType.Convert || node.Type != node.Operand.Type.BaseType) + return node; + + if (node.Operand is UnaryExpression u) + return UnwrapUnaryConvert(u); + else + return node.Operand; + } + protected override Expression VisitNew(NewExpression node) { var constructor = node.Constructor; @@ -300,7 +354,11 @@ PropertyInfo property when nodeExpression is not null _ => node.Member }; - if (TryGetReflectedExpression(nodeMember, out var reflectedExpression)) + // Check if we are rewriting a base property ((BaseType)@this).MyProp + var isBase = (node.Expression is UnaryExpression u && u.NodeType == ExpressionType.Convert && + u.Type == u.Operand.Type.BaseType && u.Operand is ParameterExpression p && p.Name == "@this"); + + if (TryGetReflectedExpression(nodeMember, isBase, out var reflectedExpression)) { if (nodeExpression is not null) { diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 21da9a46..e33fe99b 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -9,7 +9,7 @@ namespace EntityFrameworkCore.Projectables.Services { - public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver + public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver, IProjectionExpressionBaseResolver { // We never store null in the dictionary; assemblies without a registry use a sentinel delegate. private readonly static Func _nullRegistry = static _ => null!; @@ -20,6 +20,7 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver /// EF Core never repeats reflection work for the same member across queries. /// private readonly static ConcurrentDictionary _expressionCache = new(); + private readonly static ConcurrentDictionary _expressionBaseCache = new(); /// /// Caches → C#-formatted name strings, since the same parameter types @@ -78,16 +79,22 @@ public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo, ProjectableAttribute? projectableAttribute = null) - => _expressionCache.GetOrAdd(projectableMemberInfo, static (mi, a) => ResolveExpressionCore(mi, a), + => _expressionCache.GetOrAdd(projectableMemberInfo, static (mi, a) => ResolveExpressionCore(mi, false, a), projectableAttribute); - private static LambdaExpression ResolveExpressionCore(MemberInfo projectableMemberInfo, + public LambdaExpression FindGeneratedBaseExpression(MemberInfo projectableMemberInfo, ProjectableAttribute? projectableAttribute = null) + => _expressionBaseCache.GetOrAdd(projectableMemberInfo, static (mi, a) => ResolveExpressionCore(mi, true, a), + projectableAttribute); + + private static LambdaExpression ResolveExpressionCore(MemberInfo projectableMemberInfo, + bool isBase, + ProjectableAttribute? projectableAttribute) { projectableAttribute ??= projectableMemberInfo.GetCustomAttribute() ?? throw new InvalidOperationException("Expected member to have a Projectable attribute. None found"); - var expression = GetExpressionFromGeneratedType(projectableMemberInfo); + var expression = GetExpressionFromGeneratedType(projectableMemberInfo, isBase); if (expression is null && projectableAttribute.UseMemberBody is not null) { @@ -108,20 +115,28 @@ private static LambdaExpression ResolveExpressionCore(MemberInfo projectableMemb throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}."); } - private static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo) + private static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo, bool isBase) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); - // Fast path: check the per-assembly static registry (generated by source generator). + // Fast path (isBase=false): check the per-assembly static registry (generated by source generator). // The first call per assembly does a reflection lookup to find the registry class and // caches it as a delegate; subsequent calls use the cached delegate for an O(1) dictionary lookup. - var registry = GetAssemblyRegistry(declaringType.Assembly); - var registeredExpr = registry?.Invoke(projectableMemberInfo); + LambdaExpression? registeredExpr; + if (!isBase) + { + var registry = GetAssemblyRegistry(declaringType.Assembly); + registeredExpr = registry?.Invoke(projectableMemberInfo); + } + else + { + registeredExpr = null; // We don't have a registry for base expressions for now + } return registeredExpr ?? // Slow path: reflection fallback for open-generic class members and generic methods // that are not yet in the registry. - FindGeneratedExpressionViaReflection(projectableMemberInfo); + FindGeneratedExpressionViaReflection(projectableMemberInfo, isBase); } private static LambdaExpression? GetExpressionFromMemberBody(MemberInfo projectableMemberInfo, string memberName) @@ -217,6 +232,7 @@ private static bool ParameterTypesMatch( /// significantly more expensive to build than simple method-body trees. /// private readonly static ConcurrentDictionary _reflectionCache = new(); + private readonly static ConcurrentDictionary _reflectionBaseCache = new(); /// /// Resolves the for a [Projectable] member using the @@ -227,8 +243,12 @@ private static bool ParameterTypesMatch( /// public static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo) { - var result = _reflectionCache.GetOrAdd(projectableMemberInfo, - static mi => BuildReflectionExpression(mi) ?? _reflectionNullSentinel); + return FindGeneratedExpressionViaReflection(projectableMemberInfo, false); + } + private static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo, bool isBase) + { + var result = (isBase ? _reflectionBaseCache : _reflectionCache).GetOrAdd(projectableMemberInfo, + mi => BuildReflectionExpression(mi, isBase) ?? _reflectionNullSentinel); return ReferenceEquals(result, _reflectionNullSentinel) ? null : result; } @@ -244,7 +264,7 @@ private static bool ParameterTypesMatch( /// instance is ultimately stored per member. /// /// - private static LambdaExpression? BuildReflectionExpression(MemberInfo projectableMemberInfo) + private static LambdaExpression? BuildReflectionExpression(MemberInfo projectableMemberInfo, bool isBase) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); @@ -326,6 +346,9 @@ private static bool ParameterTypesMatch( memberLookupName, parameterTypeNames); + if (isBase) + generatedContainingTypeName = generatedContainingTypeName + "_Base"; + var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); if (expressionFactoryType is null) diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_Hierarchy.verified.txt new file mode 100644 index 00000000..019208a1 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_Hierarchy.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyMultiple.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyMultiple.verified.txt new file mode 100644 index 00000000..64e49e09 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyMultiple.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNestedWithoutAttribute.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNestedWithoutAttribute.verified.txt new file mode 100644 index 00000000..7958b53a --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNestedWithoutAttribute.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNotNestedWithAttribute.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNotNestedWithAttribute.verified.txt new file mode 100644 index 00000000..44aa798b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.BlockBodiedMethod_HierarchyNotNestedWithAttribute.verified.txt @@ -0,0 +1,81 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Bar_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Bar @this) => 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Bar_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Bar @this) => @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.cs index 2ea9fe0b..ebf0fd3c 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/BlockBodyTests.cs @@ -911,4 +911,155 @@ public static bool IsTerminal(this Entity entity) return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + + [Fact] + public Task BlockBodiedMethod_Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable(AllowBlockBody = true)] + public virtual int Id(){ + return 1; + } + } + + public class Bar : Foo { + override public int Id(){ + return 2; + } + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task BlockBodiedMethod_HierarchyMultiple() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable(AllowBlockBody = true)] + public virtual int Id(){ + return 1; + } + } + + public class Bar : Foo { + override public int Id(){ + return 2; + } + } + + public class Baz : Foo { + override public int Id(){ + return 3; + } + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task BlockBodiedMethod_HierarchyNotNestedWithAttribute() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable(AllowBlockBody = true)] + public virtual int Id(){ + return 1; + } + } + + public class Bar : Foo { + [Projectable(AllowBlockBody = true)] + override public int Id(){ + return 2; + } + } + + public class Baz : Bar { + override public int Id(){ + return 3; + } + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(4, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task BlockBodiedMethod_HierarchyNestedWithoutAttribute() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable(AllowBlockBody = true)] + public virtual int Id(){ + return 1; + } + } + + public class Bar : Foo { + override public int Id(){ + return 2; + } + } + + public class Baz : Bar { + override public int Id(){ + return 3; + } + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } } \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMemberExplicitReference.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMemberExplicitReference.verified.txt index 6544cf94..cd002e78 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMemberExplicitReference.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMemberExplicitReference.verified.txt @@ -10,7 +10,7 @@ namespace EntityFrameworkCore.Projectables.Generated { static global::System.Linq.Expressions.Expression> Expression() { - return (global::Projectables.Repro.Derived @this) => @this.Foo; + return (global::Projectables.Repro.Derived @this) => ((global::Projectables.Repro.Base)@this).Foo; } } } \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMethodExplicitReference.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMethodExplicitReference.verified.txt index 56ee3532..72dbe3b3 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMethodExplicitReference.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.BaseMethodExplicitReference.verified.txt @@ -10,7 +10,7 @@ namespace EntityFrameworkCore.Projectables.Generated { static global::System.Linq.Expressions.Expression> Expression() { - return (global::Projectables.Repro.Derived @this) => @this.Foo(); + return (global::Projectables.Repro.Derived @this) => ((global::Projectables.Repro.Base)@this).Foo(); } } } \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.Hierarchy.verified.txt new file mode 100644 index 00000000..019208a1 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.Hierarchy.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstract.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstract.verified.txt new file mode 100644 index 00000000..dfca5ebd --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstract.verified.txt @@ -0,0 +1,19 @@ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => ((global::Foo.Bar)@this).Id(); + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstractMultiple.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstractMultiple.verified.txt new file mode 100644 index 00000000..651d6ffe --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyAbstractMultiple.verified.txt @@ -0,0 +1,19 @@ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : ((global::Foo.Bar)@this).Id(); + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyBase.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyBase.verified.txt new file mode 100644 index 00000000..6ac79c7c --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyBase.verified.txt @@ -0,0 +1,61 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Bar_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Bar @this) => true ? 2 : ((global::Foo.Foo)@this).Id(); + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyMultiple.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyMultiple.verified.txt new file mode 100644 index 00000000..64e49e09 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyMultiple.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNestedWithoutAttribute.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNestedWithoutAttribute.verified.txt new file mode 100644 index 00000000..7958b53a --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNestedWithoutAttribute.verified.txt @@ -0,0 +1,41 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNotNestedWithAttribute.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNotNestedWithAttribute.verified.txt new file mode 100644 index 00000000..44aa798b --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.HierarchyNotNestedWithAttribute.verified.txt @@ -0,0 +1,81 @@ +[ +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Bar_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Bar @this) => 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Bar_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Bar @this) => @this is global::Foo.Baz ? ((global::Foo.Baz)@this).Id() : 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 1; + } + } +} + +// +#nullable disable +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 1; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.cs index cc5e1ed1..36353157 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/MethodTests.cs @@ -799,4 +799,275 @@ class Bar { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + + [Fact] + public Task Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable] + public virtual int Id() => 1; + } + + public class Bar : Foo { + override public int Id() => 2; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task HierarchyMultiple() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable] + public virtual int Id() => 1; + } + + public class Bar : Foo { + override public int Id() => 2; + } + + public class Baz : Foo { + override public int Id() => 3; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task HierarchyNotNestedWithAttribute() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable] + public virtual int Id() => 1; + } + + public class Bar : Foo { + [Projectable] + override public int Id() => 2; + } + + public class Baz : Bar { + override public int Id() => 3; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(4, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task HierarchyNestedWithoutAttribute() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable] + public virtual int Id() => 1; + } + + public class Bar : Foo { + override public int Id() => 2; + } + + public class Baz : Bar { + override public int Id() => 3; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + + [Fact] + public Task HierarchyAbstract() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public abstract class Foo { + [Projectable] + public abstract int Id(); + } + + public class Bar : Foo { + override public int Id() => 2; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Single(result.GeneratedTrees); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [Fact] + public Task HierarchyAbstractMultiple() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public abstract class Foo { + [Projectable] + public abstract int Id(); + } + + public class Bar : Foo { + override public int Id() => 2; + } + + public class Baz : Bar { + override public int Id() => 3; + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Single(result.GeneratedTrees); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [Fact] + public void HierarchyAbstractWithNoDerived() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public abstract class Foo { + [Projectable] + public abstract int Id(); + } +} +"); + + var result = RunGenerator(compilation); + + var diag = Assert.Single(result.Diagnostics); + Assert.Equal("EFP0006", diag.Id); + Assert.Equal(DiagnosticSeverity.Error, diag.Severity); + } + + [Fact] + public void HierarchyAbstractWithNoDerivedOverwritten() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public abstract class Foo { + [Projectable] + public abstract int Id(); + } + + public abstract class Bar : Foo { } +} +"); + + var result = RunGenerator(compilation); + + var diag = Assert.Single(result.Diagnostics); + Assert.Equal("EFP0006", diag.Id); + Assert.Equal(DiagnosticSeverity.Error, diag.Severity); + } + + [Fact] + public Task HierarchyBase() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using System.Collections.Generic; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public class Foo { + [Projectable] + public virtual int Id() => 1; + } + + public class Bar : Foo { + [Projectable] + override public int Id() => true ? 2 : base.Id(); + } +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(3, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } } \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesExpressionPropertyBody_Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesExpressionPropertyBody_Hierarchy.verified.txt new file mode 100644 index 00000000..9725ed0c --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesExpressionPropertyBody_Hierarchy.verified.txt @@ -0,0 +1,39 @@ +[ +// +#nullable disable +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 2; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesMethodBody_Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesMethodBody_Hierarchy.verified.txt new file mode 100644 index 00000000..ca113d15 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Method_UsesMethodBody_Hierarchy.verified.txt @@ -0,0 +1,37 @@ +[ +// +#nullable disable +using System; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 2; + } + } +} + +// +#nullable disable +using System; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id() : 2; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesExpressionPropertyBody_Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesExpressionPropertyBody_Hierarchy.verified.txt new file mode 100644 index 00000000..0a41464f --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesExpressionPropertyBody_Hierarchy.verified.txt @@ -0,0 +1,39 @@ +[ +// +#nullable disable +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 2; + } + } +} + +// +#nullable disable +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id : 2; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesPropertyBody_Hierarchy.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesPropertyBody_Hierarchy.verified.txt new file mode 100644 index 00000000..e77af574 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.Property_UsesPropertyBody_Hierarchy.verified.txt @@ -0,0 +1,37 @@ +[ +// +#nullable disable +using System; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id_Base + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => 2; + } + } +} + +// +#nullable disable +using System; +using EntityFrameworkCore.Projectables; +using Foo; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_Foo_Id + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Foo.Foo @this) => @this is global::Foo.Bar ? ((global::Foo.Bar)@this).Id : 2; + } + } +} +] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.cs index 1885e655..01e7ce2d 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/UseMemberBodyTests.cs @@ -42,6 +42,33 @@ class C { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task Method_UsesMethodBody_Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +namespace Foo { + public class Foo { + [Projectable(UseMemberBody = nameof(IdImpl))] + public virtual int Id() => 1; + + private int IdImpl() => 2; + } + + public class Bar : Foo { + override public int Id() => 3; + } +} +"); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + [Fact] public Task Method_UsesExpressionPropertyBody_StaticExtension() { @@ -211,6 +238,34 @@ private static Expression> IsPositiveExpr { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task Method_UsesExpressionPropertyBody_Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +namespace Foo { + public class Foo { + [Projectable(UseMemberBody = nameof(IdImpl))] + public virtual int Id() => 1; + + private static Expression> IdImpl => @this => 2; + } + + public class Bar : Foo { + override public int Id() => 3; + } +} +"); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + [Fact] public Task Property_UsesPropertyBody_SameType() { @@ -236,6 +291,33 @@ class C { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task Property_UsesPropertyBody_Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using EntityFrameworkCore.Projectables; +namespace Foo { + public class Foo { + [Projectable(UseMemberBody = nameof(IdImpl))] + public virtual int Id => 1; + + private int IdImpl => 2; + } + + public class Bar : Foo { + override public int Id => 3; + } +} +"); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + [Fact] public Task StaticMethod_UsesStaticMethodBody() { @@ -285,6 +367,34 @@ class C { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task Property_UsesExpressionPropertyBody_Hierarchy() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables; +namespace Foo { + public class Foo { + [Projectable(UseMemberBody = nameof(IdImpl))] + public virtual int Id => 1; + + private static Expression> IdImpl => @this => 2; + } + + public class Bar : Foo { + override public int Id => 3; + } +} +"); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(2, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees.OrderBy(t => t.FilePath).Select(t => t.ToString())); + } + [Fact] public void Property_UsesExpressionPropertyBody_IncompatibleReturnType_EmitsEFP0011() { diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs index 19d58391..1eda5954 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs @@ -209,5 +209,79 @@ private sealed class FakeClosureWithIQueryableProperty { public IQueryable? Items { get; set; } } + + public class ProjectableExpressionResolverStubBase : IProjectionExpressionResolver, IProjectionExpressionBaseResolver + { + readonly Func _implementation; + readonly Func _implementationBase; + + public ProjectableExpressionResolverStubBase(Func implementation, + Func implementationBase) + { + _implementation = implementation; + _implementationBase = implementationBase; + } + + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null) => _implementation(projectableMemberInfo, projectableAttribute); + public LambdaExpression FindGeneratedBaseExpression(MemberInfo projectableMemberInfo, + ProjectableAttribute? projectableAttribute = null) => _implementationBase(projectableMemberInfo, projectableAttribute); + } + + class Foo + { + [Projectable] + public virtual int VirtualProperty => 1; + + [Projectable] + public virtual int VirtualMethod() => 1; + } + + class Bar : Foo + { + [Projectable] + override public int VirtualProperty => true ? 2 : base.VirtualProperty; + + [Projectable] + override public int VirtualMethod() => true ? 2 : base.VirtualProperty; + } + + [Fact] + public void VisitMember_HierarchyBaseProperty() + { + Expression> input = x => x.VirtualProperty; + Expression> expectedFooBase = x => 1; + Expression> expectedBar = x => true ? 2 : ((Foo)x).VirtualProperty; + Expression> expectedFoo = x => x is Bar ? true ? 2 : 1 : 1; + + var resolver = new ProjectableExpressionResolverStubBase( + (x, a) => x.DeclaringType == typeof(Foo) ? expectedFoo : expectedBar, + (x, a) => expectedFooBase + ); + var subject = new ProjectableExpressionReplacer(resolver); + + var actual = subject.Replace(input); + + Assert.Equal(expectedFoo.ToString(), actual.ToString()); + } + + [Fact] + public void VisitMember_HierarchyBaseMethod() + { + Expression> input = x => x.VirtualMethod(); + Expression> expectedFooBase = x => 1; + Expression> expectedBar = x => true ? 2 : ((Foo)x).VirtualMethod(); + Expression> expectedFoo = x => x is Bar ? true ? 2 : 1 : 1; + + var resolver = new ProjectableExpressionResolverStubBase( + (x, a) => x.DeclaringType == typeof(Foo) ? expectedFoo : expectedBar, + (x, a) => expectedFooBase + ); + var subject = new ProjectableExpressionReplacer(resolver); + + var actual = subject.Replace(input); + + Assert.Equal(expectedFoo.ToString(), actual.ToString()); + } } }