diff --git a/.gitignore b/.gitignore index cace7651..26a3d0c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ -.bonsai/Bonsai.exe* -.bonsai/Packages/ -.bonsai/Settings/ +**/.bonsai/Bonsai.exe* +**/.bonsai/Packages/ +**/.bonsai/Settings/ .vs/ /artifacts/ .venv diff --git a/src/Bonsai.ML.Torch/Bind.cs b/src/Bonsai.ML.Torch/Bind.cs new file mode 100644 index 00000000..c13ca1a8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Bind.cs @@ -0,0 +1,129 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Collections.Generic; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that gathers incoming tensors into zero or more tensors by concatenating them along the first dimension. +/// +/// +/// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count. +/// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again. +/// +[Combinator] +[Description("Gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class Bind +{ + private int _count = 1; + /// + /// Gets or sets the number of tensors to accumulate in the buffer before emitting. + /// + [Description("The number of tensors to accumulate in the buffer before emitting.")] + public int Count + { + get => _count; + set => _count = value <= 0 + ? throw new ArgumentOutOfRangeException("Count must be greater than zero.") + : value; + } + + private int _skip = 1; + /// + /// Gets or sets the number of tensors to skip after emitting the buffer. + /// + [Description("The number of tensors to skip after emitting the buffer.")] + public int Skip + { + get => _skip; + set => _skip = value < 0 + ? throw new ArgumentOutOfRangeException("Skip must be non-negative.") + : value; + } + + /// + /// Processes an observable sequence of tensors, buffering them and concatenating along the first dimension. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var count = Count; + var skip = Skip; + + Tensor buffer = null; + int current = 0; + Tensor idxSrc = null; + Tensor idxDst = null; + + return source.Subscribe(input => + { + if (input is null) + return; + + if (buffer is null) + { + var shape = input.shape.Prepend(count).ToArray(); + buffer = empty(shape, dtype: input.dtype, device: input.device); + + if (skip < count) + { + idxSrc = arange(skip, count, dtype: ScalarType.Int64, device: input.device); + idxDst = arange(0, count - skip, dtype: ScalarType.Int64, device: input.device); + } + } + + if (current >= 0) + { + buffer[current] = input; + } + + current++; + + if (current >= count) + { + var output = buffer.clone(); + if (skip < count) + { + var src = index_select(buffer, 0, idxSrc); + buffer.index_copy_(0, idxDst, src); + buffer[torch.TensorIndex.Slice(count - skip, null)].zero_(); + } + else + { + buffer.zero_(); + } + current = count - skip; + observer.OnNext(output); + } + }, + observer.OnError, + () => + { + var remainder = current; + + if (remainder > 0 && buffer is not null) + { + var outputShape = buffer.shape.ToArray(); + outputShape[0] = remainder; + var output = empty(outputShape, dtype: buffer.dtype, device: buffer.device); + for (int i = 0; i < remainder; i++) + { + output[i] = buffer[i]; + } + observer.OnNext(output.clone()); + } + + buffer?.Dispose(); + idxSrc?.Dispose(); + idxDst?.Dispose(); + observer.OnCompleted(); + }); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ConvertScalarType.cs b/src/Bonsai.ML.Torch/ConvertScalarType.cs index 2af31b13..745fa722 100644 --- a/src/Bonsai.ML.Torch/ConvertScalarType.cs +++ b/src/Bonsai.ML.Torch/ConvertScalarType.cs @@ -20,7 +20,7 @@ public class ConvertScalarType public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// Converts the data type of the input tensor to the newly specified scalar type. /// public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/ConvertToArray.cs b/src/Bonsai.ML.Torch/ConvertToArray.cs index 271cde16..2c6b3bb4 100644 --- a/src/Bonsai.ML.Torch/ConvertToArray.cs +++ b/src/Bonsai.ML.Torch/ConvertToArray.cs @@ -32,7 +32,7 @@ public override Expression Build(IEnumerable arguments) MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); methodInfo = methodInfo.MakeGenericMethod(ScalarTypeLookup.GetTypeFromScalarType(Type)); Expression sourceExpression = arguments.First(); - + return Expression.Call( Expression.Constant(this), methodInfo, @@ -43,12 +43,12 @@ public override Expression Build(IEnumerable arguments) /// /// Converts the input tensor into a flattened array of the specified element type. /// - /// - /// - /// + /// The element type of the output item. + /// The sequence of input tensors. + /// The sequence of output arrays of the specified element type. public IObservable Process(IObservable source) where T : unmanaged { - return source.Select(tensor => + return source.Select(tensor => { if (tensor.dtype != Type) { @@ -58,4 +58,4 @@ public IObservable Process(IObservable source) where T : unmanag }); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/ConvertToItem.cs b/src/Bonsai.ML.Torch/ConvertToItem.cs new file mode 100644 index 00000000..e91bf6e9 --- /dev/null +++ b/src/Bonsai.ML.Torch/ConvertToItem.cs @@ -0,0 +1,61 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that converts the input tensor into a single value of the specified element type. The input tensor must only contain a single element. +/// +[Combinator] +[Description("Converts the input tensor into a single value of the specified element type.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ConvertToItem : SingleArgumentExpressionBuilder +{ + /// + /// Gets or sets the type of the item. + /// + [Description("Gets or sets the type of the item.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + public override Expression Build(IEnumerable arguments) + { + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + var type = ScalarTypeLookup.GetTypeFromScalarType(Type); + methodInfo = methodInfo.MakeGenericMethod(type); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into a single item. + /// + /// The element type of the output item. + /// The sequence of input tensors. + /// The sequence of output items of the specified element type. + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + if (tensor.dtype != Type) + { + tensor = tensor.to_type(Type); + } + return tensor.item(); + }); + } +} diff --git a/src/Bonsai.ML.Torch/ConvertToNDArray.cs b/src/Bonsai.ML.Torch/ConvertToNDArray.cs index 171724e0..9c1ea652 100644 --- a/src/Bonsai.ML.Torch/ConvertToNDArray.cs +++ b/src/Bonsai.ML.Torch/ConvertToNDArray.cs @@ -41,7 +41,7 @@ public override Expression Build(IEnumerable arguments) Type arrayType = Array.CreateInstance(type, lengths).GetType(); methodInfo = methodInfo.MakeGenericMethod(type, arrayType); Expression sourceExpression = arguments.First(); - + return Expression.Call( Expression.Constant(this), methodInfo, @@ -52,13 +52,13 @@ public override Expression Build(IEnumerable arguments) /// /// Converts the input tensor into an array of the specified element type. /// - /// - /// - /// - /// + /// The element type of the output item. + /// The type of the output array. + /// The sequence of input tensors. + /// The sequence of output arrays of the specified element type and rank. public IObservable Process(IObservable source) where T : unmanaged { - return source.Select(tensor => + return source.Select(tensor => { if (tensor.dtype != Type) { @@ -68,4 +68,4 @@ public IObservable Process(IObservable source) wher }); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index b531bd9e..e594d4a4 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -11,109 +11,78 @@ using Bonsai.ML.Data; using TorchSharp; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Creates a tensor from the specified values. +/// Uses Python-like syntax to specify the tensor values. +/// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class CreateTensor : IScalarTypeProvider { /// - /// Creates a tensor from the specified values. - /// Uses Python-like syntax to specify the tensor values. - /// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// The data type of the tensor elements. /// - [Combinator] - [ResetCombinator] - [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] - [WorkflowElementCategory(ElementCategory.Source)] - public class CreateTensor : IScalarTypeProvider - { - /// - /// The data type of the tensor elements. - /// - [Description("The data type of the tensor elements.")] - [TypeConverter(typeof(ScalarTypeConverter))] - public ScalarType Type - { - get => _scalarType; - set - { - _scalarType = value; - _tensor = ConvertTensorScalarType(_tensor, value); - } - } - private ScalarType _scalarType = ScalarType.Float32; - - /// - /// The values of the tensor elements. - /// Uses Python-like syntax to specify the tensor values. - /// For example: "[[1, 2], [3, 4]]". - /// - [XmlIgnore] - [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] - [TypeConverter(typeof(TensorConverter))] - public Tensor Values - { - get => _tensor; - set => _tensor = ConvertTensorScalarType(value, _scalarType); - } + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; - private Tensor _tensor = zeros(1, dtype: ScalarType.Float32); + /// + /// The values of the tensor elements. + /// Uses Python-like syntax to specify the tensor values. + /// For example: "[[1, 2], [3, 4]]". + /// + [XmlIgnore] + [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] + [TypeConverter(typeof(TensorConverter))] + public Tensor Values + { + get => _values; + set => _values = value; + } - /// - /// This method converts the tensor to the specified scalar type. - /// - /// - /// We use this method in the setter of the and properties to ensure that the tensor is converted to the appropriate type and then returned. - /// - private static Tensor ConvertTensorScalarType(Tensor value, ScalarType scalarType) - { - return value.to_type(scalarType); - } + /// + /// The values of the tensor elements in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(Values, Type); + set => Values = TensorConverter.ConvertFromString(value, Type); + } - /// - /// The values of the tensor elements in XML string format. - /// - [Browsable(false)] - [XmlElement(nameof(Values))] - [EditorBrowsable(EditorBrowsableState.Never)] - public string ValuesXml - { - get => TensorConverter.ConvertToString(Values, _scalarType); - set => Values = TensorConverter.ConvertFromString(value, _scalarType); - } + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + [Description("The device on which to create the tensor.")] + public Device Device { get; set; } = null; - /// - /// The device on which to create the tensor. - /// - [XmlIgnore] - [Description("The device on which to create the tensor.")] - public Device Device - { - get => _device; - set => _device = value; - } - private Device _device = null; + private Tensor _values; - /// - /// Returns an observable sequence that creates a tensor from the specified values. - /// - public IObservable Process() - { - return Observable.Return(_device != null ? _tensor.to(_device).clone() : _tensor.clone()); - } + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process() + { + var device = Device ?? CPU; + return Observable.Return(_values.to(device).clone()); + } - /// - /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. - /// - public IObservable Process(IObservable source) - { - var tensor = _tensor.clone(); - return source.Take(1) - .Select(_ => { - tensor = tensor.to(_device); - return tensor; - }) - .Concat( - source.Skip(1) - .Select(_ => tensor) - ); - } + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + var tensor = _values.to(device).clone(); + return source.Select(_ => tensor); } } diff --git a/src/Bonsai.ML.Torch/Diagonal.cs b/src/Bonsai.ML.Torch/Diagonal.cs new file mode 100644 index 00000000..e45c56c1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Diagonal.cs @@ -0,0 +1,99 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Creates a diagonal matrix. If input is a 1D tensor, it creates a diagonal matrix with the elements of the tensor on the diagonal. +/// If input is a 2D tensor, it returns the diagonal elements as a 1D tensor. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a diagonal matrix with the given data type, size, and value.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Diagonal : IScalarTypeProvider +{ + private Tensor _values; + + /// + /// Gets or sets the values to include in the diagonal. + /// + [Description("The values to include in the diagonal.")] + [TypeConverter(typeof(TensorConverter))] + [XmlIgnore] + public Tensor Values + { + get => _values; + set => _values = value; + } + + /// + /// The values of the tensor elements in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(_values, Type); + set => _values = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Gets or sets the device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// Gets or sets the diagonal offset. Default is 0, which means the main diagonal. + /// + [Description("The diagonal offset. Default is 0, which means the main diagonal.")] + public int Offset { get; set; } = 0; + + /// + /// Creates an observable sequence containing a single diagonal matrix constructed from the specified data type, size and values. + /// + public IObservable Process() + { + var device = Device ?? CPU; + var inputTensor = _values.to(device); + return Observable.Return(diag(inputTensor, Offset)); + } + + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + return source.Select(value => _values is not null + ? diag(_values, Offset).to(device) + : diag(value, Offset).to(device)); + } + + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + var inputTensor = _values.to(device); + return source.Select(value => diag(inputTensor, Offset)); + } +} diff --git a/src/Bonsai.ML.Torch/Eye.cs b/src/Bonsai.ML.Torch/Eye.cs new file mode 100644 index 00000000..6db15b31 --- /dev/null +++ b/src/Bonsai.ML.Torch/Eye.cs @@ -0,0 +1,55 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Represents an operator that creates a sequence of identity matrices with the specified data type and size. + /// + [Combinator] + [ResetCombinator] + [Description("Creates a sequence of identity matrices with the specified data type and size.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Eye + { + /// + /// Gets or sets the size of the identity matrix. + /// + [Description("The size of the identity matrix.")] + public long Size { get; set; } = 0; + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// Gets or sets the device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// Creates an observable sequence containing a single identity matrix with the specified data type and size. + /// + public IObservable Process() + { + return Observable.Return(eye(Size, dtype: Type, device: Device)); + } + + /// + /// Generates an observable sequence of identity matrices for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => eye(Size, dtype: Type, device: Device)); + } + } +} diff --git a/src/Bonsai.ML.Torch/FormatTensor.cs b/src/Bonsai.ML.Torch/FormatTensor.cs new file mode 100644 index 00000000..d07cdd19 --- /dev/null +++ b/src/Bonsai.ML.Torch/FormatTensor.cs @@ -0,0 +1,35 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that applies a string formatting operation to all tensors in the sequence. +/// +[Combinator] +[Description("Applies a string formatting operation to all tensors in the sequence.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class FormatTensor +{ + /// + /// Gets or sets the string style used to format the tensor output. + /// + [Description("The string style used to format the tensor output.")] + public TensorStringStyle StringStyle { get; set; } + + /// + /// Applies a string formatting operation to all tensors in an observable sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => value.ToString(StringStyle)); + } +} diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 889ae4b7..e18d0606 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch { /// - /// Initializes the Torch device with the specified device type. + /// Represents an operator that initializes the Torch device with the specified device type. /// [Combinator] [Description("Initializes the Torch device with the specified device type.")] @@ -32,18 +32,30 @@ public class InitializeTorchDevice /// public IObservable Process() { - InitializeDeviceType(DeviceType); - return Observable.Return(new Device(DeviceType, DeviceIndex)); + return Observable.Defer(() => + { + var deviceType = DeviceType; + var deviceIndex = DeviceIndex; + InitializeDeviceType(deviceType); + return Observable.Return(new Device(deviceType, deviceIndex)); + }); } /// /// Initializes the Torch device when the input sequence produces an element. /// + /// + /// /// public IObservable Process(IObservable source) { - InitializeDeviceType(DeviceType); - return source.Select((_) => new Device(DeviceType, DeviceIndex)); + return source.Select((_) => + { + var deviceType = DeviceType; + var deviceIndex = DeviceIndex; + InitializeDeviceType(deviceType); + return new Device(deviceType, deviceIndex); + }); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinSpace.cs b/src/Bonsai.ML.Torch/LinSpace.cs index aaded985..954c311e 100644 --- a/src/Bonsai.ML.Torch/LinSpace.cs +++ b/src/Bonsai.ML.Torch/LinSpace.cs @@ -19,13 +19,13 @@ public class LinSpace /// The start of the range. /// [Description("The start of the range.")] - public int Start { get; set; } = 0; + public double Start { get; set; } = 0; /// /// The end of the range. /// [Description("The end of the range.")] - public int End { get; set; } = 1; + public double End { get; set; } = 1; /// /// The number of points to generate. diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs index af1e7f05..4694e184 100644 --- a/src/Bonsai.ML.Torch/LoadTensor.cs +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -16,7 +16,7 @@ public class LoadTensor /// /// The path to the file containing the tensor. /// - [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [FileNameFilter("Binary files(*.bin)|*.bin|Tensor files(*.pt)|*.pt|All files|*.*")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] [Description("The path to the file containing the tensor.")] public string Path { get; set; } @@ -30,4 +30,4 @@ public IObservable Process() return Observable.Return(Tensor.Load(Path)); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs new file mode 100644 index 00000000..4d9d90da --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// Represents an operator that ensures all tensor operations within the observable sequence are executed with gradient tracking enabled. +/// +[Combinator] +[Description("Ensures all tensor operations within the observable sequence are executed with gradient tracking enabled.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithGradientTracking +{ + /// + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed with gradient tracking enabled. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var enabledGrad = torch.enable_grad(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs new file mode 100644 index 00000000..dc341522 --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// Represents an operator that ensures all tensor operations within the observable sequence are executed in inference mode. +/// +[Combinator] +[Description("Ensures all tensor operations within the observable sequence are executed in inference mode.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithInferenceMode +{ + /// + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed in inference mode. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var inferenceMode = torch.inference_mode(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs new file mode 100644 index 00000000..5dcb34b0 --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// Represents an operator that ensures all tensor operations within the observable sequence are executed without tracking gradients. +/// +[Combinator] +[Description("Ensures all tensor operations within the observable sequence are executed without tracking gradients.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithNoGradientTracking +{ + /// + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed without tracking gradients. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var noGrad = torch.no_grad(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs index 1a3c4772..3edd03b5 100644 --- a/src/Bonsai.ML.Torch/SaveTensor.cs +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -16,7 +17,7 @@ public class SaveTensor /// /// The path to the file where the tensor will be saved. /// - [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [FileNameFilter("Binary files(*.bin)|*.bin|Tensor files(*.pt)|*.pt|All files|*.*")] [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] [Description("The path to the file where the tensor will be saved.")] public string Path { get; set; } = string.Empty; @@ -28,7 +29,7 @@ public class SaveTensor /// public IObservable Process(IObservable source) { - return source.Do(tensor => tensor.save(Path)); + return source.Do(tensor => tensor.Save(Path)); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/TensorConverter.cs b/src/Bonsai.ML.Torch/TensorConverter.cs index fba053c3..cd0343f5 100644 --- a/src/Bonsai.ML.Torch/TensorConverter.cs +++ b/src/Bonsai.ML.Torch/TensorConverter.cs @@ -56,7 +56,7 @@ public static Tensor ConvertFromString(string value, ScalarType scalarType) if (string.IsNullOrEmpty(value)) { - return empty(0, dtype: scalarType); + return null; } var tensorData = PythonDataHelper.Parse(value, returnType); @@ -133,6 +133,9 @@ public static string ConvertToString(Tensor tensor, ScalarType scalarType) { object tensorData; + if (tensor is null) + return string.Empty; + if (tensor.Dimensions == 0) { if (scalarType == ScalarType.Byte) diff --git a/src/Bonsai.ML.Torch/TensorOperatorConverter.cs b/src/Bonsai.ML.Torch/TensorOperatorConverter.cs new file mode 100644 index 00000000..6ac17150 --- /dev/null +++ b/src/Bonsai.ML.Torch/TensorOperatorConverter.cs @@ -0,0 +1,78 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reflection; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Provides a type converter that automatically converts all tensor properties with the attribute when the scalar type changes. +/// Apply this converter to classes using [TypeConverter(typeof(TensorOperatorConverter))]. +/// +public class TensorOperatorConverter : TypeConverter +{ + /// + public override PropertyDescriptorCollection GetProperties(ITypeDescriptorContext context, object value, Attribute[] attributes) + { + var properties = TypeDescriptor.GetProperties(value, attributes); + return new PropertyDescriptorCollection( + properties.Cast() + .Select(p => p.Name == nameof(IScalarTypeProvider.Type) + ? new ScalarTypePropertyDescriptor(p) + : p) + .ToArray()); + } + + /// + public override bool GetPropertiesSupported(ITypeDescriptorContext context) => true; + + private class ScalarTypePropertyDescriptor(PropertyDescriptor baseDescriptor) : PropertyDescriptor(baseDescriptor) + { + private readonly PropertyDescriptor _baseDescriptor = baseDescriptor; + public override Type ComponentType => _baseDescriptor.ComponentType; + public override bool IsReadOnly => _baseDescriptor.IsReadOnly; + public override Type PropertyType => _baseDescriptor.PropertyType; + + public override bool CanResetValue(object component) => _baseDescriptor.CanResetValue(component); + public override object GetValue(object component) => _baseDescriptor.GetValue(component); + public override void ResetValue(object component) => _baseDescriptor.ResetValue(component); + public override bool ShouldSerializeValue(object component) => _baseDescriptor.ShouldSerializeValue(component); + + public override void SetValue(object component, object value) + { + var oldValue = _baseDescriptor.GetValue(component); + if (Equals(oldValue, value)) + return; + + _baseDescriptor.SetValue(component, value); + + if (value is ScalarType newScalarType && component is IScalarTypeProvider) + { + ConvertAllTensorProperties(component, newScalarType); + } + } + + private static void ConvertAllTensorProperties(object component, ScalarType scalarType) + { + var properties = TypeDescriptor.GetProperties(component); + foreach (PropertyDescriptor property in properties) + { + // Check if this property uses TensorConverter + var converterAttr = property.Attributes.OfType().FirstOrDefault(); + + if (converterAttr?.ConverterTypeName?.Contains(nameof(TensorConverter)) != true) + continue; + + if (property.PropertyType != typeof(Tensor)) + continue; + + if (property.GetValue(component) is not Tensor tensor || tensor.dtype == scalarType) + continue; + + property.SetValue(component, tensor.to_type(scalarType)); + } + } + } +} diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index d0d3b7a8..bfc5624d 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -6,128 +6,127 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Converts the input value into a tensor. +/// +[Combinator] +[ResetCombinator] +[Description("Converts the input value into a tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ToTensor { /// - /// Converts the input value into a tensor. + /// The device on which to create the tensor. /// - [Combinator] - [ResetCombinator] - [Description("Converts the input value into a tensor.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToTensor - { - /// - /// The device on which to create the tensor. - /// - [Description("The device on which to create the tensor.")] - [XmlIgnore] - public Device Device { get; set; } = null; + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; - /// - /// The data type of the tensor. - /// - [Description("The data type of the tensor.")] - public ScalarType? Type { get; set; } = null; + /// + /// The data type of the tensor. + /// + [Description("The data type of the tensor.")] + public ScalarType? Type { get; set; } = null; - /// - /// Converts an int into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a double into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a byte into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a bool into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a float into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a long into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a short into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts an array into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts an IplImage into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); - } + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); + } - /// - /// Converts a Mat into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); - } + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Unbind.cs b/src/Bonsai.ML.Torch/Unbind.cs new file mode 100644 index 00000000..465267e6 --- /dev/null +++ b/src/Bonsai.ML.Torch/Unbind.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that deconstructs each tensor in the sequence into one or more tensors by splitting it along the specified dimension. +/// +[Combinator] +[Description("Deconstructs each tensor in the sequence into one or more tensors by splitting it along the specified dimension.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class Unbind +{ + private int _dimension = 0; + /// + /// Gets or sets the dimension along which to deconstruct the tensor. + /// + [Description("The dimension along which to deconstruct the tensor.")] + public int Dimension + { + get => _dimension; + set => _dimension = value; + } + + /// + /// Processes an observable sequence of tensors, deconstructing each tensor into a sequence of tensors by splitting along the specified dimension. + /// + public IObservable Process(IObservable source) + { + return source.SelectMany((input) => + { + if (input is null) + return null; + return input.unbind(_dimension).ToObservable(); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ValueTupleConverter.cs b/src/Bonsai.ML.Torch/ValueTupleConverter.cs new file mode 100644 index 00000000..73b57ed4 --- /dev/null +++ b/src/Bonsai.ML.Torch/ValueTupleConverter.cs @@ -0,0 +1,592 @@ +using System; +using System.ComponentModel; + +namespace Bonsai.ML.Torch; + +/// +/// Type converter for single-element value tuples. +/// +/// The type of the element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 1) + { + var item = (T)TypeDescriptor.GetConverter(typeof(T)).ConvertFromString(context, culture, elements[0].Trim()); + return ValueTuple.Create(item); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item = TypeDescriptor.GetConverter(typeof(T)).ConvertToString(context, culture, tuple.Item1); + return $"({item})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for single-element nullable value tuples. +/// +/// The type of the element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for two-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 2) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + return ValueTuple.Create(item1, item2); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + return $"({item1}, {item2})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for two-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for three-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + return ValueTuple.Create(item1, item2, item3); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + return $"({item1}, {item2}, {item3})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for three-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for four-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + return ValueTuple.Create(item1, item2, item3, item4); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + return $"({item1}, {item2}, {item3}, {item4})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for four-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for five-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + return $"({item1}, {item2}, {item3}, {item4}, {item5})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for five-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for six-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + var item6 = (T6)TypeDescriptor.GetConverter(typeof(T6)).ConvertFromString(context, culture, elements[5].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5, item6); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}, {typeof(T6).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + var item6 = TypeDescriptor.GetConverter(typeof(T6)).ConvertToString(context, culture, tuple.Item6); + return $"({item1}, {item2}, {item3}, {item4}, {item5}, {item6})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for six-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for seven-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +/// The type of the seventh element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + var item6 = (T6)TypeDescriptor.GetConverter(typeof(T6)).ConvertFromString(context, culture, elements[5].Trim()); + var item7 = (T7)TypeDescriptor.GetConverter(typeof(T7)).ConvertFromString(context, culture, elements[6].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5, item6, item7); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}, {typeof(T6).Name}, {typeof(T7).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + var item6 = TypeDescriptor.GetConverter(typeof(T6)).ConvertToString(context, culture, tuple.Item6); + var item7 = TypeDescriptor.GetConverter(typeof(T7)).ConvertToString(context, culture, tuple.Item7); + return $"({item1}, {item2}, {item3}, {item4}, {item5}, {item6}, {item7})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for seven-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +/// The type of the seventh element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +}