diff --git a/src/EntityFramework/DbContext.cs b/src/EntityFramework/DbContext.cs index 1675919303..9421c37639 100644 --- a/src/EntityFramework/DbContext.cs +++ b/src/EntityFramework/DbContext.cs @@ -13,6 +13,7 @@ namespace System.Data.Entity using System.Data.Entity.Utilities; using System.Data.Entity.Validation; using System.Diagnostics.CodeAnalysis; + using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -63,7 +64,8 @@ public class DbContext : IDisposable, IObjectContextAdapter private Database _database; - + private ForceDateTimeTypeAttribute _forcedDateTimeTypeAttribute; + /// /// Constructs a new context instance using conventions to create the name of the database to /// which a connection will be made. The by-convention name is the full name (namespace + class name) @@ -76,7 +78,7 @@ protected DbContext() { InitializeLazyInternalContext(new LazyInternalConnection(this, GetType().DatabaseName())); } - + /// /// Constructs a new context instance using conventions to create the name of the database to /// which a connection will be made, and initializes it from the given model. @@ -205,6 +207,20 @@ private void DiscoverAndInitializeSets() new DbSetDiscoveryService(this).InitializeSets(); } + internal DbType? ForcedDateTimeType + { + get + { + if (_forcedDateTimeTypeAttribute == null) + { + _forcedDateTimeTypeAttribute + = GetType().GetCustomAttributes(inherit: false).OfType().FirstOrDefault(); + } + + return _forcedDateTimeTypeAttribute?.DateTimeType; + } + } + #endregion #region Model building diff --git a/src/EntityFramework/Infrastructure/ForceDateTimeTypeAttribute.cs b/src/EntityFramework/Infrastructure/ForceDateTimeTypeAttribute.cs new file mode 100644 index 0000000000..ec66e771e0 --- /dev/null +++ b/src/EntityFramework/Infrastructure/ForceDateTimeTypeAttribute.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +namespace System.Data.Entity.Infrastructure +{ + /// + /// Place this attribute on a type that inherits from to + /// force all parameters to have the given . + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false)] + public sealed class ForceDateTimeTypeAttribute : Attribute + { + private readonly DbType _type; + + /// + /// Creates a new instance with the given . + /// + /// The type to force. + public ForceDateTimeTypeAttribute(DbType type) + { + _type = type; + } + + /// + /// The that will be forced for all parameters. + /// + public DbType DateTimeType + { + get { return _type; } + } + } +} \ No newline at end of file diff --git a/src/EntityFramework/Infrastructure/Interception/DbCommandDispatcher.cs b/src/EntityFramework/Infrastructure/Interception/DbCommandDispatcher.cs index 7c69375f45..2bac860471 100644 --- a/src/EntityFramework/Infrastructure/Interception/DbCommandDispatcher.cs +++ b/src/EntityFramework/Infrastructure/Interception/DbCommandDispatcher.cs @@ -6,6 +6,7 @@ namespace System.Data.Entity.Infrastructure.Interception using System.Data.Common; using System.Data.Entity.Utilities; using System.Diagnostics.CodeAnalysis; + using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -53,6 +54,8 @@ public virtual int NonQuery(DbCommand command, DbCommandInterceptionContext inte Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.Dispatch( command, (t, c) => t.ExecuteNonQuery(), @@ -61,6 +64,22 @@ public virtual int NonQuery(DbCommand command, DbCommandInterceptionContext inte (i, t, c) => i.NonQueryExecuted(t, c)); } + private static void ForceDateTimeTypes(DbCommand command, DbCommandInterceptionContext interceptionContext) + { + var forcedDateTimeType = interceptionContext.ForcedDateTimeType; + if (forcedDateTimeType.HasValue) + { + foreach (DbParameter parameter in command.Parameters) + { + if (parameter.DbType == DbType.DateTime + || parameter.DbType == DbType.DateTime2) + { + parameter.DbType = forcedDateTimeType.Value; + } + } + } + } + /// /// Sends and /// to any @@ -80,6 +99,8 @@ public virtual object Scalar(DbCommand command, DbCommandInterceptionContext int Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.Dispatch( command, (t, c) => t.ExecuteScalar(), @@ -108,6 +129,8 @@ public virtual DbDataReader Reader( Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.Dispatch( command, (t, c) => t.ExecuteReader(c.CommandBehavior), @@ -138,6 +161,8 @@ public virtual Task NonQueryAsync( Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.DispatchAsync( command, (t, c, ct) => t.ExecuteNonQueryAsync(ct), @@ -168,6 +193,8 @@ public virtual Task ScalarAsync( Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.DispatchAsync( command, (t, c, ct) => t.ExecuteScalarAsync(ct), @@ -198,6 +225,8 @@ public virtual Task ReaderAsync( Check.NotNull(command, "command"); Check.NotNull(interceptionContext, "interceptionContext"); + ForceDateTimeTypes(command, interceptionContext); + return _internalDispatcher.DispatchAsync( command, (t, c, ct) => t.ExecuteReaderAsync(c.CommandBehavior, ct), diff --git a/src/EntityFramework/Infrastructure/Interception/DbCommandInterceptionContext.cs b/src/EntityFramework/Infrastructure/Interception/DbCommandInterceptionContext.cs index fcffdba670..25fe8a28ca 100644 --- a/src/EntityFramework/Infrastructure/Interception/DbCommandInterceptionContext.cs +++ b/src/EntityFramework/Infrastructure/Interception/DbCommandInterceptionContext.cs @@ -46,7 +46,7 @@ public DbCommandInterceptionContext(DbInterceptionContext copyFrom) _commandBehavior = asThisType._commandBehavior; } } - + /// /// The that will be used or has been used to execute the command with a /// . This property is only used for diff --git a/src/EntityFramework/Infrastructure/Interception/DbInterceptionContext.cs b/src/EntityFramework/Infrastructure/Interception/DbInterceptionContext.cs index 5d8e0699e6..f708c0e651 100644 --- a/src/EntityFramework/Infrastructure/Interception/DbInterceptionContext.cs +++ b/src/EntityFramework/Infrastructure/Interception/DbInterceptionContext.cs @@ -30,6 +30,7 @@ public class DbInterceptionContext private readonly IList _dbContexts; private readonly IList _objectContexts; private bool _isAsync; + private DbType? _forcedDateTimeType; /// /// Constructs a new with no state. @@ -52,6 +53,7 @@ protected DbInterceptionContext(DbInterceptionContext copyFrom) _dbContexts = copyFrom.DbContexts.Where(c => c.InternalContext == null || !c.InternalContext.IsDisposed).ToList(); _objectContexts = copyFrom.ObjectContexts.Where(c => !c.IsDisposed).ToList(); _isAsync = copyFrom._isAsync; + _forcedDateTimeType = copyFrom._forcedDateTimeType; } private DbInterceptionContext(IEnumerable copyFrom) @@ -69,6 +71,20 @@ private DbInterceptionContext(IEnumerable copyFrom) .Where(c => !c.IsDisposed).ToList(); _isAsync = copyFrom.Any(c => c.IsAsync); + + foreach (var context in _dbContexts) + { + var forcedDateTimeType = context.ForcedDateTimeType; + if (forcedDateTimeType.HasValue) + { + _forcedDateTimeType = forcedDateTimeType; + } + } + } + + internal DbType? ForcedDateTimeType + { + get { return _forcedDateTimeType; } } /// @@ -98,6 +114,12 @@ public DbInterceptionContext WithDbContext(DbContext context) if (!copy._dbContexts.Contains(context, ObjectReferenceEqualityComparer.Default)) { copy._dbContexts.Add(context); + + var forcedDateTimeType = context.ForcedDateTimeType; + if (forcedDateTimeType.HasValue) + { + copy._forcedDateTimeType = forcedDateTimeType; + } } return copy; } diff --git a/test/EntityFramework/FunctionalTests/Interception/BlogContext.cs b/test/EntityFramework/FunctionalTests/Interception/BlogContext.cs index b65ed2a88a..e3e9be137a 100644 --- a/test/EntityFramework/FunctionalTests/Interception/BlogContext.cs +++ b/test/EntityFramework/FunctionalTests/Interception/BlogContext.cs @@ -29,7 +29,8 @@ public static void DoStuff(BlogContext context) blog.Posts.Add( new Post { - Title = "Throw it away..." + Title = "Throw it away...", + Spacetime = new DateTime(1915, 11, 25) }); ExtendedSqlAzureExecutionStrategy.ExecuteNew( @@ -75,6 +76,8 @@ public class Blog public string Title { get; set; } public virtual ICollection Posts { get; set; } + + public DateTime TimeDilation { get; set; } } public class Post @@ -84,6 +87,8 @@ public class Post public int BlogId { get; set; } public virtual Blog Blog { get; set; } + + public DateTime Spacetime { get; set; } } public class BlogInitializer : DropCreateDatabaseAlways @@ -94,9 +99,11 @@ protected override void Seed(BlogContext context) new Post { Title = "Wrap it up...", + Spacetime = new DateTime(1915, 11, 25), Blog = new Blog { - Title = "Half a Unicorn" + Title = "Half a Unicorn", + TimeDilation = new DateTime(1905, 6, 30) } }); } diff --git a/test/EntityFramework/FunctionalTests/Interception/CommandInterceptionTests.cs b/test/EntityFramework/FunctionalTests/Interception/CommandInterceptionTests.cs index d13e19519f..78f10f82f5 100644 --- a/test/EntityFramework/FunctionalTests/Interception/CommandInterceptionTests.cs +++ b/test/EntityFramework/FunctionalTests/Interception/CommandInterceptionTests.cs @@ -5,6 +5,7 @@ namespace System.Data.Entity.Interception using System.Collections.Concurrent; using System.Collections.Generic; using System.Data.Common; + using System.Data.Entity.Infrastructure; using System.Data.Entity.Infrastructure.Interception; using System.Data.Entity.TestHelpers; using System.Data.SqlClient; @@ -25,13 +26,21 @@ public CommandInterceptionTests() [Fact] [UseDefaultExecutionStrategy] public void Initialization_and_simple_query_and_update_commands_can_be_logged() + { + CommandInterceptionTest(); + CommandInterceptionTest(); + CommandInterceptionTest(); + } + + private void CommandInterceptionTest() + where TContextType : BlogContext, new() { var logger = new CommandLogger(); DbInterception.Add(logger); try { - using (var context = new BlogContextLogAll()) + using (var context = new TContextType()) { BlogContext.DoStuff(context); } @@ -52,6 +61,20 @@ public void Initialization_and_simple_query_and_update_commands_can_be_logged() { Assert.Equal(method + 1, logger.Log[i + 1].Method); Assert.Same(logger.Log[i].Command, logger.Log[i + 1].Command); + + var parameters = logger.Log[i].Command.Parameters; + + var expectedDbType = typeof(TContextType) == typeof(BlogContextDateTimeLogAll) + ? DbType.DateTime + : DbType.DateTime2; + + foreach (DbParameter parameter in parameters) + { + if (parameter.Value is DateTime) + { + Assert.Equal(expectedDbType, parameter.DbType); + } + } } } @@ -78,6 +101,23 @@ static BlogContextLogAll() Database.SetInitializer(new BlogInitializer()); } } + [ForceDateTimeType(DbType.DateTime)] + public class BlogContextDateTimeLogAll : BlogContext + { + static BlogContextDateTimeLogAll() + { + Database.SetInitializer(new BlogInitializer()); + } + } + + [ForceDateTimeType(DbType.DateTime2)] + public class BlogContextDateTime2LogAll : BlogContext + { + static BlogContextDateTime2LogAll() + { + Database.SetInitializer(new BlogInitializer()); + } + } [Fact] public void Commands_that_result_in_exceptions_are_still_intercepted() diff --git a/test/EntityFramework/FunctionalTests/Interception/CommitFailureTests.cs b/test/EntityFramework/FunctionalTests/Interception/CommitFailureTests.cs index d9d9520113..fb91bd0c98 100644 --- a/test/EntityFramework/FunctionalTests/Interception/CommitFailureTests.cs +++ b/test/EntityFramework/FunctionalTests/Interception/CommitFailureTests.cs @@ -178,7 +178,7 @@ private void Execute_commit_failure_test( }); failingTransactionInterceptor.ShouldFailTimes = 1; - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); verifySaveChanges(() => context.SaveChanges()); var expectedCommitCount = useTransactionHandler @@ -329,7 +329,7 @@ private void TransactionHandler_and_ExecutionStrategy_does_not_retry_on_false_co failingCommandInterceptor.FailAfter = 1; failingCommandInterceptor.ShouldFailTimes = queryFailures; - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); runAndVerify(context); @@ -428,7 +428,7 @@ private void CommitFailureHandler_prunes_transactions_after_set_amount_implement for (var i = 0; i < transactionHandler.PruningLimit; i++) { - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); context.SaveChanges(); } @@ -440,10 +440,10 @@ private void CommitFailureHandler_prunes_transactions_after_set_amount_implement failingTransactionInterceptor.ShouldRollBack = false; } - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); context.SaveChanges(); - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); context.SaveChanges(); AssertTransactionHistoryCount(context, 1); @@ -679,7 +679,7 @@ private void CommitFailureHandler_with_ExecutionStrategy_test( context.Database.Delete(); Assert.Equal(1, context.Blogs.Count()); - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); context.SaveChanges(); @@ -790,7 +790,7 @@ public void CommitFailureHandler_supports_nested_transactions() Assert.Equal(1, context.Blogs.Count()); }); - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); ExtendedSqlAzureExecutionStrategy.ExecuteNew( () => @@ -802,7 +802,7 @@ public void CommitFailureHandler_supports_nested_transactions() using (var innerTransaction = innerContext.Database.BeginTransaction()) { Assert.Equal(1, innerContext.Blogs.Count()); - innerContext.Blogs.Add(new BlogContext.Blog()); + innerContext.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); innerContext.SaveChanges(); innerTransaction.Commit(); } @@ -853,7 +853,7 @@ public void BuildDatabaseInitializationScript_can_be_used_to_initialize_the_data using (var context = new BlogContextCommit()) { - context.Blogs.Add(new BlogContext.Blog()); + context.Blogs.Add(new BlogContext.Blog { TimeDilation = new DateTime(1905, 6, 30) }); Assert.Throws(() => context.SaveChanges()); diff --git a/test/EntityFramework/FunctionalTests/Interception/DatabaseLogFormatterTests.cs b/test/EntityFramework/FunctionalTests/Interception/DatabaseLogFormatterTests.cs index 49dd80fe77..047deb699f 100644 --- a/test/EntityFramework/FunctionalTests/Interception/DatabaseLogFormatterTests.cs +++ b/test/EntityFramework/FunctionalTests/Interception/DatabaseLogFormatterTests.cs @@ -63,7 +63,7 @@ public void Simple_query_and_update_commands_can_be_logged() const int selectCount = 5; const int updateCount = 1; const int asyncCount = 2; - const int paramCount = 4; + const int paramCount = 5; const int imALoggerCount = 1; const int transactionCount = 2; const int connectionCount = 4; @@ -187,7 +187,7 @@ public void The_command_formatter_to_use_can_be_changed() Assert.Equal(3, logLines.Length); Assert.Equal( - "Context 'BlogContextNoInit' is executing command 'SELECT TOP (2) [c].[Id] AS [Id], [c].[Title] AS [Title] FROM [dbo].[Blogs] AS [c]'", + "Context 'BlogContextNoInit' is executing command 'SELECT TOP (2) [c].[Id] AS [Id], [c].[Title] AS [Title], [c].[TimeDilation] AS [TimeDilation] FROM [dbo].[Blogs] AS [c]'", logLines[0]); Assert.Equal(