Skip to content

Commit

Permalink
Query: Disallow FromSql/TVF with TPT
Browse files Browse the repository at this point in the history
- Throw exception when using FromSql* methods
- Throw if a TVF is mapped to TPT
- Throw if construction of SelectExpression with custom TableExpression is done when TPT

Resolves #21508
  • Loading branch information
smitpatel committed Jul 22, 2020
1 parent d0ec185 commit 65c6fac
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 8 deletions.
13 changes: 11 additions & 2 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
Expand Down Expand Up @@ -148,13 +149,21 @@ public static IQueryable<TEntity> FromSqlInterpolated<TEntity>(
private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
IQueryable source,
string sql,
object[] arguments)
object[] arguments,
[CallerMemberName] string memberName = null)
{
var queryRootExpression = (QueryRootExpression)source.Expression;

var entityType = queryRootExpression.EntityType;
if ((entityType.BaseType != null || entityType.GetDirectlyDerivedTypes().Any())
&& entityType.GetDiscriminatorProperty() == null)
{
throw new InvalidOperationException(RelationalStrings.NonTPHOnFromSqlNotSupported(memberName, entityType.DisplayName()));
}

return new FromSqlQueryRootExpression(
queryRootExpression.QueryProvider,
queryRootExpression.EntityType,
entityType,
sql,
Expression.Constant(arguments));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ protected virtual void ValidateDbFunctions(
{
var elementType = dbFunction.ReturnType.GetGenericArguments()[0];
var entityType = model.FindEntityType(elementType);

if (entityType?.IsOwned() == true
|| ((IConventionModel)model).IsOwned(elementType)
|| (entityType == null && model.GetEntityTypes().Any(e => e.ClrType == elementType)))
Expand All @@ -130,6 +131,13 @@ protected virtual void ValidateDbFunctions(
throw new InvalidOperationException(RelationalStrings.DbFunctionInvalidReturnEntityType(
dbFunction.ModelName, dbFunction.ReturnType.ShortDisplayName(), elementType.ShortDisplayName()));
}

if ((entityType.BaseType != null || entityType.GetDerivedTypes().Any())
&& entityType.GetDiscriminatorProperty() == null)
{
throw new InvalidOperationException(
RelationalStrings.TableValuedFunctionNonTPH(dbFunction.ModelName, entityType.DisplayName()));
}
}

foreach (var parameter in dbFunction.Parameters)
Expand Down
25 changes: 24 additions & 1 deletion src/EFCore.Relational/Properties/RelationalStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/EFCore.Relational/Properties/RelationalStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -708,4 +708,13 @@
<value>An error occurred while the batch executor was releasing a transaction savepoint.</value>
<comment>Debug RelationalEventId.BatchExecutorFailedToReleaseSavepoint</comment>
</data>
<data name="NonTPHOnFromSqlNotSupported" xml:space="preserve">
<value>Using '{memberName}' on DbSet of '{entityType}' is not supported since '{entityType}' is part of hierarchy and does not contain a discriminator property.</value>
</data>
<data name="SelectExpressionNonTPHWithCustomTable" xml:space="preserve">
<value>Cannot create 'SelectExpression' with custom 'TableExpressionBase' since result type '{entityType}' is part of hierarchy and does not contain a discriminator property.</value>
</data>
<data name="TableValuedFunctionNonTPH" xml:space="preserve">
<value>The element type of result of '{dbFunction}' is mapped to '{entityType}'. This is not supported since '{entityType}' is part of hierarchy and does not contain a discriminator property.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,7 @@ private static Expression MatchTypes(Expression expression, Type targetType)
if (targetType != expression.Type
&& targetType.TryGetElementType(typeof(IQueryable<>)) == null)
{
if (targetType.MakeNullable() != expression.Type)
{
throw new InvalidFilterCriteriaException();
}
Check.DebugAssert(targetType.MakeNullable() == expression.Type, "expression.Type must be nullable of targetType");

expression = Expression.Convert(expression, targetType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
internal SelectExpression(IEntityType entityType, TableExpressionBase tableExpressionBase)
: base(null)
{
if ((entityType.BaseType != null || entityType.GetDirectlyDerivedTypes().Any())
&& entityType.GetDiscriminatorProperty() == null)
{
throw new InvalidOperationException(RelationalStrings.SelectExpressionNonTPHWithCustomTable(entityType.DisplayName()));
}

var table = tableExpressionBase switch
{
TableExpression tableExpression => tableExpression.Table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@


// ReSharper disable InconsistentNaming
using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.TestModels.InheritanceModel;
using Xunit;

namespace Microsoft.EntityFrameworkCore.Query
Expand All @@ -30,5 +34,19 @@ public TPTInheritanceQueryTestBase(TFixture fixture)

// TPT does not have discriminator
public override Task Discriminator_with_cast_in_shadow_property(bool async) => Task.CompletedTask;

[ConditionalFact]
public virtual void Using_from_sql_throws()
{
using var context = CreateContext();

var message = Assert.Throws<InvalidOperationException>(() => context.Set<Bird>().FromSqlRaw("Select * from Birds")).Message;

Assert.Equal(RelationalStrings.NonTPHOnFromSqlNotSupported("FromSqlRaw", typeof(Bird).Name), message);

message = Assert.Throws<InvalidOperationException>(() => context.Set<Bird>().FromSqlInterpolated($"Select * from Birds")).Message;

Assert.Equal(RelationalStrings.NonTPHOnFromSqlNotSupported("FromSqlInterpolated", typeof(Bird).Name), message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,19 @@ public void Detects_named_index_properties_mapped_to_different_tables_in_TPT_hie
LogLevel.Error);
}

[ConditionalFact]
public virtual void Non_TPH_as_a_result_of_DbFunction_throws()
{
var modelBuilder = CreateConventionalModelBuilder();
modelBuilder.Entity<A>().ToTable("A").HasNoDiscriminator();
modelBuilder.Entity<C>().ToTable("C");

modelBuilder.HasDbFunction(TestMethods.MethodFMi);

VerifyError(RelationalStrings.TableValuedFunctionNonTPH(
TestMethods.MethodFMi.DeclaringType.FullName + "." + TestMethods.MethodFMi.Name + "()", "C"), modelBuilder.Model);
}

private static void GenerateMapping(IMutableProperty property)
=> property[CoreAnnotationNames.TypeMapping]
= new TestRelationalTypeMappingSource(
Expand Down Expand Up @@ -1712,7 +1725,7 @@ protected class Employee : Person

public class TestDecimalToLongConverter : ValueConverter<decimal, long>
{
private static readonly Expression<Func<decimal, long>> convertToProviderExpression = d => (long)(d*100);
private static readonly Expression<Func<decimal, long>> convertToProviderExpression = d => (long)(d * 100);
private static readonly Expression<Func<long, decimal>> convertFromProviderExpression = l => l / 100m;

public TestDecimalToLongConverter()
Expand Down Expand Up @@ -1750,12 +1763,14 @@ private class TestMethods : BaseTestMethods
public static readonly MethodInfo MethodCMi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodC));
public static readonly MethodInfo MethodDMi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodD));
public static readonly MethodInfo MethodEMi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodE));
public static readonly MethodInfo MethodFMi = typeof(TestMethods).GetTypeInfo().GetDeclaredMethod(nameof(TestMethods.MethodF));

public static IQueryable<TestMethods> MethodA() => throw new NotImplementedException();
public static IQueryable<TestMethods> MethodB(int id) => throw new NotImplementedException();
public static TestMethods MethodC() => throw new NotImplementedException();
public static int MethodD(TestMethods methods) => throw new NotImplementedException();
public static int MethodE() => throw new NotImplementedException();
public static IQueryable<C> MethodF() => throw new NotImplementedException();
}

protected virtual ModelBuilder CreateModelBuilderWithoutConvention<T>(
Expand Down

0 comments on commit 65c6fac

Please sign in to comment.