Skip to content

Commit

Permalink
FromSql work (#34065)
Browse files Browse the repository at this point in the history
Closes #34064
Closes #29177
  • Loading branch information
roji committed Jun 25, 2024
1 parent 2295738 commit 5ac6eca
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 87 deletions.
62 changes: 56 additions & 6 deletions src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,46 @@ source.Provider is EntityQueryProvider
: source;
}

/// <summary>
/// Creates a LINQ query based on an interpolated string representing a SQL query.
/// </summary>
/// <remarks>
/// <para>
/// If the database provider supports composing on the supplied SQL, you can compose on top of the raw SQL query using
/// LINQ operators.
/// </para>
/// <para>
/// As with any API that accepts SQL it is important to parameterize any user input to protect against a SQL injection
/// attack. You can include interpolated parameter place holders in the SQL query string. Any interpolated parameter values
/// you supply will automatically be converted to a Cosmos parameter.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-raw-sql">Executing raw SQL commands with EF Core</see>
/// for more information and examples.
/// </para>
/// </remarks>
/// <typeparam name="TEntity">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">
/// An <see cref="IQueryable{T}" /> to use as the base of the interpolated string SQL query (typically a <see cref="DbSet{TEntity}" />).
/// </param>
/// <param name="sql">The interpolated string representing a SQL query with parameters.</param>
/// <returns>An <see cref="IQueryable{T}" /> representing the interpolated string SQL query.</returns>
public static IQueryable<TEntity> FromSql<TEntity>(
this DbSet<TEntity> source,
[NotParameterized] FormattableString sql)
where TEntity : class
{
Check.NotNull(sql, nameof(sql));
Check.NotEmpty(sql.Format, nameof(source));

var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
GenerateFromSqlQueryRoot(
queryableSource,
sql.Format,
sql.GetArguments()));
}

/// <summary>
/// Creates a LINQ query based on a raw SQL query.
/// </summary>
Expand Down Expand Up @@ -103,14 +143,26 @@ source.Provider is EntityQueryProvider
public static IQueryable<TEntity> FromSqlRaw<TEntity>(
this DbSet<TEntity> source,
[NotParameterized] string sql,
params object[] parameters)
params object?[] parameters)
where TEntity : class
{
Check.NotEmpty(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

var queryableSource = (IQueryable)source;
var entityQueryRootExpression = (EntityQueryRootExpression)queryableSource.Expression;
return queryableSource.Provider.CreateQuery<TEntity>(
GenerateFromSqlQueryRoot(
queryableSource,
sql,
parameters));
}

private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
IQueryable source,
string sql,
object?[] arguments)
{
var entityQueryRootExpression = (EntityQueryRootExpression)source.Expression;

var entityType = entityQueryRootExpression.EntityType;

Expand All @@ -119,12 +171,10 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
|| entityType.FindDiscriminatorProperty() is not null,
"Found FromSql on a TPT entity type, but TPT isn't supported on Cosmos");

var fromSqlQueryRootExpression = new FromSqlQueryRootExpression(
return new FromSqlQueryRootExpression(
entityQueryRootExpression.QueryProvider!,
entityType,
sql,
Expression.Constant(parameters));

return queryableSource.Provider.CreateQuery<TEntity>(fromSqlQueryRootExpression);
Expression.Constant(arguments));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ public static Task MigrateAsync(
public static int ExecuteSqlRaw(
this DatabaseFacade databaseFacade,
string sql,
params object[] parameters)
=> ExecuteSqlRaw(databaseFacade, sql, (IEnumerable<object>)parameters);
params object?[] parameters)
=> ExecuteSqlRaw(databaseFacade, sql, (IEnumerable<object?>)parameters);

