Skip to content

Commit

Permalink
Move main savepoint logic up to RelationalTransaction
Browse files Browse the repository at this point in the history
Subclasses of RelationalTransaction now only have to provide the SQL
to manage savepoints, and RelationalTransaction does the rest. Default
SQL is defined to support most databases.
  • Loading branch information
roji committed Jun 3, 2020
1 parent 0065e1d commit 5ae2e35
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 183 deletions.
116 changes: 92 additions & 24 deletions src/EFCore.Relational/Storage/RelationalTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,10 @@ public RelationalTransaction(
/// </summary>
protected virtual IDiagnosticsLogger<DbLoggerCategory.Database.Transaction> Logger { get; }

/// <summary>
/// A correlation ID that allows this transaction to be identified and
/// correlated across multiple database calls.
/// </summary>
/// <inheritdoc />
public virtual Guid TransactionId { get; }

/// <summary>
/// Commits all changes made to the database in the current transaction.
/// </summary>
/// <inheritdoc />
public virtual void Commit()
{
var startTime = DateTimeOffset.UtcNow;
Expand Down Expand Up @@ -125,9 +120,7 @@ public virtual void Commit()
ClearTransaction();
}

/// <summary>
/// Discards all changes made to the database in the current transaction.
/// </summary>
/// <inheritdoc />
public virtual void Rollback()
{
var startTime = DateTimeOffset.UtcNow;
Expand Down Expand Up @@ -170,11 +163,7 @@ public virtual void Rollback()
ClearTransaction();
}

/// <summary>
/// Commits all changes made to the database in the current transaction asynchronously.
/// </summary>
/// <param name="cancellationToken"> The cancellation token. </param>
/// <returns> A <see cref="Task" /> representing the asynchronous operation. </returns>
/// <inheritdoc />
public virtual async Task CommitAsync(CancellationToken cancellationToken = default)
{
var startTime = DateTimeOffset.UtcNow;
Expand Down Expand Up @@ -223,11 +212,7 @@ await Logger.TransactionErrorAsync(
await ClearTransactionAsync(cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Discards all changes made to the database in the current transaction asynchronously.
/// </summary>
/// <param name="cancellationToken"> The cancellation token. </param>
/// <returns> A <see cref="Task" /> representing the asynchronous operation. </returns>
/// <inheritdoc />
public virtual async Task RollbackAsync(CancellationToken cancellationToken = default)
{
var startTime = DateTimeOffset.UtcNow;
Expand Down Expand Up @@ -276,9 +261,94 @@ await Logger.TransactionErrorAsync(
await ClearTransactionAsync(cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public virtual void Save(string savepointName)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointSaveSql(savepointName);
command.ExecuteNonQuery();
}

/// <inheritdoc />
public virtual Task SaveAsync(string savepointName, CancellationToken cancellationToken = default)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointSaveSql(savepointName);
return command.ExecuteNonQueryAsync(cancellationToken);
}

/// <summary>
/// When implemented in a provider supporting transaction savepoints, this method should return an
/// SQL statement which creates a savepoint with the given name.
/// </summary>
/// <param name="name"> The name of the savepoint to be created. </param>
/// <returns> An SQL string to create the savepoint. </returns>
protected virtual string GetSavepointSaveSql([NotNull] string name) => "SAVEPOINT " + name;

/// <inheritdoc />
public virtual void Rollback(string savepointName)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointRollbackSql(savepointName);
command.ExecuteNonQuery();
}

/// <inheritdoc />
public virtual Task RollbackAsync(string savepointName, CancellationToken cancellationToken = default)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointRollbackSql(savepointName);
return command.ExecuteNonQueryAsync(cancellationToken);
}

/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// When implemented in a provider supporting transaction savepoints, this method should return an
/// SQL statement which rolls back a savepoint with the given name.
/// </summary>
/// <param name="name"> The name of the savepoint to be created. </param>
/// <returns> An SQL string to create the savepoint. </returns>
protected virtual string GetSavepointRollbackSql([NotNull] string name) => "ROLLBACK TO " + name;

/// <inheritdoc />
public virtual void Release(string savepointName)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointReleaseSql(savepointName);
command.ExecuteNonQuery();
}

/// <inheritdoc />
public virtual Task ReleaseAsync(string savepointName, CancellationToken cancellationToken = default)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = GetSavepointReleaseSql(savepointName);
return command.ExecuteNonQueryAsync(cancellationToken);
}

/// <summary>
/// <para>
/// When implemented in a provider supporting transaction savepoints, this method should return an
/// SQL statement which releases a savepoint with the given name.
/// </para>
/// <para>
/// If savepoint release isn't supported, <see cref="Release "/> and <see cref="ReleaseAsync "/> should
/// be overridden to do nothing.
/// </para>
/// </summary>
/// <param name="name"> The name of the savepoint to be created. </param>
/// <returns> An SQL string to create the savepoint. </returns>
protected virtual string GetSavepointReleaseSql([NotNull] string name) => "RELEASE SAVEPOINT " + name;

/// <inheritdoc />
public virtual bool AreSavepointsSupported => true;

/// <inheritdoc />
public virtual void Dispose()
{
if (!_disposed)
Expand All @@ -300,9 +370,7 @@ public virtual void Dispose()
}
}

