Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ClashBehaviour.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
namespace Dimension.DataFrame.Extensions;

/// <summary>
/// Defines the behavior when adding a column to a DataFrame and a column with the same name already exists
/// </summary>
public enum ClashBehaviour
{
/// <summary>
/// Keep the existing column and do not add the new column
/// </summary>
KeepOriginal,

/// <summary>
/// Remove the existing column and add the new column in its place
/// </summary>
ReplaceOriginal,

/// <summary>
/// Throw an InvalidOperationException when a name clash occurs (default behavior)
/// </summary>
Exception
}
2 changes: 1 addition & 1 deletion DataFrameExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static class DataFrameExtensionsCalculations
public static PrimitiveDataFrameColumn<T> Apply<T>(this PrimitiveDataFrameColumn<T> column, Func<T, T> operation, string name = "")
where T : unmanaged, INumber<T>
{
if (operation == null)
if (operation is null)
{
throw new ArgumentNullException(nameof(operation));
}
Expand Down
6 changes: 3 additions & 3 deletions DataFrameExtensionsArithmetic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static PrimitiveDataFrameColumn<T> Minus<T>(this PrimitiveDataFrameColumn

if (string.IsNullOrEmpty(name))
{
name = $"{column.Name}_Minus_{columnToSubtract.Name}";
name = $"{column.Name}-{columnToSubtract.Name}";
}

return new PrimitiveDataFrameColumn<T>(name, result);
Expand Down Expand Up @@ -99,8 +99,8 @@ public static PrimitiveDataFrameColumn<T> Times<T>(this PrimitiveDataFrameColumn

if (string.IsNullOrEmpty(name))
{
var otherNames = otherColumns.Select(c => c.Name);
name = $"{column.Name}_Times_{string.Join("_", otherNames)}";
var namesToConcat = new[] {column.Name}.Concat(otherColumns.Select(c => c.Name));
name = string.Join("*", namesToConcat);
}

return new PrimitiveDataFrameColumn<T>(name, result);
Expand Down
101 changes: 32 additions & 69 deletions DataFrameExtensionsFilters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,75 +49,7 @@ public static Microsoft.Data.Analysis.DataFrame Filter(this Microsoft.Data.Analy
var newColumns = new List<DataFrameColumn>();
foreach (var column in df.Columns)
{
DataFrameColumn newColumn;

// Support common numeric types
if (column.DataType == typeof(int))
{
newColumn = new PrimitiveDataFrameColumn<int>(column.Name);
}
else if (column.DataType == typeof(long))
{
newColumn = new PrimitiveDataFrameColumn<long>(column.Name);
}
else if (column.DataType == typeof(float))
{
newColumn = new PrimitiveDataFrameColumn<float>(column.Name);
}
else if (column.DataType == typeof(double))
{
newColumn = new PrimitiveDataFrameColumn<double>(column.Name);
}
else if (column.DataType == typeof(decimal))
{
newColumn = new PrimitiveDataFrameColumn<decimal>(column.Name);
}
// Support other common types
else if (column.DataType == typeof(bool))
{
newColumn = new PrimitiveDataFrameColumn<bool>(column.Name);
}
else if (column.DataType == typeof(byte))
{
newColumn = new PrimitiveDataFrameColumn<byte>(column.Name);
}
else if (column.DataType == typeof(sbyte))
{
newColumn = new PrimitiveDataFrameColumn<sbyte>(column.Name);
}
else if (column.DataType == typeof(short))
{
newColumn = new PrimitiveDataFrameColumn<short>(column.Name);
}
else if (column.DataType == typeof(ushort))
{
newColumn = new PrimitiveDataFrameColumn<ushort>(column.Name);
}
else if (column.DataType == typeof(uint))
{
newColumn = new PrimitiveDataFrameColumn<uint>(column.Name);
}
else if (column.DataType == typeof(ulong))
{
newColumn = new PrimitiveDataFrameColumn<ulong>(column.Name);
}
else if (column.DataType == typeof(char))
{
newColumn = new PrimitiveDataFrameColumn<char>(column.Name);
}
else if (column.DataType == typeof(DateTime))
{
newColumn = new PrimitiveDataFrameColumn<DateTime>(column.Name);
}
else if (column.DataType == typeof(string))
{
newColumn = new StringDataFrameColumn(column.Name);
}
else
{
throw new NotSupportedException($"Column type {column.DataType.Name} is not supported. Supported types: int, long, float, double, decimal, bool, byte, sbyte, short, ushort, uint, ulong, char, DateTime, string");
}

var newColumn = CreateColumnByType(column.DataType, column.Name);
newColumns.Add(newColumn);
}

Expand All @@ -137,4 +69,35 @@ public static Microsoft.Data.Analysis.DataFrame Filter(this Microsoft.Data.Analy

return newDf;
}

/// <summary>
/// Creates a new DataFrame column based on the specified type
/// </summary>
/// <param name="dataType">The type of data the column will hold</param>
/// <param name="columnName">The name for the new column</param>
/// <returns>A new DataFrameColumn of the appropriate type</returns>
/// <exception cref="NotSupportedException">Thrown when the data type is not supported</exception>
private static DataFrameColumn CreateColumnByType(Type dataType, string columnName)
{
// Use pattern matching for cleaner type checking
if (dataType == typeof(int)) return new PrimitiveDataFrameColumn<int>(columnName);
if (dataType == typeof(long)) return new PrimitiveDataFrameColumn<long>(columnName);
if (dataType == typeof(float)) return new PrimitiveDataFrameColumn<float>(columnName);
if (dataType == typeof(double)) return new PrimitiveDataFrameColumn<double>(columnName);
if (dataType == typeof(decimal)) return new PrimitiveDataFrameColumn<decimal>(columnName);
if (dataType == typeof(bool)) return new PrimitiveDataFrameColumn<bool>(columnName);
if (dataType == typeof(byte)) return new PrimitiveDataFrameColumn<byte>(columnName);
if (dataType == typeof(sbyte)) return new PrimitiveDataFrameColumn<sbyte>(columnName);
if (dataType == typeof(short)) return new PrimitiveDataFrameColumn<short>(columnName);
if (dataType == typeof(ushort)) return new PrimitiveDataFrameColumn<ushort>(columnName);
if (dataType == typeof(uint)) return new PrimitiveDataFrameColumn<uint>(columnName);
if (dataType == typeof(ulong)) return new PrimitiveDataFrameColumn<ulong>(columnName);
if (dataType == typeof(char)) return new PrimitiveDataFrameColumn<char>(columnName);
if (dataType == typeof(DateTime)) return new PrimitiveDataFrameColumn<DateTime>(columnName);
if (dataType == typeof(string)) return new StringDataFrameColumn(columnName);

throw new NotSupportedException(
$"Column type {dataType.Name} is not supported. " +
"Supported types: int, long, float, double, decimal, bool, byte, sbyte, short, ushort, uint, ulong, char, DateTime, string");
}
}
6 changes: 3 additions & 3 deletions DataFrameExtensionsNullsNaNs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ public static class DataFrameExtensionsNullsNaNs
public static PrimitiveDataFrameColumn<T> DropNulls<T>(this PrimitiveDataFrameColumn<T> column)
where T : unmanaged, INumber<T>
{
var newColumn = new PrimitiveDataFrameColumn<T>(column.Name, column.Length);
var validValues = new List<T?>();
foreach (var value in column)
{
var shouldAddValue = value != null && !(value is float f && float.IsNaN(f)) && !(value is double d && double.IsNaN(d));
if (shouldAddValue)
{
newColumn.Append(value);
validValues.Add(value);
}
}

return newColumn;
return new PrimitiveDataFrameColumn<T>(column.Name, validValues);
}

public static Microsoft.Data.Analysis.DataFrame DropNulls(this Microsoft.Data.Analysis.DataFrame df)
Expand Down
17 changes: 11 additions & 6 deletions DataFrameExtensionsRolling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public static PrimitiveDataFrameColumn<T> Rolling<T>(this PrimitiveDataFrameColu
where T : unmanaged, INumber<T>
{
var result = new PrimitiveDataFrameColumn<T>(column.Name + "_Rolling", column.Length);

// Pre-allocate a reusable buffer to avoid repeated allocations
var windowBuffer = new T?[windowSize];

for (var i = 0; i < column.Length; i++)
{
if (i < windowSize - 1)
Expand All @@ -32,19 +36,20 @@ public static PrimitiveDataFrameColumn<T> Rolling<T>(this PrimitiveDataFrameColu
continue;
}

var window = new List<T?>();
// Reuse the buffer instead of creating new List
var windowCount = 0;
for (var j = i - windowSize + 1; j <= i; j++)
{
if (!column[j].HasValue)
if (column[j].HasValue)
{
continue;
windowBuffer[windowCount++] = column[j];
}

window.Add(column[j]);
}

if (window.Count > 0)
if (windowCount > 0)
{
// Create a span/array view of only the valid values
var window = new ArraySegment<T?>(windowBuffer, 0, windowCount);
var opResult = operation(window);
result[i] = opResult;
}
Expand Down
48 changes: 31 additions & 17 deletions DataFrameExtensionsStatistics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ public static class DataFrameExtensionsStatistics
/// </summary>
/// <typeparam name="T">Numeric type</typeparam>
/// <param name="column">Column to calculate median for</param>
/// <returns>Median value, or null if column is empty or all values are null</returns>
public static T? Median<T>(this PrimitiveDataFrameColumn<T> column)
/// <returns>Median value as double, or null if column is empty or all values are null</returns>
public static double? Median<T>(this PrimitiveDataFrameColumn<T> column)
where T : unmanaged, INumber<T>
{
if (column == null || column.Length == 0)
{
return null;
}

var values = column.Where(v => v.HasValue).Select(v => v!.Value).OrderBy(v => v).ToList();
var values = column.Where(v => v.HasValue).Select(v => Convert.ToDouble(v!.Value)).OrderBy(v => v).ToList();

if (values.Count == 0)
{
Expand All @@ -71,7 +71,7 @@ public static class DataFrameExtensionsStatistics
if (values.Count % 2 == 0)
{
// Even number of elements - average the two middle values
return (values[middleIndex - 1] + values[middleIndex]) / T.CreateChecked(2);
return (values[middleIndex - 1] + values[middleIndex]) / 2.0;
}
else
{
Expand All @@ -95,7 +95,7 @@ public static class DataFrameExtensionsStatistics
}

/// <summary>
/// Calculates the variance of a column
/// Calculates the variance of a column using Welford's online algorithm for numerical stability
/// </summary>
/// <typeparam name="T">Numeric type</typeparam>
/// <param name="column">Column to calculate variance for</param>
Expand All @@ -109,18 +109,32 @@ public static class DataFrameExtensionsStatistics
return null;
}

var values = column.Where(v => v.HasValue).Select(v => Convert.ToDouble(v!.Value)).ToList();
// Single-pass variance calculation using Welford's algorithm
var count = 0;
var mean = 0.0;
var m2 = 0.0;

if (values.Count < (sample ? 2 : 1))
for (var i = 0; i < column.Length; i++)
{
return null;
var value = column[i];
if (value.HasValue)
{
count++;
var doubleValue = Convert.ToDouble(value.Value);
var delta = doubleValue - mean;
mean += delta / count;
var delta2 = doubleValue - mean;
m2 += delta * delta2;
}
}

var mean = values.Average();
var sumOfSquaredDifferences = values.Sum(v => Math.Pow(v - mean, 2));
var divisor = sample ? values.Count - 1 : values.Count;
if (count < (sample ? 2 : 1))
{
return null;
}

return sumOfSquaredDifferences / divisor;
var divisor = sample ? count - 1 : count;
return m2 / divisor;
}

/// <summary>
Expand Down Expand Up @@ -212,7 +226,7 @@ public static long Count<T>(this PrimitiveDataFrameColumn<T> column)
/// <typeparam name="T">Numeric type</typeparam>
/// <param name="column">Column to calculate statistics for</param>
/// <returns>Tuple containing (count, mean, stddev, min, 25th percentile, median, 75th percentile, max)</returns>
public static (long Count, T? Mean, double? StdDev, T? Min, T? Q25, T? Median, T? Q75, T? Max) Describe<T>(this PrimitiveDataFrameColumn<T> column)
public static (long Count, T? Mean, double? StdDev, T? Min, double? Q25, double? Median, double? Q75, T? Max) Describe<T>(this PrimitiveDataFrameColumn<T> column)
where T : unmanaged, INumber<T>
{
var count = column.Count();
Expand All @@ -233,16 +247,16 @@ public static (long Count, T? Mean, double? StdDev, T? Min, T? Q25, T? Median, T
/// <typeparam name="T">Numeric type</typeparam>
/// <param name="column">Column to calculate quantile for</param>
/// <param name="quantile">Quantile to calculate (0.0 to 1.0, e.g., 0.25 for 25th percentile)</param>
/// <returns>Quantile value, or null if column is empty</returns>
public static T? Quantile<T>(this PrimitiveDataFrameColumn<T> column, double quantile)
/// <returns>Quantile value as double, or null if column is empty</returns>
public static double? Quantile<T>(this PrimitiveDataFrameColumn<T> column, double quantile)
where T : unmanaged, INumber<T>
{
if (column == null || column.Length == 0 || quantile < 0 || quantile > 1)
{
return null;
}

var values = column.Where(v => v.HasValue).Select(v => v!.Value).OrderBy(v => v).ToList();
var values = column.Where(v => v.HasValue).Select(v => Convert.ToDouble(v!.Value)).OrderBy(v => v).ToList();

if (values.Count == 0)
{
Expand All @@ -258,7 +272,7 @@ public static (long Count, T? Mean, double? StdDev, T? Min, T? Q25, T? Median, T
return values[lowerIndex];
}

var weight = T.CreateChecked(index - lowerIndex);
var weight = index - lowerIndex;
return values[lowerIndex] + weight * (values[upperIndex] - values[lowerIndex]);
}
}
Loading
Loading