/// <summary>
/// Executes the given SQL against the database and returns the number of rows affected.
Expand Down Expand Up @@ -211,7 +211,7 @@ public static int ExecuteSqlRaw(
public static int ExecuteSqlInterpolated(
this DatabaseFacade databaseFacade,
FormattableString sql)
=> ExecuteSqlRaw(databaseFacade, sql.Format, sql.GetArguments()!);
=> ExecuteSqlRaw(databaseFacade, sql.Format, sql.GetArguments());

/// <summary>
/// Executes the given SQL against the database and returns the number of rows affected.
Expand Down Expand Up @@ -243,7 +243,7 @@ public static int ExecuteSqlInterpolated(
public static int ExecuteSql(
this DatabaseFacade databaseFacade,
FormattableString sql)
=> ExecuteSqlRaw(databaseFacade, sql.Format, sql.GetArguments()!);
=> ExecuteSqlRaw(databaseFacade, sql.Format, sql.GetArguments());

/// <summary>
/// Executes the given SQL against the database and returns the number of rows affected.
Expand Down Expand Up @@ -281,7 +281,7 @@ public static int ExecuteSql(
public static int ExecuteSqlRaw(
this DatabaseFacade databaseFacade,
string sql,
IEnumerable<object> parameters)
IEnumerable<object?> parameters)
{
Check.NotNull(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static DbCommand CreateDbCommand(this IQueryable source)
public static IQueryable<TEntity> FromSqlRaw<TEntity>(
this DbSet<TEntity> source,
[NotParameterized] string sql,
params object[] parameters)
params object?[] parameters)
where TEntity : class
{
Check.NotEmpty(sql, nameof(sql));
Expand Down Expand Up @@ -182,14 +182,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
sql.GetArguments()));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
public static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
IQueryable source,
string sql,
object?[] arguments,
Expand Down
9 changes: 2 additions & 7 deletions src/EFCore.Relational/Storage/IRawSqlCommandBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ public interface IRawSqlCommandBuilder
/// <param name="sql">The command text.</param>
/// <param name="parameters">Parameters for the command.</param>
/// <returns>The newly created command.</returns>
RawSqlCommand Build(
string sql,
IEnumerable<object> parameters);
RawSqlCommand Build(string sql, IEnumerable<object?> parameters);

/// <summary>
/// Creates a new command based on SQL command text.
Expand All @@ -49,8 +47,5 @@ RawSqlCommand Build(
/// <param name="parameters">Parameters for the command.</param>
/// <param name="model">The model.</param>
/// <returns>The newly created command.</returns>
RawSqlCommand Build(
string sql,
IEnumerable<object> parameters,
IModel model);
RawSqlCommand Build(string sql, IEnumerable<object?> parameters, IModel model);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public virtual IRelationalCommand Build(string sql)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual RawSqlCommand Build(string sql, IEnumerable<object> parameters)
public virtual RawSqlCommand Build(string sql, IEnumerable<object?> parameters)
=> Build(sql, parameters, null);

/// <summary>
Expand All @@ -61,7 +61,7 @@ public virtual RawSqlCommand Build(string sql, IEnumerable<object> parameters)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual RawSqlCommand Build(string sql, IEnumerable<object> parameters, IModel? model)
public virtual RawSqlCommand Build(string sql, IEnumerable<object?> parameters, IModel? model)
{
var relationalCommandBuilder = _relationalCommandBuilderFactory.Create();

Expand Down Expand Up @@ -89,14 +89,14 @@ public virtual RawSqlCommand Build(string sql, IEnumerable<object> parameters, I
var substitutedName = _sqlGenerationHelper.GenerateParameterName(parameterName);

substitutions.Add(substitutedName);
var typeMapping = parameter == null
var typeMapping = parameter is null
? model == null
? _typeMappingSource.GetMappingForValue(null)
: _typeMappingSource.GetMappingForValue(null, model)
: model == null
? _typeMappingSource.GetMapping(parameter.GetType())
: _typeMappingSource.GetMapping(parameter.GetType(), model);
var nullable = parameter == null || parameter.GetType().IsNullableType();
var nullable = parameter is null || parameter.GetType().IsNullableType();

relationalCommandBuilder.AddParameter(parameterName, substitutedName, typeMapping, nullable);
parameterValues.Add(parameterName, parameter);
Expand Down
Loading

0 comments on commit 5ac6eca

Please sign in to comment.