/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
/// <inheritdoc />
public virtual async ValueTask DisposeAsync()
{
if (!_disposed)
Expand Down
39 changes: 6 additions & 33 deletions src/EFCore.SqlServer/Storage/Internal/SqlServerTransaction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class SqlServerTransaction : RelationalTransaction, IDbContextTransaction
public class SqlServerTransaction : RelationalTransaction
{
private readonly DbTransaction _dbTransaction;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -34,45 +32,20 @@ public SqlServerTransaction(
[NotNull] IDiagnosticsLogger<DbLoggerCategory.Database.Transaction> logger,
bool transactionOwned)
: base(connection, transaction, transactionId, logger, transactionOwned)
=> _dbTransaction = transaction;

/// <inheritdoc />
public virtual void Save(string savepointName)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = "SAVE TRANSACTION " + savepointName;
command.ExecuteNonQuery();
}

/// <inheritdoc />
public virtual async Task SaveAsync(string savepointName, CancellationToken cancellationToken = default)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = "SAVE TRANSACTION " + savepointName;
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
protected override string GetSavepointSaveSql(string name) => "SAVE TRANSACTION " + name;

/// <inheritdoc />
public virtual void Rollback(string savepointName)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = "ROLLBACK TRANSACTION " + savepointName;
command.ExecuteNonQuery();
}
protected override string GetSavepointRollbackSql(string name) => "ROLLBACK TRANSACTION " + name;

/// <inheritdoc />
public virtual async Task RollbackAsync(string savepointName, CancellationToken cancellationToken = default)
{
using var command = Connection.DbConnection.CreateCommand();
command.Transaction = _dbTransaction;
command.CommandText = "ROLLBACK TRANSACTION " + savepointName;
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
public override void Release(string savepointName) {}

/// <inheritdoc />
public virtual bool AreSavepointsSupported => true;
public override Task ReleaseAsync(string savepointName, CancellationToken cancellationToken = default)
=> Task.CompletedTask;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public static IServiceCollection AddEntityFrameworkSqlite([NotNull] this IServic
.TryAdd<IModelValidator, SqliteModelValidator>()
.TryAdd<IProviderConventionSetBuilder, SqliteConventionSetBuilder>()
.TryAdd<IUpdateSqlGenerator, SqliteUpdateSqlGenerator>()
.TryAdd<IRelationalTransactionFactory, SqliteTransactionFactory>()
.TryAdd<IModificationCommandBatchFactory, SqliteModificationCommandBatchFactory>()
.TryAdd<IRelationalConnection>(p => p.GetService<ISqliteRelationalConnection>())
.TryAdd<IMigrationsSqlGenerator, SqliteMigrationsSqlGenerator>()
Expand Down
96 changes: 0 additions & 96 deletions src/EFCore.Sqlite.Core/Storage/Internal/SqliteTransaction.cs

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,7 @@ public override async Task CommitAsync(CancellationToken cancellationToken = def

await base.CommitAsync(cancellationToken);
}

public override bool AreSavepointsSupported => false;
}
}

0 comments on commit 5ae2e35

Please sign in to comment.