Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/EntityFramework/DbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace System.Data.Entity
using System.Data.Entity.Utilities;
using System.Data.Entity.Validation;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -63,7 +64,8 @@ public class DbContext : IDisposable, IObjectContextAdapter

private Database _database;


private ForceDateTimeTypeAttribute _forcedDateTimeTypeAttribute;

/// <summary>
/// Constructs a new context instance using conventions to create the name of the database to
/// which a connection will be made. The by-convention name is the full name (namespace + class name)
Expand All @@ -76,7 +78,7 @@ protected DbContext()
{
InitializeLazyInternalContext(new LazyInternalConnection(this, GetType().DatabaseName()));
}

/// <summary>
/// Constructs a new context instance using conventions to create the name of the database to
/// which a connection will be made, and initializes it from the given model.
Expand Down Expand Up @@ -205,6 +207,20 @@ private void DiscoverAndInitializeSets()
new DbSetDiscoveryService(this).InitializeSets();
}

internal DbType? ForcedDateTimeType
{
get
{
if (_forcedDateTimeTypeAttribute == null)
{
_forcedDateTimeTypeAttribute
= GetType().GetCustomAttributes(inherit: false).OfType<ForceDateTimeTypeAttribute>().FirstOrDefault();
}

return _forcedDateTimeTypeAttribute?.DateTimeType;
}
}

#endregion

#region Model building
Expand Down
31 changes: 31 additions & 0 deletions src/EntityFramework/Infrastructure/ForceDateTimeTypeAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.

