Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private void InitGenerators(GenerateRequest generateRequest)
DbDriver = InstantiateDriver();

// initialize file generators
CsprojGen = new(outputDirectory, projectName, namespaceName, Options);
CsprojGen = new(DbDriver, outputDirectory, projectName, namespaceName);
QueriesGen = new(DbDriver, namespaceName);
ModelsGen = new(DbDriver, namespaceName);
UtilsGen = new(DbDriver, namespaceName);
Expand Down
69 changes: 12 additions & 57 deletions CodeGenerator/Generators/CsprojGen.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
using Google.Protobuf;
using SqlcGenCsharp.Drivers;
using System;
using System.Linq;
using File = Plugin.File;


namespace SqlcGenCsharp.Generators;

internal class CsprojGen(string outputDirectory, string projectName, string namespaceName, Options options)
internal class CsprojGen(DbDriver dbDriver, string outputDirectory, string projectName, string namespaceName)
{
// TODO this logic needs to be moved to the Drivers project
private const string DefaultDapperVersion = "2.1.66";
private const string DefaultNpgsqlVersion = "8.0.6";
private const string DefaultMysqlConnectorVersion = "2.4.0";
private const string DefaultSqliteVersion = "9.0.0";
private const string DefaultCsvHelperVersion = "33.0.1";
private const string DefaultSystemTextJsonVersion = "9.0.6";

public File GenerateFile()
{
var csprojContents = GetFileContents();
Expand All @@ -27,7 +21,11 @@ public File GenerateFile()

private string GetFileContents()
{
var optionalNullableProperty = options.DotnetFramework.IsDotnetCore() ? Environment.NewLine + " <Nullable>enable</Nullable>" : "";
var optionalNullableProperty = dbDriver.Options.DotnetFramework.IsDotnetCore() ? Environment.NewLine + " <Nullable>enable</Nullable>" : "";
var referenceItems = dbDriver.GetPackageReferences()
.Select(p => $""" <PackageReference Include="{p.Key}" Version="{p.Value}"/>""")
.JoinByNewLine();

return $"""
<!--{Consts.AutoGeneratedComment}-->
<!--Run the following to add the project to the solution:
Expand All @@ -36,59 +34,16 @@ dotnet sln add {outputDirectory}/{projectName}.csproj
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>{options.DotnetFramework.ToName()}</TargetFramework>
<TargetFramework>{dbDriver.Options.DotnetFramework.ToName()}</TargetFramework>
<RootNamespace>{namespaceName}</RootNamespace>
<OutputType>Library</OutputType>{optionalNullableProperty}
</PropertyGroup>

{GetPackageReferences()}
<ItemGroup>
{referenceItems}
</ItemGroup>

</Project>
""";

string GetPackageReferences()
{
var optionalDapperPackageReference = options.UseDapper
? Environment.NewLine + $""" <PackageReference Include="Dapper" Version="{GetDapperVersion(options)}"/>"""
: string.Empty;
var optionalCsvHelper = options.DriverName is DriverName.MySqlConnector
? Environment.NewLine + $""" <PackageReference Include="CsvHelper" Version="{DefaultCsvHelperVersion}"/>"""
: string.Empty;
var optionalSystemTextJson = IsSystemTextJsonNeeded()
? Environment.NewLine + $""" <PackageReference Include="System.Text.Json" Version="{DefaultSystemTextJsonVersion}"/>"""
: string.Empty;
return $"""
<ItemGroup>
<PackageReference Include="{options.DriverName.ToName()}" Version="{GetDriverVersion(options)}"/>{optionalDapperPackageReference}{optionalCsvHelper}{optionalSystemTextJson}
</ItemGroup>
""";
}
}

private bool IsSystemTextJsonNeeded()
{
if (options.DotnetFramework.IsDotnetCore())
return false;
return options.DriverName is DriverName.MySqlConnector or DriverName.Npgsql;
}

private static string GetDriverVersion(Options options)
{
if (string.IsNullOrEmpty(options.OverrideDriverVersion))
return options.DriverName switch
{
DriverName.Npgsql => DefaultNpgsqlVersion,
DriverName.MySqlConnector => DefaultMysqlConnectorVersion,
DriverName.Sqlite => DefaultSqliteVersion,
_ => throw new NotSupportedException($"unsupported driver: {options.DriverName}")
};
return options.OverrideDriverVersion;
}

private static string GetDapperVersion(Options options)
{
return string.IsNullOrEmpty(options.OverrideDapperVersion)
? DefaultDapperVersion
: options.OverrideDapperVersion;
}
}
4 changes: 2 additions & 2 deletions Drivers/ColumnMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class ColumnMapping(
Dictionary<string, DbTypeInfo> dbTypes,
ReaderFn readerFn,
ReaderFn? readerArrayFn = null,
string? usingDirective = null,
string[]? usingDirectives = null,
WriterFn? writerFn = null,
ConvertFunc? convertFunc = null,
string? sqlMapper = null,
Expand All @@ -25,7 +25,7 @@ public class ColumnMapping(
public Dictionary<string, DbTypeInfo> DbTypes { get; } = dbTypes;
public ReaderFn ReaderFn { get; } = readerFn;
public ReaderFn? ReaderArrayFn { get; } = readerArrayFn;
public string? UsingDirective { get; } = usingDirective;
public string[]? UsingDirectives { get; } = usingDirectives;
public WriterFn? WriterFn { get; } = writerFn;
public ConvertFunc? ConvertFunc { get; } = convertFunc;
public string? SqlMapper { get; } = sqlMapper;
Expand Down
104 changes: 83 additions & 21 deletions Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ public record ConnectionGenCommands(string EstablishConnection, string Connectio

public abstract class DbDriver
{
protected const string DefaultDapperVersion = "2.1.66";
protected const string DefaultSystemTextJsonVersion = "9.0.6";
protected const string DefaultNodaTimeVersion = "3.2.0";

public Options Options { get; }

public string DefaultSchema { get; }
Expand Down Expand Up @@ -51,6 +55,7 @@ public abstract class DbDriver
"NpgsqlCircle",
"JsonElement",
"NpgsqlCidr",
"Instant"
];

protected abstract Dictionary<string, ColumnMapping> ColumnMappings { get; }
Expand All @@ -68,6 +73,27 @@ public static string TransformQueryForSliceArgs(string originalSql, int sliceSiz
throw new InvalidOperationException("Transaction is provided, but its connection is null.");
""";

protected static readonly SqlMapperImplFunc DateTimeNodaInstantTypeHandler = _ => $$"""
private class NodaInstantTypeHandler : SqlMapper.TypeHandler<Instant>
{
public override Instant Parse(object value)
{
if (value is DateTime dt)
{
if (dt.Kind != DateTimeKind.Utc)
dt = DateTime.SpecifyKind(dt, DateTimeKind.Utc);
return dt.ToInstant();
}
throw new DataException($"Cannot convert {value?.GetType()} to Instant");
}

public override void SetValue(IDbDataParameter parameter, Instant value)
{
parameter.Value = value;
}
}
""";

protected DbDriver(
Options options,
Catalog catalog,
Expand Down Expand Up @@ -101,6 +127,21 @@ private static Dictionary<string, Dictionary<string, Table>> ConstructTablesLook
);
}

public virtual IDictionary<string, string> GetPackageReferences()
{
return new Dictionary<string, string> {
{ "Dapper", Options.OverrideDapperVersion != string.Empty ? Options.OverrideDapperVersion : DefaultDapperVersion }
}
.MergeIf(new Dictionary<string, string>
{
{ "System.Text.Json", DefaultSystemTextJsonVersion }
}, IsSystemTextJsonNeeded())
.MergeIf(new Dictionary<string, string>
{
{ "NodaTime", DefaultNodaTimeVersion }
}, TypeExistsInQueries("Instant"));
}

public virtual ISet<string> GetUsingDirectivesForQueries()
{
return new HashSet<string>
Expand All @@ -118,17 +159,16 @@ public virtual ISet<string> GetUsingDirectivesForQueries()
private ISet<string> GetUsingDirectivesForColumnMappings()
{
var usingDirectives = new HashSet<string>();
foreach (var schemaTables in Tables.Values)
foreach (var table in schemaTables.Values)
foreach (var column in table.Columns)
{
var csharpType = GetCsharpTypeWithoutNullableSuffix(column, null);
if (!ColumnMappings.ContainsKey(csharpType))
continue;
foreach (var query in Queries)
foreach (var column in query.Columns)
{
var csharpType = GetCsharpTypeWithoutNullableSuffix(column, query);
if (!ColumnMappings.ContainsKey(csharpType))
continue;

var columnMapping = ColumnMappings[csharpType];
usingDirectives.AddRangeExcludeNulls([columnMapping.UsingDirective]);
}
var columnMapping = ColumnMappings[csharpType];
usingDirectives.AddRangeIf(columnMapping.UsingDirectives!, columnMapping.UsingDirectives is not null);
}
return usingDirectives;
}

Expand Down Expand Up @@ -223,6 +263,32 @@ public virtual string[] GetLastIdStatement(Query query)
];
}

public virtual string AddParametersToCommand(Query query)
{
return query.Params.Select(p =>
{
var commandVar = Variable.Command.AsVarName();
var param = $"{Variable.Args.AsVarName()}.{p.Column.Name.ToPascalCase()}";
var columnMapping = GetCsharpTypeWithoutNullableSuffix(p.Column, query);

if (p.Column.IsSqlcSlice)
return $$"""
for (int i = 0; i < {{param}}.Length; i++)
{{commandVar}}.Parameters.AddWithValue($"@{{p.Column.Name}}Arg{i}", {{param}}[i]);
""";

var writerFn = GetWriterFn(p.Column, query);
var paramToWrite = writerFn is null ? param : writerFn(
param,
p.Column.Type.Name,
IsColumnNotNull(p.Column, query),
Options.UseDapper,
Options.DotnetFramework.IsDotnetLegacy());
var addParamToCommand = $"""{commandVar}.Parameters.AddWithValue("@{p.Column.Name}", {paramToWrite});""";
return addParamToCommand;
}).JoinByNewLine();
}

public Column GetColumnFromParam(Parameter queryParam, Query query)
{
if (string.IsNullOrEmpty(queryParam.Column.Name))
Expand Down Expand Up @@ -285,17 +351,6 @@ public string AddNullableSuffixIfNeeded(string csharpType, bool notNull)
return IsTypeNullable(csharpType) ? $"{csharpType}?" : csharpType;
}

protected string? GetColumnDbTypeOverride(Column column)
{
var columnType = column.Type.Name.ToLower();
foreach (var columnMapping in ColumnMappings.Values)
{
if (columnMapping.DbTypes.TryGetValue(columnType, out var dbTypeOverride))
return dbTypeOverride.NpgsqlTypeOverride;
}
throw new NotSupportedException($"Column {column.Name} has unsupported column type: {column.Type.Name}");
}

public bool IsTypeNullable(string csharpType)
{
if (NullableTypes.Contains(csharpType.Replace("?", ""))) return true;
Expand Down Expand Up @@ -365,4 +420,11 @@ public virtual string GetColumnReader(Column column, int ordinal, Query? query)
}
throw new NotSupportedException($"column {column.Name} has unsupported column type: {column.Type.Name} in {GetType().Name}");
}

private bool IsSystemTextJsonNeeded()
{
if (Options.DotnetFramework.IsDotnetCore())
return false;
return TypeExistsInQueries("JsonElement");
}
}
21 changes: 0 additions & 21 deletions Drivers/Generators/CommonGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,6 @@ public static string GetMethodParameterList(string argInterface, IEnumerable<Par
: $"{argInterface} {Variable.Args.AsVarName()}")}";
}

// TODO: extract AddWithValue statement generation to a method + possible override for Npgsql for type override
public string AddParametersToCommand(Query query)
{
return query.Params.Select(p =>
{
var commandVar = Variable.Command.AsVarName();
var param = $"{Variable.Args.AsVarName()}.{p.Column.Name.ToPascalCase()}";
if (p.Column.IsSqlcSlice)
return $$"""
for (int i = 0; i < {{param}}.Length; i++)
{{commandVar}}.Parameters.AddWithValue($"@{{p.Column.Name}}Arg{i}", {{param}}[i]);
""";

var notNull = dbDriver.IsColumnNotNull(p.Column, query);
var writerFn = dbDriver.GetWriterFn(p.Column, query);
var paramToWrite = writerFn is null ? param : writerFn(param, p.Column.Type.Name, notNull, dbDriver.Options.UseDapper, dbDriver.Options.DotnetFramework.IsDotnetLegacy());
var addParamToCommand = $"""{commandVar}.Parameters.AddWithValue("@{p.Column.Name}", {paramToWrite});""";
return addParamToCommand;
}).JoinByNewLine();
}

public string ConstructDapperParamsDict(Query query)
{
if (!query.Params.Any()) return string.Empty;
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Generators/ExecDeclareGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private string GetDriverNoTxBody(string sqlVar, Query query)
{
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
var createSqlCommand = dbDriver.CreateSqlCommand(sqlVar);
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
return $$"""
using ({{establishConnection}})
{
Expand All @@ -89,7 +89,7 @@ private string GetDriverWithTxBody(string sqlVar, Query query)
{
var transactionProperty = Variable.Transaction.AsPropertyName();
var commandVar = Variable.Command.AsVarName();
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);

return $$"""
{{dbDriver.TransactionConnectionNullExcetionThrow}}
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Generators/ExecLastIdDeclareGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private string GetDriverNoTxBody(string sqlVar, Query query)
{
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
var createSqlCommand = dbDriver.CreateSqlCommand(sqlVar);
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
var returnLastId = ((IExecLastId)dbDriver).GetLastIdStatement(query).JoinByNewLine();
return $$"""
using ({{establishConnection}})
Expand All @@ -86,7 +86,7 @@ private string GetDriverWithTxBody(string sqlVar, Query query)
{
var transactionProperty = Variable.Transaction.AsPropertyName();
var commandVar = Variable.Command.AsVarName();
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
var returnLastId = ((IExecLastId)dbDriver).GetLastIdStatement(query).JoinByNewLine();

return $$"""
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Generators/ExecRowsDeclareGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private string GetDriverNoTxBody(string sqlVar, Query query)
{
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
var createSqlCommand = dbDriver.CreateSqlCommand(sqlVar);
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
return $$"""
using ({{establishConnection}})
{
Expand All @@ -87,7 +87,7 @@ private string GetDriverWithTxBody(string sqlVar, Query query)
{
var transactionProperty = Variable.Transaction.AsPropertyName();
var commandVar = Variable.Command.AsVarName();
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);

return $$"""
{{dbDriver.TransactionConnectionNullExcetionThrow}}
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Generators/ManyDeclareGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private string GetDriverNoTxBody(string sqlVar, string returnInterface, Query qu
{
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
var createSqlCommand = dbDriver.CreateSqlCommand(sqlVar);
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
var initDataReader = CommonGen.InitDataReader();
var awaitReaderRow = CommonGen.AwaitReaderRow();
var dataclassInit = CommonGen.InstantiateDataclass(query.Columns.ToArray(), returnInterface, query);
Expand Down Expand Up @@ -111,7 +111,7 @@ private string GetDriverWithTxBody(string sqlVar, string returnInterface, Query
{
var transactionProperty = Variable.Transaction.AsPropertyName();
var commandVar = Variable.Command.AsVarName();
var commandParameters = CommonGen.AddParametersToCommand(query);
var commandParameters = dbDriver.AddParametersToCommand(query);
var initDataReader = CommonGen.InitDataReader();
var awaitReaderRow = CommonGen.AwaitReaderRow();
var dataclassInit = CommonGen.InstantiateDataclass(query.Columns.ToArray(), returnInterface, query);
Expand Down
Loading