Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Drivers/ColumnMapping.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
using System;
using System.Collections.Generic;

namespace SqlcGenCsharp.Drivers;

public record DbTypeInfo(int? Length = null, string? NpgsqlTypeOverride = null);

public delegate string ReaderFn(int ordinal);
public delegate string ReaderFn(int ordinal, string dbType);

public delegate string WriterFn(string el, bool notNull, bool isDapper);
public delegate string WriterFn(string el, string dbType, bool notNull, bool isDapper, bool isLegacy);

public delegate string ConvertFunc(string el);

Expand Down
18 changes: 10 additions & 8 deletions Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,10 @@ protected bool CopyFromQueryExists()
{
if (query is null)
return null;
return Options.Overrides.FirstOrDefault(o =>
o.Column == $"{query.Name}:{column.Name}" || o.Column == $"*:{column.Name}");
foreach (var overrideOption in Options.Overrides)
if (overrideOption.Column == $"{query.Name}:{column.Name}" || overrideOption.Column == $"*:{column.Name}")
return overrideOption;
return null;
}

// If the column data type is overridden, we need to check for nulls in generated code
Expand Down Expand Up @@ -337,29 +339,29 @@ private static bool DoesColumnMappingApply(ColumnMapping columnMapping, Column c
if (writerFn is not null)
return writerFn;

static string DefaultWriterFn(string el, bool notNull, bool isDapper) => notNull ? el : $"{el} ?? (object)DBNull.Value";
static string DefaultWriterFn(string el, string dbType, bool notNull, bool isDapper, bool isLegacy) => notNull ? el : $"{el} ?? (object)DBNull.Value";
return Options.UseDapper ? null : DefaultWriterFn;
}

/* Column reader methods */
private string GetColumnReader(CsharpTypeOption csharpTypeOption, int ordinal)
private string GetColumnReader(CsharpTypeOption csharpTypeOption, Column column, int ordinal)
{
if (ColumnMappings.TryGetValue(csharpTypeOption.Type, out var value))
return value.ReaderFn(ordinal);
return value.ReaderFn(ordinal, column.Type.Name);
throw new NotSupportedException($"Could not find column mapping for type override: {csharpTypeOption.Type}");
}

public virtual string GetColumnReader(Column column, int ordinal, Query? query)
{
if (FindOverrideForQueryColumn(query, column) is { CsharpType: var csharpType })
return GetColumnReader(csharpType, ordinal);
return GetColumnReader(csharpType, column, ordinal);

foreach (var columnMapping in ColumnMappings.Values
.Where(columnMapping => DoesColumnMappingApply(columnMapping, column)))
{
if (column.IsArray)
return columnMapping.ReaderArrayFn?.Invoke(ordinal) ?? throw new InvalidOperationException("ReaderArrayFn is null");
return columnMapping.ReaderFn(ordinal);
return columnMapping.ReaderArrayFn?.Invoke(ordinal, column.Type.Name) ?? throw new InvalidOperationException("ReaderArrayFn is null");
return columnMapping.ReaderFn(ordinal, column.Type.Name);
}
throw new NotSupportedException($"column {column.Name} has unsupported column type: {column.Type.Name} in {GetType().Name}");
}
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Generators/CommonGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public string AddParametersToCommand(Query query)

var notNull = dbDriver.IsColumnNotNull(p.Column, query);
var writerFn = dbDriver.GetWriterFn(p.Column, query);
var paramToWrite = writerFn is null ? param : writerFn(param, notNull, dbDriver.Options.UseDapper);
var paramToWrite = writerFn is null ? param : writerFn(param, p.Column.Type.Name, notNull, dbDriver.Options.UseDapper, dbDriver.Options.DotnetFramework.IsDotnetLegacy());
var addParamToCommand = $"""{commandVar}.Parameters.AddWithValue("@{p.Column.Name}", {paramToWrite});""";
return addParamToCommand;
}).JoinByNewLine();
Expand Down Expand Up @@ -57,7 +57,7 @@ public string ConstructDapperParamsDict(Query query)