namespace System.Data.Entity.Infrastructure
{
/// <summary>
/// Place this attribute on a type that inherits from <see cref="System.Data.Entity.DbContext"/> to
/// force all <see cref="System.DateTime"/> parameters to have the given <see cref="System.Data.DbType"/>.
/// </summary>
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false)]
public sealed class ForceDateTimeTypeAttribute : Attribute
{
private readonly DbType _type;

/// <summary>
/// Creates a new <see cref="ForceDateTimeTypeAttribute"/> instance with the given <see cref="System.Data.DbType"/>.
/// </summary>
/// <param name="type"> The type to force. </param>
public ForceDateTimeTypeAttribute(DbType type)
{
_type = type;
}

/// <summary>
/// The <see cref="System.Data.DbType"/> that will be forced for all <see cref="System.DateTime"/> parameters.
/// </summary>
public DbType DateTimeType
{
get { return _type; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace System.Data.Entity.Infrastructure.Interception
using System.Data.Common;
using System.Data.Entity.Utilities;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -53,6 +54,8 @@ public virtual int NonQuery(DbCommand command, DbCommandInterceptionContext inte
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.Dispatch(
command,
(t, c) => t.ExecuteNonQuery(),
Expand All @@ -61,6 +64,22 @@ public virtual int NonQuery(DbCommand command, DbCommandInterceptionContext inte
(i, t, c) => i.NonQueryExecuted(t, c));
}

private static void ForceDateTimeTypes(DbCommand command, DbCommandInterceptionContext interceptionContext)
{
var forcedDateTimeType = interceptionContext.ForcedDateTimeType;
if (forcedDateTimeType.HasValue)
{
foreach (DbParameter parameter in command.Parameters)
{
if (parameter.DbType == DbType.DateTime
|| parameter.DbType == DbType.DateTime2)
{
parameter.DbType = forcedDateTimeType.Value;
}
}
}
}

/// <summary>
/// Sends <see cref="IDbCommandInterceptor.ScalarExecuting" /> and
/// <see cref="IDbCommandInterceptor.ScalarExecuted" /> to any <see cref="IDbCommandInterceptor" />
Expand All @@ -80,6 +99,8 @@ public virtual object Scalar(DbCommand command, DbCommandInterceptionContext int
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.Dispatch(
command,
(t, c) => t.ExecuteScalar(),
Expand Down Expand Up @@ -108,6 +129,8 @@ public virtual DbDataReader Reader(
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.Dispatch(
command,
(t, c) => t.ExecuteReader(c.CommandBehavior),
Expand Down Expand Up @@ -138,6 +161,8 @@ public virtual Task<int> NonQueryAsync(
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.DispatchAsync(
command,
(t, c, ct) => t.ExecuteNonQueryAsync(ct),
Expand Down Expand Up @@ -168,6 +193,8 @@ public virtual Task<object> ScalarAsync(
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.DispatchAsync(
command,
(t, c, ct) => t.ExecuteScalarAsync(ct),
Expand Down Expand Up @@ -198,6 +225,8 @@ public virtual Task<DbDataReader> ReaderAsync(
Check.NotNull(command, "command");
Check.NotNull(interceptionContext, "interceptionContext");

ForceDateTimeTypes(command, interceptionContext);

return _internalDispatcher.DispatchAsync(
command,
(t, c, ct) => t.ExecuteReaderAsync(c.CommandBehavior, ct),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public DbCommandInterceptionContext(DbInterceptionContext copyFrom)
_commandBehavior = asThisType._commandBehavior;
}
}

/// <summary>
/// The <see cref="CommandBehavior" /> that will be used or has been used to execute the command with a
/// <see cref="DbDataReader" />. This property is only used for <see cref="DbCommand.ExecuteReader(System.Data.CommandBehavior)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class DbInterceptionContext
private readonly IList<DbContext> _dbContexts;
private readonly IList<ObjectContext> _objectContexts;
private bool _isAsync;
private DbType? _forcedDateTimeType;

/// <summary>
/// Constructs a new <see cref="DbInterceptionContext" /> with no state.
Expand All @@ -52,6 +53,7 @@ protected DbInterceptionContext(DbInterceptionContext copyFrom)
_dbContexts = copyFrom.DbContexts.Where(c => c.InternalContext == null || !c.InternalContext.IsDisposed).ToList();
_objectContexts = copyFrom.ObjectContexts.Where(c => !c.IsDisposed).ToList();
_isAsync = copyFrom._isAsync;
_forcedDateTimeType = copyFrom._forcedDateTimeType;
}

private DbInterceptionContext(IEnumerable<DbInterceptionContext> copyFrom)
Expand All @@ -69,6 +71,20 @@ private DbInterceptionContext(IEnumerable<DbInterceptionContext> copyFrom)
.Where(c => !c.IsDisposed).ToList();

_isAsync = copyFrom.Any(c => c.IsAsync);

foreach (var context in _dbContexts)
{
var forcedDateTimeType = context.ForcedDateTimeType;
if (forcedDateTimeType.HasValue)
{
_forcedDateTimeType = forcedDateTimeType;
}
}
}

internal DbType? ForcedDateTimeType
{
get { return _forcedDateTimeType; }
}

/// <summary>
Expand Down Expand Up @@ -98,6 +114,12 @@ public DbInterceptionContext WithDbContext(DbContext context)
if (!copy._dbContexts.Contains(context, ObjectReferenceEqualityComparer.Default))
{
copy._dbContexts.Add(context);

var forcedDateTimeType = context.ForcedDateTimeType;
if (forcedDateTimeType.HasValue)
{
copy._forcedDateTimeType = forcedDateTimeType;
}
}
return copy;
}
Expand Down
11 changes: 9 additions & 2 deletions test/EntityFramework/FunctionalTests/Interception/BlogContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static void DoStuff(BlogContext context)
blog.Posts.Add(
new Post
{
Title = "Throw it away..."
Title = "Throw it away...",
Spacetime = new DateTime(1915, 11, 25)
});

ExtendedSqlAzureExecutionStrategy.ExecuteNew(
Expand Down Expand Up @@ -75,6 +76,8 @@ public class Blog
public string Title { get; set; }

public virtual ICollection<Post> Posts { get; set; }

public DateTime TimeDilation { get; set; }
}

public class Post
Expand All @@ -84,6 +87,8 @@ public class Post

public int BlogId { get; set; }
public virtual Blog Blog { get; set; }

public DateTime Spacetime { get; set; }
}

public class BlogInitializer : DropCreateDatabaseAlways<BlogContext>
Expand All @@ -94,9 +99,11 @@ protected override void Seed(BlogContext context)
new Post
{
Title = "Wrap it up...",
Spacetime = new DateTime(1915, 11, 25),
Blog = new Blog
{
Title = "Half a Unicorn"
Title = "Half a Unicorn",
TimeDilation = new DateTime(1905, 6, 30)
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace System.Data.Entity.Interception
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data.Common;
using System.Data.Entity.Infrastructure;
using System.Data.Entity.Infrastructure.Interception;
using System.Data.Entity.TestHelpers;
using System.Data.SqlClient;
Expand All @@ -25,13 +26,21 @@ public CommandInterceptionTests()
[Fact]
[UseDefaultExecutionStrategy]
public void Initialization_and_simple_query_and_update_commands_can_be_logged()
{
CommandInterceptionTest<BlogContextLogAll>();
CommandInterceptionTest<BlogContextDateTime2LogAll>();
CommandInterceptionTest<BlogContextDateTimeLogAll>();
}

private void CommandInterceptionTest<TContextType>()
where TContextType : BlogContext, new()
{
var logger = new CommandLogger();
DbInterception.Add(logger);

try
{
using (var context = new BlogContextLogAll())
using (var context = new TContextType())
{
BlogContext.DoStuff(context);
}
Expand All @@ -52,6 +61,20 @@ public void Initialization_and_simple_query_and_update_commands_can_be_logged()
{
Assert.Equal(method + 1, logger.Log[i + 1].Method);
Assert.Same(logger.Log[i].Command, logger.Log[i + 1].Command);

var parameters = logger.Log[i].Command.Parameters;

var expectedDbType = typeof(TContextType) == typeof(BlogContextDateTimeLogAll)
? DbType.DateTime
: DbType.DateTime2;

foreach (DbParameter parameter in parameters)
{
if (parameter.Value is DateTime)
{
Assert.Equal(expectedDbType, parameter.DbType);
}
}
}
}

Expand All @@ -78,6 +101,23 @@ static BlogContextLogAll()
Database.SetInitializer<BlogContextLogAll>(new BlogInitializer());
}
}
[ForceDateTimeType(DbType.DateTime)]
public class BlogContextDateTimeLogAll : BlogContext
{
static BlogContextDateTimeLogAll()
{
Database.SetInitializer<BlogContextDateTimeLogAll>(new BlogInitializer());
}
}

[ForceDateTimeType(DbType.DateTime2)]
public class BlogContextDateTime2LogAll : BlogContext
{
static BlogContextDateTime2LogAll()
{
Database.SetInitializer<BlogContextDateTime2LogAll>(new BlogInitializer());
}
}

[Fact]
public void Commands_that_result_in_exceptions_are_still_intercepted()
Expand Down
Loading