Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CodeGenerator/Generators/EnumsGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace SqlcGenCsharp.Generators;

internal class EnumsGen(DbDriver dbDriver)

Check warning on line 9 in CodeGenerator/Generators/EnumsGen.cs

View workflow job for this annotation

GitHub Actions / Build (WASM)

Parameter 'dbDriver' is unread.

Check warning on line 9 in CodeGenerator/Generators/EnumsGen.cs

View workflow job for this annotation

GitHub Actions / Codegen Tests

Parameter 'dbDriver' is unread.
{
public MemberDeclarationSyntax[] Generate(string name, IList<string> possibleValues)
{
Expand All @@ -21,10 +21,6 @@
{{enumValuesDef}}
}
""")!;

if (dbDriver.Options.UseDapper)
return [enumType];

var enumExtensions = ParseMemberDeclaration($$"""
public static class {{name}}Extensions
{
Expand All @@ -40,6 +36,11 @@
{
return StringToEnum[me];
}

public static {{name}}[] To{{name}}Arr(this string me)
{
return me.Split(',').ToList().Select(v => StringToEnum[v]).ToArray();
}
}
""")!;
return [enumType, enumExtensions];
Expand Down
3 changes: 0 additions & 3 deletions CodegenTests/CodegenUtilsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ public void TestMysqlCopyFromGenerateUtilsMembers()
var expected = new HashSet<string>
{
MySqlConnectorDriver.NullToStringCsvConverter,
MySqlConnectorDriver.BoolToBitCsvConverter,
MySqlConnectorDriver.ByteCsvConverter,
MySqlConnectorDriver.ByteArrayCsvConverter
};
var actual = members
.FindAll(m => m is ClassDeclarationSyntax)
Expand Down
74 changes: 48 additions & 26 deletions Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public abstract class DbDriver

public Dictionary<string, Dictionary<string, Plugin.Enum>> Enums { get; }

private IList<Query> Queries { get; }
protected IList<Query> Queries { get; }

private HashSet<string> NullableTypesInDotnetCore { get; } =
[
Expand Down Expand Up @@ -59,15 +59,13 @@ public abstract class DbDriver
public abstract Dictionary<string, ColumnMapping> ColumnMappings { get; }

protected const string JsonElementTypeHandler =
"""
public class JsonElementTypeHandler : SqlMapper.TypeHandler<JsonElement>
"""
private class JsonElementTypeHandler : SqlMapper.TypeHandler<JsonElement>
{
public override JsonElement Parse(object value)
{
if (value is string s)
return JsonDocument.Parse(s).RootElement;
if (value is null)
return default;
throw new DataException($"Cannot convert {value?.GetType()} to JsonElement");
}

Expand All @@ -76,7 +74,7 @@ public override void SetValue(IDbDataParameter parameter, JsonElement value)
parameter.Value = value.GetRawText();
}
}
""";
""";

protected const string TransformQueryForSliceArgsImpl = """
public static string TransformQueryForSliceArgs(string originalSql, int sliceSize, string paramName)
Expand Down Expand Up @@ -148,7 +146,7 @@ private ISet<string> GetUsingDirectivesForColumnMappings()

public virtual ISet<string> GetUsingDirectivesForUtils()
{
return new HashSet<string>()
return new HashSet<string>
{
"System.Linq"
}
Expand All @@ -158,20 +156,24 @@ public virtual ISet<string> GetUsingDirectivesForUtils()

public virtual ISet<string> GetUsingDirectivesForModels()
{
return GetUsingDirectivesForColumnMappings();
return new HashSet<string>
{
"System.Linq"
}
.AddRangeExcludeNulls(GetUsingDirectivesForColumnMappings());
}

public virtual string[] GetConstructorStatements()
public string[] GetConstructorStatements()
{
return [$"this.{Variable.ConnectionString.AsPropertyName()} = {Variable.ConnectionString.AsVarName()};"];
}

public virtual string[] GetTransactionConstructorStatements()
public string[] GetTransactionConstructorStatements()
{
return [$"this.{Variable.Transaction.AsPropertyName()} = {Variable.Transaction.AsVarName()};"];
}

protected ISet<string> GetConfigureSqlMappings()
protected virtual ISet<string> GetConfigureSqlMappings()
{
return ColumnMappings
.Where(m => TypeExistsInQueries(m.Key) && m.Value.SqlMapper is not null)
Expand Down Expand Up @@ -200,9 +202,15 @@ public static void ConfigureSqlMapper()

protected bool TypeExistsInQueries(string csharpType)
{
return Queries
.SelectMany(query => query.Columns)
.Any(column => csharpType == GetCsharpTypeWithoutNullableSuffix(column, null));
return Queries.Any(q => TypeExistsInQuery(csharpType, q));
}

protected bool TypeExistsInQuery(string csharpType, Query query)
{
return query.Columns
.Any(column => csharpType == GetCsharpTypeWithoutNullableSuffix(column, query)) ||
query.Params
.Any(p => csharpType == GetCsharpTypeWithoutNullableSuffix(p.Column, query));
}

public string AddNullableSuffixIfNeeded(string csharpType, bool notNull)
Expand All @@ -218,6 +226,16 @@ public string GetCsharpType(Column column, Query? query)
return AddNullableSuffixIfNeeded(csharpType, IsColumnNotNull(column, query));
}

public string GetColumnSchema(Column column)
{
return column.Table.Schema == DefaultSchema ? string.Empty : column.Table.Schema;
}

public virtual string GetEnumTypeAsCsharpType(Column column, Plugin.Enum enumType)
{
return column.Type.Name.ToModelName(GetColumnSchema(column), DefaultSchema);
}

public string GetCsharpTypeWithoutNullableSuffix(Column column, Query? query)
{
if (column.EmbedTable != null)
Expand All @@ -226,8 +244,8 @@ public string GetCsharpTypeWithoutNullableSuffix(Column column, Query? query)
if (string.IsNullOrEmpty(column.Type.Name))
return "object";

if (IsEnumType(column))
return column.Type.Name.ToModelName(column.Table.Schema, DefaultSchema);
if (GetEnumType(column) is { } enumType)
return GetEnumTypeAsCsharpType(column, enumType);

if (FindOverrideForQueryColumn(query, column) is { CsharpType: var csharpType })
return csharpType.Type;
Expand All @@ -242,14 +260,14 @@ public string GetCsharpTypeWithoutNullableSuffix(Column column, Query? query)
throw new NotSupportedException($"Column {column.Name} has unsupported column type: {column.Type.Name}");
}

private bool IsEnumType(Column column)
public Plugin.Enum? GetEnumType(Column column)
{
if (column.Table is null)
return false;
var enumSchema = column.Table.Schema == DefaultSchema ? string.Empty : column.Table.Schema;
if (!Enums.TryGetValue(enumSchema, value: out var enumsInSchema))
return false;
return enumsInSchema.ContainsKey(column.Type.Name);
return null;
var schemaName = GetColumnSchema(column);
if (!Enums.TryGetValue(schemaName, value: out var enumsInSchema))
return null;
return enumsInSchema.GetValueOrDefault(column.Type.Name);
}

private static bool DoesColumnMappingApply(ColumnMapping columnMapping, Column column)
Expand All @@ -269,16 +287,20 @@ private string GetColumnReader(CsharpTypeOption csharpTypeOption, int ordinal)
throw new NotSupportedException($"Could not find column mapping for type override: {csharpTypeOption.Type}");
}

private string GetEnumReader(Column column, int ordinal)
private string GetEnumReader(Column column, int ordinal, Plugin.Enum enumType)
{
var enumName = column.Type.Name.ToModelName(column.Table.Schema, DefaultSchema);
return $"{Variable.Reader.AsVarName()}.GetString({ordinal}).To{enumName}()";
var fullEnumType = GetEnumTypeAsCsharpType(column, enumType);
var readStmt = $"{Variable.Reader.AsVarName()}.GetString({ordinal})";
if (fullEnumType.EndsWith("[]"))
return $"{readStmt}.To{enumName}Arr()";
return $"{readStmt}.To{enumName}()";
}

public string GetColumnReader(Column column, int ordinal, Query? query)
{
if (IsEnumType(column))
return GetEnumReader(column, ordinal);
if (GetEnumType(column) is { } enumType)
return GetEnumReader(column, ordinal, enumType);

if (FindOverrideForQueryColumn(query, column) is { CsharpType: var csharpType })
return GetColumnReader(csharpType, ordinal);
Expand Down
16 changes: 14 additions & 2 deletions Drivers/Generators/CommonGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ public static string GetMethodParameterList(string argInterface, IEnumerable<Par
if (writerFn is not null)
return writerFn;

var defaultWriterFn = (string el, bool notNull, bool isDapper) => notNull ? el : $"{el} ?? (object)DBNull.Value";
return dbDriver.Options.UseDapper ? null : defaultWriterFn;
if (dbDriver.GetEnumType(column) is { } enumType)
if (dbDriver.GetEnumTypeAsCsharpType(column, enumType).EndsWith("[]"))
return (el, notNull, isDapper) =>
{
var stringJoinStmt = $"string.Join(\",\", {el})";
var nullValue = isDapper ? "null" : "(object)DBNull.Value";
return notNull
? stringJoinStmt
: $"{el} != null ? {stringJoinStmt} : {nullValue}";
};

string DefaultWriterFn(string el, bool notNull, bool isDapper) => notNull ? el : $"{el} ?? (object)DBNull.Value";
return dbDriver.Options.UseDapper ? null : DefaultWriterFn;
}

// TODO: extract AddWithValue statement generation to a method + possible override for Npgsql for type override
public string AddParametersToCommand(Query query)
{
return query.Params.Select(p =>
Expand Down
Loading
Loading