var notNull = dbDriver.IsColumnNotNull(p.Column, query);
var writerFn = dbDriver.GetWriterFn(p.Column, query);
var paramToWrite = writerFn is null ? $"{argsVar}.{param}" : writerFn($"{argsVar}.{param}", notNull, dbDriver.Options.UseDapper);
var paramToWrite = writerFn is null ? $"{argsVar}.{param}" : writerFn($"{argsVar}.{param}", p.Column.Type.Name, notNull, dbDriver.Options.UseDapper, dbDriver.Options.DotnetFramework.IsDotnetLegacy());
var addParamToDict = $"{queryParamsVar}.Add(\"{p.Column.Name}\", {paramToWrite});";
return addParamToDict;
});
Expand Down
32 changes: 16 additions & 16 deletions Drivers/MySqlConnectorDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public sealed partial class MySqlConnectorDriver(
{
{ "tinyint", new(Length: 1) }
},
ordinal => $"reader.GetBoolean({ordinal})"
readerFn: (ordinal, _) => $"reader.GetBoolean({ordinal})"
),
["short"] = new(
new()
Expand All @@ -35,7 +35,7 @@ public sealed partial class MySqlConnectorDriver(
{ "smallint", new() },
{ "year", new() }
},
ordinal => $"reader.GetInt16({ordinal})"
readerFn: (ordinal, _) => $"reader.GetInt16({ordinal})"
),
["int"] = new(
new()
Expand All @@ -44,15 +44,15 @@ public sealed partial class MySqlConnectorDriver(
{ "integer", new() },
{ "mediumint", new() }
},
ordinal => $"reader.GetInt32({ordinal})",
readerFn: (ordinal, _) => $"reader.GetInt32({ordinal})",
convertFunc: x => $"Convert.ToInt32{x}"
),
["long"] = new(
new()
{
{ "bigint", new() }
},
ordinal => $"reader.GetInt64({ordinal})",
readerFn: (ordinal, _) => $"reader.GetInt64({ordinal})",
convertFunc: x => $"Convert.ToInt64{x}"
),
["double"] = new(
Expand All @@ -61,14 +61,14 @@ public sealed partial class MySqlConnectorDriver(
{ "double", new() },
{ "float", new() }
},
ordinal => $"reader.GetDouble({ordinal})"
readerFn: (ordinal, _) => $"reader.GetDouble({ordinal})"
),
["decimal"] = new(
new()
{
{ "decimal", new() }
},
ordinal => $"reader.GetDecimal({ordinal})"
readerFn: (ordinal, _) => $"reader.GetDecimal({ordinal})"
),

/* Binary data types */
Expand All @@ -77,7 +77,7 @@ public sealed partial class MySqlConnectorDriver(
{
{ "bit", new() }
},
ordinal => $"reader.GetFieldValue<byte>({ordinal})"
readerFn: (ordinal, _) => $"reader.GetFieldValue<byte>({ordinal})"
),
["byte[]"] = new(
new()
Expand All @@ -89,7 +89,7 @@ public sealed partial class MySqlConnectorDriver(
{ "tinyblob", new() },
{ "varbinary", new() }
},
ordinal => $"reader.GetFieldValue<byte[]>({ordinal})"
readerFn: (ordinal, _) => $"reader.GetFieldValue<byte[]>({ordinal})"
),

/* String data types */
Expand All @@ -104,7 +104,7 @@ public sealed partial class MySqlConnectorDriver(
{ "varchar", new() },
{ "var_string", new() },
},
ordinal => $"reader.GetString({ordinal})"
readerFn: (ordinal, _) => $"reader.GetString({ordinal})"
),

/* Date and time data types */
Expand All @@ -115,14 +115,14 @@ public sealed partial class MySqlConnectorDriver(
{ "datetime", new() },
{ "timestamp", new() }
},
readerFn: ordinal => $"reader.GetDateTime({ordinal})"
readerFn: (ordinal, _) => $"reader.GetDateTime({ordinal})"
),
["TimeSpan"] = new(
new()
{
{ "time", new() }
},
readerFn: ordinal => $"reader.GetFieldValue<TimeSpan>({ordinal})"
readerFn: (ordinal, _) => $"reader.GetFieldValue<TimeSpan>({ordinal})"
),

/* Unstructured data types */
Expand All @@ -131,8 +131,8 @@ public sealed partial class MySqlConnectorDriver(
{
{ "json", new() }
},
readerFn: ordinal => $"JsonSerializer.Deserialize<JsonElement>(reader.GetString({ordinal}))",
writerFn: (el, notNull, isDapper) =>
readerFn: (ordinal, _) => $"JsonSerializer.Deserialize<JsonElement>(reader.GetString({ordinal}))",
writerFn: (el, _, notNull, isDapper, isLegacy) =>
{
if (notNull)
return $"{el}.GetRawText()";
Expand All @@ -150,7 +150,7 @@ public sealed partial class MySqlConnectorDriver(
{
{ "any", new() }
},
ordinal => $"reader.GetValue({ordinal})"
readerFn: (ordinal, _) => $"reader.GetValue({ordinal})"
)
};

Expand Down Expand Up @@ -621,7 +621,7 @@ private bool IsSetDataType(Column column)
return writerFn;

if (GetEnumType(column) is { } enumType && IsSetDataType(column, enumType))
return (el, notNull, isDapper) =>
return (el, dbType, notNull, isDapper, isLegacy) =>
{
var stringJoinStmt = $"string.Join(\",\", {el})";
var nullValue = isDapper ? "null" : "(object)DBNull.Value";
Expand All @@ -630,7 +630,7 @@ private bool IsSetDataType(Column column)
: $"{el} != null ? {stringJoinStmt} : {nullValue}";
};

static string DefaultWriterFn(string el, bool notNull, bool isDapper) => notNull ? el : $"{el} ?? (object)DBNull.Value";
static string DefaultWriterFn(string el, string dbType, bool notNull, bool isDapper, bool isLegacy) => notNull ? el : $"{el} ?? (object)DBNull.Value";
return Options.UseDapper ? null : DefaultWriterFn;
}

Expand Down
Loading
Loading