Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 146 additions & 98 deletions src/EFCore.Design/Design/Internal/DbContextOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ public virtual IReadOnlyList<string> Optimize(
var optimizeAllInAssembly = contextTypeName == "*";
var contexts = optimizeAllInAssembly ? CreateAllContexts() : [CreateContext(contextTypeName)];

MSBuildLocator.RegisterDefaults();

List<string> generatedFiles = [];
HashSet<string> generatedFileNames = [];
var contextOptimized = false;
foreach (var context in contexts)
{
using (context)
Expand All @@ -158,6 +157,20 @@ public virtual IReadOnlyList<string> Optimize(
optimizeAllInAssembly,
generatedFiles,
generatedFileNames);
contextOptimized = true;
}
}

if (optimizeAllInAssembly)
{
if (!contextOptimized)
{
throw new OperationException(DesignStrings.NoContextsToOptimize);
}

if (generatedFiles.Count == 0)
{
_reporter.WriteWarning(DesignStrings.OptimizeNoFilesGenerated);
}
}

Expand Down Expand Up @@ -269,6 +282,10 @@ private IReadOnlyList<string> PrecompileQueries(string? outputDir, DbContext con
{
outputDir = Path.GetFullPath(Path.Combine(_projectDir, outputDir ?? "Generated"));

if (!MSBuildLocator.IsRegistered)
{
MSBuildLocator.RegisterDefaults();
}
// TODO: pass through properties
var workspace = MSBuildWorkspace.Create();
workspace.LoadMetadataForReferencedProjects = true;
Expand Down Expand Up @@ -373,6 +390,44 @@ static async Task<object> FormatCode(Project project, ScaffoldedFile generatedFi
: null;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
}
catch (Exception exception)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
}
}
else
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();

return info;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -445,11 +500,11 @@ public virtual IEnumerable<Type> GetContextTypes()
public virtual Type GetContextType(string? name)
=> FindContextType(name).Key;

private IDictionary<Type, Func<DbContext>> FindContextTypes()
private IDictionary<Type, Func<DbContext>> FindContextTypes(string? name = null)
{
_reporter.WriteVerbose(DesignStrings.FindingContexts);

var contexts = new Dictionary<Type, Func<DbContext>>();
var contexts = new Dictionary<Type, Func<DbContext>?>();

try
{
Expand All @@ -475,28 +530,18 @@ where i.IsGenericType
}

// Look for DbContextAttribute on the assembly
var appServices = _appServicesFactory.Create(_args);
foreach (var contextAttribute in _startupAssembly.GetCustomAttributes<DbContextAttribute>())
{
var context = contextAttribute.ContextType;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
}
if (contexts.ContainsKey(context))
{
continue;
}

// Look for DbContext classes registered in the service provider
var registeredContexts = appServices.GetServices<DbContextOptions>()
.Select(o => o.ContextType);
foreach (var context in registeredContexts.Where(c => !contexts.ContainsKey(c)))
{
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? FindContextFromRuntimeDbContextFactory(appServices, context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}

// Look for DbContext classes in assemblies
Expand All @@ -507,87 +552,93 @@ where i.IsGenericType

var contextTypes = types.Where(t => typeof(DbContext).IsAssignableFrom(t)).Select(
t => t.AsType())
.Concat(
.Concat<Type>(
types.Where(t => typeof(Migration).IsAssignableFrom(t))
.Select(t => t.GetCustomAttribute<DbContextAttribute>()?.ContextType)
.Where(t => t != null)
.Cast<Type>())
.Where(t => t != null)!)
.Distinct();

foreach (var context in contextTypes.Where(c => !contexts.ContainsKey(c)))
foreach (var context in contextTypes)
{
if (contexts.ContainsKey(context))
{
continue;
}

_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context)
?? (() => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, context)));
FindContextFactory(context));
}
}
catch (Exception ex)
{
if (ex is OperationException)

if (!string.IsNullOrEmpty(name))
{
throw;
contexts = FilterTypes(contexts, name, throwOnEmpty: false);
}

if (ex is TargetInvocationException)
if (contexts.Values.All(f => f != null)
&& (string.IsNullOrEmpty(name) || contexts.Count == 1))
{
ex = ex.InnerException!;
return contexts!;
}

throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ContextInfo GetContextInfo(string? contextType)
{
using var context = CreateContext(contextType);
var info = new ContextInfo { Type = context.GetType().FullName! };
// Look for DbContext classes registered in the service provider
var appServices = _appServicesFactory.Create(_args);
foreach (var options in appServices.GetServices<DbContextOptions>())
{
var context = options.ContextType;
if (contexts.ContainsKey(context))
{
continue;
}

var provider = context.GetService<IDatabaseProvider>();
info.ProviderName = provider.Name;
_reporter.WriteVerbose(DesignStrings.FoundDbContext(context.ShortDisplayName()));
contexts.Add(
context,
FindContextFactory(context));
}

if (((IDatabaseFacadeDependenciesAccessor)context.Database).Dependencies is IRelationalDatabaseFacadeDependencies)
{
try
if (!string.IsNullOrEmpty(name))
{
var connection = context.Database.GetDbConnection();
info.DataSource = connection.DataSource;
info.DatabaseName = connection.Database;
contexts = FilterTypes(contexts, name, throwOnEmpty: true);
}
catch (Exception exception)

foreach (var contextPair in contexts)
{
info.DataSource = info.DatabaseName = DesignStrings.BadConnection(exception.Message);
if (contextPair.Value == null)
{
var context = contextPair.Key;
contexts[context] = CreateContextFromServiceProvider(appServices, context);
}
}
}
else
catch (Exception ex)
{
info.DataSource = info.DatabaseName = DesignStrings.NoRelationalConnection;
}
if (ex is OperationException)
{
throw;
}

var options = context.GetService<IDbContextOptions>();
info.Options = options.BuildOptionsFragment().Trim();
if (ex is TargetInvocationException)
{
ex = ex.InnerException!;
}

return info;
throw new OperationException(DesignStrings.CannotFindDbContextTypes(ex.Message), ex);
}

return contexts!;
}

private static Func<DbContext>? FindContextFromRuntimeDbContextFactory(IServiceProvider appServices, Type contextType)
private static Func<DbContext> CreateContextFromServiceProvider(IServiceProvider appServices, Type contextType)
{
var factoryInterface = typeof(IDbContextFactory<>).MakeGenericType(contextType);
var service = appServices.GetService(factoryInterface);
return service == null
? null
var factoryService = appServices.GetService(factoryInterface);
return factoryService == null
? () => (DbContext)ActivatorUtilities.GetServiceOrCreateInstance(appServices, contextType)
: () => (DbContext)factoryInterface
.GetMethod(nameof(IDbContextFactory<DbContext>.CreateDbContext))
!.Invoke(service, null)!;
!.Invoke(factoryService, null)!;
}

private Func<DbContext>? FindContextFactory(Type contextType)
Expand All @@ -609,44 +660,45 @@ private DbContext CreateContextFromFactory(Type factory, Type contextType)

private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)
{
var types = FindContextTypes();

if (string.IsNullOrEmpty(name))
{
if (types.Count == 0)
var types = FindContextTypes(name);
return !string.IsNullOrEmpty(name)
? types.First()
: types.Count switch
{
throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name));
}

if (types.Count == 1)
{
return types.First();
}

throw new OperationException(DesignStrings.MultipleContexts);
}
0 => throw new OperationException(DesignStrings.NoContext(_assembly.GetName().Name)),
1 => types.First(),
_ => throw new OperationException(DesignStrings.MultipleContexts)
};
}

var candidates = FilterTypes(types, name, ignoreCase: true);
private Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool throwOnEmpty)
{
var candidates = FilterTypes(types, name, StringComparison.OrdinalIgnoreCase);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.NoContextWithName(name));
return throwOnEmpty
? throw new OperationException(DesignStrings.NoContextWithName(name))
: candidates;
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Disambiguate using case
candidates = FilterTypes(candidates, name);
candidates = FilterTypes(candidates, name, StringComparison.Ordinal);
if (candidates.Count == 0)
{
throw new OperationException(DesignStrings.MultipleContextsWithName(name));
}

if (candidates.Count == 1)
{
return candidates.First();
return candidates;
}

// Allow selecting types in the default namespace
Expand All @@ -658,21 +710,17 @@ private KeyValuePair<Type, Func<DbContext>> FindContextType(string? name)

Check.DebugAssert(candidates.Count == 1, $"candidates.Count is {candidates.Count}");

return candidates.First();
return candidates;
}

private static IDictionary<Type, Func<DbContext>> FilterTypes(
IDictionary<Type, Func<DbContext>> types,
private static Dictionary<Type, Func<DbContext>?> FilterTypes(
Dictionary<Type, Func<DbContext>?> types,
string name,
bool ignoreCase = false)
{
var comparisonType = ignoreCase ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal;

return types
StringComparison comparisonType)
=> types
.Where(
t => string.Equals(t.Key.Name, name, comparisonType)
|| string.Equals(t.Key.FullName, name, comparisonType)
|| string.Equals(t.Key.AssemblyQualifiedName, name, comparisonType))
.ToDictionary();
}
}
12 changes: 12 additions & 0 deletions src/EFCore.Design/Properties/DesignStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading