From bfaa374033c663033380e38a4e7a2f056cd18de3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 1 Aug 2025 18:37:28 +0100 Subject: [PATCH 01/28] Added classes to generate diagonal tensors and identity (eye) matrices --- src/Bonsai.ML.Torch/Diagonal.cs | 76 +++++++++++++++++++++++++++++++++ src/Bonsai.ML.Torch/Eye.cs | 55 ++++++++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Diagonal.cs create mode 100644 src/Bonsai.ML.Torch/Eye.cs diff --git a/src/Bonsai.ML.Torch/Diagonal.cs b/src/Bonsai.ML.Torch/Diagonal.cs new file mode 100644 index 00000000..362d40f8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Diagonal.cs @@ -0,0 +1,76 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a diagonal matrix. If input is a 1D tensor, it creates a diagonal matrix with the elements of the tensor on the diagonal. + /// If input is a 2D tensor, it returns the diagonal elements as a 1D tensor. + /// + [Combinator] + [ResetCombinator] + [Description("Creates a diagonal matrix with the given data type, size, and value.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Diagonal + { + /// + /// The input matrix. + /// + [Description("The input matrix.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public double[] Input { get; set; } + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The diagonal offset. Default is 0, which means the main diagonal. + /// + [Description("The diagonal offset. Default is 0, which means the main diagonal.")] + public int Offset { get; set; } = 0; + + /// + /// Creates a diagonal matrix. + /// + public IObservable Process() + { + var inputTensor = tensor(Input, dtype: Type, device: Device); + return Observable.Return(diag(inputTensor, Offset)); + } + + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var inputTensor = tensor(Input, dtype: Type, device: Device); + return source.Select(value => diag(inputTensor, Offset)); + } + + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var inputTensor = tensor(Input, dtype: Type, device: Device); + return source.Select(value => diag(inputTensor, Offset)); + } + } +} diff --git a/src/Bonsai.ML.Torch/Eye.cs b/src/Bonsai.ML.Torch/Eye.cs new file mode 100644 index 00000000..8c8252a0 --- /dev/null +++ b/src/Bonsai.ML.Torch/Eye.cs @@ -0,0 +1,55 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates an identity matrix with the given data type and size. + /// + [Combinator] + [ResetCombinator] + [Description("Creates an identity matrix with the given data type and size.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Eye + { + /// + /// The size of the identity matrix. + /// + [Description("The size of the identity matrix.")] + public long Size { get; set; } = 0; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// Creates an identity matrix with the given data type and size. + /// + public IObservable Process() + { + return Observable.Return(eye(Size, dtype: Type, device: Device)); + } + + /// + /// Generates an observable sequence of identity matrices for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => eye(Size, dtype: Type, device: Device)); + } + } +} From c1dd4ad9d5b9cb217c7ea5e835f0844255a856d3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 1 Aug 2025 18:38:18 +0100 Subject: [PATCH 02/28] Added a sink for debugging with extended print capabilities --- src/Bonsai.ML.Torch/PrintTensor.cs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/Bonsai.ML.Torch/PrintTensor.cs diff --git a/src/Bonsai.ML.Torch/PrintTensor.cs b/src/Bonsai.ML.Torch/PrintTensor.cs new file mode 100644 index 00000000..6775a11c --- /dev/null +++ b/src/Bonsai.ML.Torch/PrintTensor.cs @@ -0,0 +1,22 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class PrintTensor +{ + public TensorStringStyle StringStyle { get; set; } + public IObservable Process(IObservable source) + { + return source.Do(value => Console.WriteLine(value.ToString(StringStyle))); + } +} From 2fde34848b2dc6a742d0ecbd4836034313e4ded7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Aug 2025 17:15:32 +0100 Subject: [PATCH 03/28] Updated main gitignore to ignore nested bonsai environments --- .gitignore | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index cace7651..26a3d0c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ -.bonsai/Bonsai.exe* -.bonsai/Packages/ -.bonsai/Settings/ +**/.bonsai/Bonsai.exe* +**/.bonsai/Packages/ +**/.bonsai/Settings/ .vs/ /artifacts/ .venv From f9ae7332fc06b87b0c2d494c796185a3147ff38b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Aug 2025 17:17:15 +0100 Subject: [PATCH 04/28] Updated `LoadTensor` operator to support the native load method from torch --- src/Bonsai.ML.Torch/LoadTensor.cs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs index af1e7f05..0f210ae4 100644 --- a/src/Bonsai.ML.Torch/LoadTensor.cs +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -21,13 +21,29 @@ public class LoadTensor [Description("The path to the file containing the tensor.")] public string Path { get; set; } + /// + /// Indicates whether to use the native torch load method for the tensor. + /// + /// + /// If set to true, the native torch load method will be used. + /// If set to false, the tensor will be loaded using the TorchSharp method which is specific to .NET formats. + /// + [Description("Indicates whether to use the native torch load method for the tensor.")] + public bool UseNativeMethod { get; set; } = true; + /// /// Loads a tensor from the specified file. /// /// public IObservable Process() { - return Observable.Return(Tensor.Load(Path)); + switch (UseNativeMethod) + { + case true: + return Observable.Return(load(Path)); + case false: + return Observable.Return(Tensor.Load(Path)); + } } } } \ No newline at end of file From 835657afaf06134bbbaaf95cde1e555388cd0cf7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:29:00 +0100 Subject: [PATCH 05/28] Added class to `Buffer` tensors along the first dimension --- src/Bonsai.ML.Torch/Buffer.cs | 120 ++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Buffer.cs diff --git a/src/Bonsai.ML.Torch/Buffer.cs b/src/Bonsai.ML.Torch/Buffer.cs new file mode 100644 index 00000000..41b206d5 --- /dev/null +++ b/src/Bonsai.ML.Torch/Buffer.cs @@ -0,0 +1,120 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch; + +/// +/// This operator collects incoming tensors into a buffer and concatenates them along the first dimension. +/// +/// +/// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count. +/// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again. +/// +[Combinator] +[Description("Buffers the incoming tensors and concatenates them into a single tensor along the first dimension.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class Buffer +{ + private int _count = 1; + /// + /// Gets or sets the number of tensors to accumulate in the buffer before emitting. + /// + [Description("The number of tensors to accumulate in the buffer before emitting.")] + public int Count + { + get => _count; + set => _count = value <= 0 + ? throw new ArgumentOutOfRangeException("Count must be greater than zero.") + : value; + } + + private int _skip = 1; + /// + /// Gets or sets the number of tensors to skip after emitting the buffer. + /// + [Description("The number of tensors to skip after emitting the buffer.")] + public int Skip + { + get => _skip; + set => _skip = value < 0 + ? throw new ArgumentOutOfRangeException("Skip must be non-negative.") + : value; + } + + private torch.Tensor _buffer = null; + private int _current = 0; + private torch.Tensor _idxSrc = null; + private torch.Tensor _idxDst = null; + + /// + /// Processes an observable sequence of tensors, buffering them and concatenating along the first dimension. + /// + public IObservable Process(IObservable source) + { + var count = _count; + var skip = _skip; + var send = false; + _current = 0; + return source.Select((input) => + { + if (input is null) return false; + + if (_buffer is null) + { + var shape = input.shape.Prepend(count).ToArray(); + _buffer = torch.empty(shape, dtype: input.dtype, device: input.device); + + if (skip < count) + { + _idxSrc = torch.arange(skip, count, dtype: torch.ScalarType.Int64, device: input.device); + _idxDst = torch.arange(0, count - skip, dtype: torch.ScalarType.Int64, device: input.device); + } + } + + if (_current >= 0) + { + _buffer[_current] = input; + } + + _current++; + + if (_current >= count) + { + send = true; + } + return send; + }) + .Where(x => x) + .Select(x => + { + var output = _buffer.clone(); + if (skip < count) + { + var src = torch.index_select(_buffer, 0, _idxSrc); + _buffer.index_copy_(0, _idxDst, src); + _buffer[torch.TensorIndex.Slice(count - skip, null)].zero_(); + } + else + { + _buffer.zero_(); + } + _current = count - skip; + send = false; + return output; + }) + .Finally(() => + { + _buffer?.Dispose(); + _buffer = null; + _idxSrc?.Dispose(); + _idxSrc = null; + _idxDst?.Dispose(); + _idxDst = null; + _current = 0; + send = false; + }); + } +} \ No newline at end of file From b5e943ff67396b448cd9de7cd92dd0242fd2d280 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:29:50 +0100 Subject: [PATCH 06/28] Added class to `Decompose` a large tensor into a sequence of smaller tensors by decomposing along a dimension --- src/Bonsai.ML.Torch/Decompose.cs | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Decompose.cs diff --git a/src/Bonsai.ML.Torch/Decompose.cs b/src/Bonsai.ML.Torch/Decompose.cs new file mode 100644 index 00000000..cdc80075 --- /dev/null +++ b/src/Bonsai.ML.Torch/Decompose.cs @@ -0,0 +1,39 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch; + +/// +/// This operator decomposes each incoming tensor into a sequence of tensors by splitting it along the specified dimension. +/// +[Combinator] +[Description("Decomposes each incoming tensor into a sequence of tensors by splitting it along the specified dimension.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class Decompose +{ + private int _dimension = 0; + /// + /// Gets or sets the dimension along which to split the tensor. + /// + [Description("The dimension along which to split the tensor.")] + public int Dimension + { + get => _dimension; + set => _dimension = value; + } + + /// + /// Processes an observable sequence of tensors, decomposing each tensor into a sequence of tensors by splitting along the specified dimension. + /// + public IObservable Process(IObservable source) + { + return source.SelectMany((input) => + { + if (input is null) return null; + return input.unbind(_dimension).ToObservable(); + }); + } +} \ No newline at end of file From 3613429684ae7ebe23569f46755d8863afc31e6e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:30:34 +0100 Subject: [PATCH 07/28] Added feature to save a tensor using TorchSharp's .NET format or the native libtorch method --- src/Bonsai.ML.Torch/SaveTensor.cs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs index 1a3c4772..0e306941 100644 --- a/src/Bonsai.ML.Torch/SaveTensor.cs +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -21,6 +21,16 @@ public class SaveTensor [Description("The path to the file where the tensor will be saved.")] public string Path { get; set; } = string.Empty; + /// + /// Indicates whether to use the native torch save method for the tensor. + /// + /// + /// If set to true, the native torch save method will be used. + /// If set to false, the tensor will be saved using the TorchSharp method which is specific to .NET formats. + /// + [Description("Indicates whether to use the native torch save method for the tensor.")] + public bool UseNativeMethod { get; set; } = true; + /// /// Saves the input tensor to the specified file. /// @@ -28,7 +38,13 @@ public class SaveTensor /// public IObservable Process(IObservable source) { - return source.Do(tensor => tensor.save(Path)); + return source.Do(tensor => + { + if (UseNativeMethod) + tensor.save(Path); + else + tensor.Save(Path); + }); } } } \ No newline at end of file From 048b9a7c9cbc53059962108ec82b893619eabf96 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:31:00 +0100 Subject: [PATCH 08/28] Added `.pt` file filters for load/save tensor objects --- src/Bonsai.ML.Torch/LoadTensor.cs | 2 +- src/Bonsai.ML.Torch/SaveTensor.cs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs index 0f210ae4..0d6718c7 100644 --- a/src/Bonsai.ML.Torch/LoadTensor.cs +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -16,7 +16,7 @@ public class LoadTensor /// /// The path to the file containing the tensor. /// - [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [FileNameFilter("Binary files(*.bin)|*.bin|Tensor files(*.pt)|*.pt|All files|*.*")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] [Description("The path to the file containing the tensor.")] public string Path { get; set; } diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs index 0e306941..1d568f8c 100644 --- a/src/Bonsai.ML.Torch/SaveTensor.cs +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -16,7 +17,7 @@ public class SaveTensor /// /// The path to the file where the tensor will be saved. /// - [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [FileNameFilter("Binary files(*.bin)|*.bin|Tensor files(*.pt)|*.pt|All files|*.*")] [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] [Description("The path to the file where the tensor will be saved.")] public string Path { get; set; } = string.Empty; From e49bf085835382bef8a31754497ac9cd57400c98 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:35:30 +0100 Subject: [PATCH 09/28] Added operators to manage downstream tensor operations in different modes (enabled grad, inference, no grad) --- .../ObserveWithGradientTracking.cs | 33 +++++++++++++++++++ .../ObserveWithInferenceMode.cs | 33 +++++++++++++++++++ .../ObserveWithNoGradientTracking.cs | 33 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs create mode 100644 src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs create mode 100644 src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs diff --git a/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs new file mode 100644 index 00000000..0fc1fd05 --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// This operator ensures that all tensor operations within the observable sequence are executed with gradient tracking enabled. +/// +[Combinator] +[Description("Ensures that all tensor operations within the observable sequence are executed with gradient tracking enabled.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithGradientTracking +{ + /// + /// Processes an observable sequence, ensuring all tensor operations are executed with gradient tracking enabled. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var enabledGrad = torch.enable_grad(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs new file mode 100644 index 00000000..d34aa39d --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// This operator ensures that all tensor operations within the observable sequence are executed in inference mode. +/// +[Combinator] +[Description("Ensures that all tensor operations within the observable sequence are executed in inference mode.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithInferenceMode +{ + /// + /// Processes an observable sequence, executing all tensor operations in inference mode. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var inferenceMode = torch.inference_mode(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs new file mode 100644 index 00000000..3d3fc507 --- /dev/null +++ b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs @@ -0,0 +1,33 @@ +using System; +using System.Reactive; +using System.Reactive.Linq; +using TorchSharp; +using Bonsai; +using System.ComponentModel; + +/// +/// This operator ensures that all tensor operations within the observable sequence are executed without tracking gradients. +/// +[Combinator] +[Description("Ensures that all tensor operations within the observable sequence are executed without tracking gradients.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ObserveWithNoGradientTracking +{ + /// + /// Processes an observable sequence, executing all tensor operations without tracking gradients. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var sourceObserver = Observer.Create(value => + { + using var noGrad = torch.no_grad(); + observer.OnNext(value); + }, + observer.OnError, + observer.OnCompleted); + return source.SubscribeSafe(sourceObserver); + }); + } +} \ No newline at end of file From 787925f7caf7bb98feefc49700c82874c0312949 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 14:36:41 +0100 Subject: [PATCH 10/28] Updated `TensorConverter` to handle case when tensor is null and return empty string rather than empty tensor --- src/Bonsai.ML.Torch/TensorConverter.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/TensorConverter.cs b/src/Bonsai.ML.Torch/TensorConverter.cs index fba053c3..cd0343f5 100644 --- a/src/Bonsai.ML.Torch/TensorConverter.cs +++ b/src/Bonsai.ML.Torch/TensorConverter.cs @@ -56,7 +56,7 @@ public static Tensor ConvertFromString(string value, ScalarType scalarType) if (string.IsNullOrEmpty(value)) { - return empty(0, dtype: scalarType); + return null; } var tensorData = PythonDataHelper.Parse(value, returnType); @@ -133,6 +133,9 @@ public static string ConvertToString(Tensor tensor, ScalarType scalarType) { object tensorData; + if (tensor is null) + return string.Empty; + if (tensor.Dimensions == 0) { if (scalarType == ScalarType.Byte) From 8fc49471d286bca6b0bb95612351ebe4f94ecffe Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 15 Oct 2025 13:39:28 +0100 Subject: [PATCH 11/28] Remove unnecessary reset of current index in `Buffer` class --- src/Bonsai.ML.Torch/Buffer.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/Buffer.cs b/src/Bonsai.ML.Torch/Buffer.cs index 41b206d5..45133269 100644 --- a/src/Bonsai.ML.Torch/Buffer.cs +++ b/src/Bonsai.ML.Torch/Buffer.cs @@ -113,7 +113,6 @@ public int Skip _idxSrc = null; _idxDst?.Dispose(); _idxDst = null; - _current = 0; send = false; }); } From 5a19f854ac80573da13c85c721df6244f8dc4978 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 30 Oct 2025 18:13:15 +0000 Subject: [PATCH 12/28] Refactored `InitializeTorchDevice` to ensure the device initialization is only happening on subscription and not before that --- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 889ae4b7..9ea67c16 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -32,8 +32,11 @@ public class InitializeTorchDevice /// public IObservable Process() { - InitializeDeviceType(DeviceType); - return Observable.Return(new Device(DeviceType, DeviceIndex)); + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType, DeviceIndex)); + }); } /// @@ -42,8 +45,11 @@ public IObservable Process() /// public IObservable Process(IObservable source) { - InitializeDeviceType(DeviceType); - return source.Select((_) => new Device(DeviceType, DeviceIndex)); + return source.Select((_) => + { + InitializeDeviceType(DeviceType); + return new Device(DeviceType, DeviceIndex); + }); } } } \ No newline at end of file From a4c15b344c4e756f2cfdb8746df2b981dc05381a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Sat, 1 Nov 2025 21:39:49 +0000 Subject: [PATCH 13/28] Added an abstract base class for operators that contain tensor properties --- src/Bonsai.ML.Torch/TensorContainerBase.cs | 75 ++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/Bonsai.ML.Torch/TensorContainerBase.cs diff --git a/src/Bonsai.ML.Torch/TensorContainerBase.cs b/src/Bonsai.ML.Torch/TensorContainerBase.cs new file mode 100644 index 00000000..9ecfc36d --- /dev/null +++ b/src/Bonsai.ML.Torch/TensorContainerBase.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Base class for operators that contain tensor properties. Provides automatic scalar type conversion on registered tensors. +/// +public abstract class TensorContainerBase : IScalarTypeProvider +{ + private ScalarType _scalarType = ScalarType.Float32; + private readonly List _registeredTensors = new(); + private readonly object _sync = new(); + + /// + public ScalarType Type + { + get => _scalarType; + set + { + if (_scalarType == value) return; + _scalarType = value; + ConvertAllTensors(); + } + } + + /// + /// Registers a tensor property for automatic scalar type conversion. + /// + /// + /// + protected void RegisterTensor(Func getter, Action setter) + { + _registeredTensors.Add(new RegisteredTensor(getter, setter)); + } + + /// + /// Convert a single tensor to the current scalar type. + /// + protected Tensor ConvertTensor(Tensor value) + { + return value.dtype == _scalarType ? value : value.to_type(_scalarType); + } + + /// + /// Converts all tensor properties to the current scalar type. + /// + protected void ConvertAllTensors() + { + lock (_sync) + { + foreach (var registeredTensor in _registeredTensors) + { + var tensor = registeredTensor.Getter(); + if (tensor is null || tensor.dtype == _scalarType) continue; + registeredTensor.Setter(tensor.to_type(_scalarType)); + } + } + } + + private readonly struct RegisteredTensor( + Func getter, + Action setter) + { + public Func Getter { get; } = getter; + public Action Setter { get; } = setter; + } +} From 9be41b867256a0c4a2a8275ac2a9ecbf73ffd149 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 10 Nov 2025 14:55:57 +0000 Subject: [PATCH 14/28] Added XML docs to classes --- src/Bonsai.ML.Torch/PrintTensor.cs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/PrintTensor.cs b/src/Bonsai.ML.Torch/PrintTensor.cs index 6775a11c..a7509855 100644 --- a/src/Bonsai.ML.Torch/PrintTensor.cs +++ b/src/Bonsai.ML.Torch/PrintTensor.cs @@ -9,12 +9,25 @@ namespace Bonsai.ML.Torch; +/// +/// Prints the string representation of incoming tensors to the console. +/// [Combinator] -[Description("")] +[Description("Prints the string representation of incoming tensors to the console.")] [WorkflowElementCategory(ElementCategory.Sink)] public class PrintTensor { + /// + /// Gets or sets the string style used to format the tensor output. + /// + [Description("The string style used to format the tensor output.")] public TensorStringStyle StringStyle { get; set; } + + /// + /// Processes the input sequence of tensors and prints their string representations to the console. + /// + /// + /// public IObservable Process(IObservable source) { return source.Do(value => Console.WriteLine(value.ToString(StringStyle))); From 31cee0d57e1d576867925bc25965c71104840572 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 10 Nov 2025 19:06:59 +0000 Subject: [PATCH 15/28] Updated property name from `UseNativeMethod` to `UseNativeTorchMethod` to avoid confusion with what is considered "native" in this case (is .NET native or torch native?) --- src/Bonsai.ML.Torch/LoadTensor.cs | 4 ++-- src/Bonsai.ML.Torch/SaveTensor.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs index 0d6718c7..6b179ac4 100644 --- a/src/Bonsai.ML.Torch/LoadTensor.cs +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -29,7 +29,7 @@ public class LoadTensor /// If set to false, the tensor will be loaded using the TorchSharp method which is specific to .NET formats. /// [Description("Indicates whether to use the native torch load method for the tensor.")] - public bool UseNativeMethod { get; set; } = true; + public bool UseNativeTorchMethod { get; set; } = false; /// /// Loads a tensor from the specified file. @@ -37,7 +37,7 @@ public class LoadTensor /// public IObservable Process() { - switch (UseNativeMethod) + switch (UseNativeTorchMethod) { case true: return Observable.Return(load(Path)); diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs index 1d568f8c..72127b11 100644 --- a/src/Bonsai.ML.Torch/SaveTensor.cs +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -30,7 +30,7 @@ public class SaveTensor /// If set to false, the tensor will be saved using the TorchSharp method which is specific to .NET formats. /// [Description("Indicates whether to use the native torch save method for the tensor.")] - public bool UseNativeMethod { get; set; } = true; + public bool UseNativeTorchMethod { get; set; } = false; /// /// Saves the input tensor to the specified file. @@ -41,7 +41,7 @@ public IObservable Process(IObservable source) { return source.Do(tensor => { - if (UseNativeMethod) + if (UseNativeTorchMethod) tensor.save(Path); else tensor.Save(Path); From 26900fd92b03055780768b9be555b737b79f51cd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 11 Nov 2025 13:36:57 +0000 Subject: [PATCH 16/28] Renamed `Decompose` to `Deconstruct` for better alignment with operators true function (splitting along dimension) --- src/Bonsai.ML.Torch/{Decompose.cs => Deconstruct.cs} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename src/Bonsai.ML.Torch/{Decompose.cs => Deconstruct.cs} (55%) diff --git a/src/Bonsai.ML.Torch/Decompose.cs b/src/Bonsai.ML.Torch/Deconstruct.cs similarity index 55% rename from src/Bonsai.ML.Torch/Decompose.cs rename to src/Bonsai.ML.Torch/Deconstruct.cs index cdc80075..e454db22 100644 --- a/src/Bonsai.ML.Torch/Decompose.cs +++ b/src/Bonsai.ML.Torch/Deconstruct.cs @@ -7,18 +7,18 @@ namespace Bonsai.ML.Torch; /// -/// This operator decomposes each incoming tensor into a sequence of tensors by splitting it along the specified dimension. +/// This operator deconstructs each incoming tensor into a sequence of tensors by splitting it along the specified dimension. /// [Combinator] -[Description("Decomposes each incoming tensor into a sequence of tensors by splitting it along the specified dimension.")] +[Description("Deconstructs each incoming tensor into a sequence of tensors by splitting it along the specified dimension.")] [WorkflowElementCategory(ElementCategory.Combinator)] -public class Decompose +public class Deconstruct { private int _dimension = 0; /// - /// Gets or sets the dimension along which to split the tensor. + /// Gets or sets the dimension along which to deconstruct the tensor. /// - [Description("The dimension along which to split the tensor.")] + [Description("The dimension along which to deconstruct the tensor.")] public int Dimension { get => _dimension; @@ -26,7 +26,7 @@ public int Dimension } /// - /// Processes an observable sequence of tensors, decomposing each tensor into a sequence of tensors by splitting along the specified dimension. + /// Processes an observable sequence of tensors, deconstructing each tensor into a sequence of tensors by splitting along the specified dimension. /// public IObservable Process(IObservable source) { From 81888662cd62c09f3fafd6013244ec2764ee33aa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Nov 2025 12:42:43 +0000 Subject: [PATCH 17/28] Apply suggestions from code review Co-authored-by: glopesdev --- src/Bonsai.ML.Torch/Buffer.cs | 7 ++++--- src/Bonsai.ML.Torch/Deconstruct.cs | 4 ++-- src/Bonsai.ML.Torch/Diagonal.cs | 14 +++++++------- src/Bonsai.ML.Torch/Eye.cs | 12 ++++++------ src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 6 ++++-- src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs | 6 +++--- src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs | 6 +++--- .../ObserveWithNoGradientTracking.cs | 6 +++--- src/Bonsai.ML.Torch/PrintTensor.cs | 14 +++++++------- src/Bonsai.ML.Torch/TensorContainerBase.cs | 4 ++-- 10 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/Bonsai.ML.Torch/Buffer.cs b/src/Bonsai.ML.Torch/Buffer.cs index 45133269..56a7de62 100644 --- a/src/Bonsai.ML.Torch/Buffer.cs +++ b/src/Bonsai.ML.Torch/Buffer.cs @@ -7,14 +7,14 @@ namespace Bonsai.ML.Torch; /// -/// This operator collects incoming tensors into a buffer and concatenates them along the first dimension. +/// Represents an operator that gathers incoming tensors into zero or more tensors by concatenating them along the first dimension. /// /// /// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count. /// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again. /// [Combinator] -[Description("Buffers the incoming tensors and concatenates them into a single tensor along the first dimension.")] +[Description("Gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class Buffer { @@ -60,7 +60,8 @@ public int Skip _current = 0; return source.Select((input) => { - if (input is null) return false; + if (input is null) + return false; if (_buffer is null) { diff --git a/src/Bonsai.ML.Torch/Deconstruct.cs b/src/Bonsai.ML.Torch/Deconstruct.cs index e454db22..d3962668 100644 --- a/src/Bonsai.ML.Torch/Deconstruct.cs +++ b/src/Bonsai.ML.Torch/Deconstruct.cs @@ -7,10 +7,10 @@ namespace Bonsai.ML.Torch; /// -/// This operator deconstructs each incoming tensor into a sequence of tensors by splitting it along the specified dimension. +/// Represents an operator that deconstructs each tensor in the sequence into one or more tensors by splitting it along the specified dimension. /// [Combinator] -[Description("Deconstructs each incoming tensor into a sequence of tensors by splitting it along the specified dimension.")] +[Description("Deconstructs each tensor in the sequence into one or more tensors by splitting it along the specified dimension.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class Deconstruct { diff --git a/src/Bonsai.ML.Torch/Diagonal.cs b/src/Bonsai.ML.Torch/Diagonal.cs index 362d40f8..d5c5cfc2 100644 --- a/src/Bonsai.ML.Torch/Diagonal.cs +++ b/src/Bonsai.ML.Torch/Diagonal.cs @@ -17,33 +17,33 @@ namespace Bonsai.ML.Torch public class Diagonal { /// - /// The input matrix. + /// Gets or sets the values to include in the diagonal. /// - [Description("The input matrix.")] + [Description("The values to include in the diagonal.")] [TypeConverter(typeof(UnidimensionalArrayConverter))] - public double[] Input { get; set; } + public double[] Values { get; set; } /// - /// The data type of the tensor elements. + /// Gets or sets the data type of the tensor elements. /// [Description("The data type of the tensor elements.")] public ScalarType? Type { get; set; } /// - /// The device on which to create the tensor. + /// Gets or sets the device on which to create the tensor. /// [Description("The device on which to create the tensor.")] [XmlIgnore] public Device Device { get; set; } = null; /// - /// The diagonal offset. Default is 0, which means the main diagonal. + /// Gets or sets the diagonal offset. Default is 0, which means the main diagonal. /// [Description("The diagonal offset. Default is 0, which means the main diagonal.")] public int Offset { get; set; } = 0; /// - /// Creates a diagonal matrix. + /// Creates an observable sequence containing a single diagonal matrix constructed from the specified data type, size and values. /// public IObservable Process() { diff --git a/src/Bonsai.ML.Torch/Eye.cs b/src/Bonsai.ML.Torch/Eye.cs index 8c8252a0..6db15b31 100644 --- a/src/Bonsai.ML.Torch/Eye.cs +++ b/src/Bonsai.ML.Torch/Eye.cs @@ -7,35 +7,35 @@ namespace Bonsai.ML.Torch { /// - /// Creates an identity matrix with the given data type and size. + /// Represents an operator that creates a sequence of identity matrices with the specified data type and size. /// [Combinator] [ResetCombinator] - [Description("Creates an identity matrix with the given data type and size.")] + [Description("Creates a sequence of identity matrices with the specified data type and size.")] [WorkflowElementCategory(ElementCategory.Source)] public class Eye { /// - /// The size of the identity matrix. + /// Gets or sets the size of the identity matrix. /// [Description("The size of the identity matrix.")] public long Size { get; set; } = 0; /// - /// The data type of the tensor elements. + /// Gets or sets the data type of the tensor elements. /// [Description("The data type of the tensor elements.")] public ScalarType? Type { get; set; } = null; /// - /// The device on which to create the tensor. + /// Gets or sets the device on which to create the tensor. /// [Description("The device on which to create the tensor.")] [XmlIgnore] public Device Device { get; set; } = null; /// - /// Creates an identity matrix with the given data type and size. + /// Creates an observable sequence containing a single identity matrix with the specified data type and size. /// public IObservable Process() { diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 9ea67c16..69131c54 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -34,8 +34,10 @@ public IObservable Process() { return Observable.Defer(() => { - InitializeDeviceType(DeviceType); - return Observable.Return(new Device(DeviceType, DeviceIndex)); + var deviceType = DeviceType; + var deviceIndex = DeviceIndex; + InitializeDeviceType(deviceType); + return Observable.Return(new Device(deviceType, deviceIndex)); }); } diff --git a/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs index 0fc1fd05..4d9d90da 100644 --- a/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs +++ b/src/Bonsai.ML.Torch/ObserveWithGradientTracking.cs @@ -6,15 +6,15 @@ using System.ComponentModel; /// -/// This operator ensures that all tensor operations within the observable sequence are executed with gradient tracking enabled. +/// Represents an operator that ensures all tensor operations within the observable sequence are executed with gradient tracking enabled. /// [Combinator] -[Description("Ensures that all tensor operations within the observable sequence are executed with gradient tracking enabled.")] +[Description("Ensures all tensor operations within the observable sequence are executed with gradient tracking enabled.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class ObserveWithGradientTracking { /// - /// Processes an observable sequence, ensuring all tensor operations are executed with gradient tracking enabled. + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed with gradient tracking enabled. /// public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs index d34aa39d..dc341522 100644 --- a/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs +++ b/src/Bonsai.ML.Torch/ObserveWithInferenceMode.cs @@ -6,15 +6,15 @@ using System.ComponentModel; /// -/// This operator ensures that all tensor operations within the observable sequence are executed in inference mode. +/// Represents an operator that ensures all tensor operations within the observable sequence are executed in inference mode. /// [Combinator] -[Description("Ensures that all tensor operations within the observable sequence are executed in inference mode.")] +[Description("Ensures all tensor operations within the observable sequence are executed in inference mode.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class ObserveWithInferenceMode { /// - /// Processes an observable sequence, executing all tensor operations in inference mode. + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed in inference mode. /// public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs index 3d3fc507..5dcb34b0 100644 --- a/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs +++ b/src/Bonsai.ML.Torch/ObserveWithNoGradientTracking.cs @@ -6,15 +6,15 @@ using System.ComponentModel; /// -/// This operator ensures that all tensor operations within the observable sequence are executed without tracking gradients. +/// Represents an operator that ensures all tensor operations within the observable sequence are executed without tracking gradients. /// [Combinator] -[Description("Ensures that all tensor operations within the observable sequence are executed without tracking gradients.")] +[Description("Ensures all tensor operations within the observable sequence are executed without tracking gradients.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class ObserveWithNoGradientTracking { /// - /// Processes an observable sequence, executing all tensor operations without tracking gradients. + /// Returns an observable sequence which is identical to the source sequence, but where all tensor operations are executed without tracking gradients. /// public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/PrintTensor.cs b/src/Bonsai.ML.Torch/PrintTensor.cs index a7509855..d07cdd19 100644 --- a/src/Bonsai.ML.Torch/PrintTensor.cs +++ b/src/Bonsai.ML.Torch/PrintTensor.cs @@ -10,12 +10,12 @@ namespace Bonsai.ML.Torch; /// -/// Prints the string representation of incoming tensors to the console. +/// Represents an operator that applies a string formatting operation to all tensors in the sequence. /// [Combinator] -[Description("Prints the string representation of incoming tensors to the console.")] -[WorkflowElementCategory(ElementCategory.Sink)] -public class PrintTensor +[Description("Applies a string formatting operation to all tensors in the sequence.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class FormatTensor { /// /// Gets or sets the string style used to format the tensor output. @@ -24,12 +24,12 @@ public class PrintTensor public TensorStringStyle StringStyle { get; set; } /// - /// Processes the input sequence of tensors and prints their string representations to the console. + /// Applies a string formatting operation to all tensors in an observable sequence. /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { - return source.Do(value => Console.WriteLine(value.ToString(StringStyle))); + return source.Select(value => value.ToString(StringStyle)); } } diff --git a/src/Bonsai.ML.Torch/TensorContainerBase.cs b/src/Bonsai.ML.Torch/TensorContainerBase.cs index 9ecfc36d..bb3e6a72 100644 --- a/src/Bonsai.ML.Torch/TensorContainerBase.cs +++ b/src/Bonsai.ML.Torch/TensorContainerBase.cs @@ -11,9 +11,9 @@ namespace Bonsai.ML.Torch; /// -/// Base class for operators that contain tensor properties. Provides automatic scalar type conversion on registered tensors. +/// Provides an abstract base class for operators that contain tensor properties. Automatic scalar type conversion is provided for all registered tensors. /// -public abstract class TensorContainerBase : IScalarTypeProvider +public abstract class TensorOperatorBase : IScalarTypeProvider { private ScalarType _scalarType = ScalarType.Float32; private readonly List _registeredTensors = new(); From b3c9b17b53cf400b63f9676050aa70a7097c9018 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Nov 2025 13:52:44 +0000 Subject: [PATCH 18/28] Renamed `PrintTensor` to `FormatTensor` --- src/Bonsai.ML.Torch/{PrintTensor.cs => FormatTensor.cs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/Bonsai.ML.Torch/{PrintTensor.cs => FormatTensor.cs} (100%) diff --git a/src/Bonsai.ML.Torch/PrintTensor.cs b/src/Bonsai.ML.Torch/FormatTensor.cs similarity index 100% rename from src/Bonsai.ML.Torch/PrintTensor.cs rename to src/Bonsai.ML.Torch/FormatTensor.cs From e7b71de4ed0e52e41541470bc604c268774837b4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Nov 2025 13:53:23 +0000 Subject: [PATCH 19/28] Renamed `TensorContainer` to `TensorOperator` --- .../{TensorContainerBase.cs => TensorOperatorBase.cs} | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) rename src/Bonsai.ML.Torch/{TensorContainerBase.cs => TensorOperatorBase.cs} (95%) diff --git a/src/Bonsai.ML.Torch/TensorContainerBase.cs b/src/Bonsai.ML.Torch/TensorOperatorBase.cs similarity index 95% rename from src/Bonsai.ML.Torch/TensorContainerBase.cs rename to src/Bonsai.ML.Torch/TensorOperatorBase.cs index bb3e6a72..0767fa3a 100644 --- a/src/Bonsai.ML.Torch/TensorContainerBase.cs +++ b/src/Bonsai.ML.Torch/TensorOperatorBase.cs @@ -25,7 +25,8 @@ public ScalarType Type get => _scalarType; set { - if (_scalarType == value) return; + if (_scalarType == value) + return; _scalarType = value; ConvertAllTensors(); } @@ -59,7 +60,8 @@ protected void ConvertAllTensors() foreach (var registeredTensor in _registeredTensors) { var tensor = registeredTensor.Getter(); - if (tensor is null || tensor.dtype == _scalarType) continue; + if (tensor is null || tensor.dtype == _scalarType) + continue; registeredTensor.Setter(tensor.to_type(_scalarType)); } } From a93218551ae98342845081f7b38325bf299372ff Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Nov 2025 13:56:09 +0000 Subject: [PATCH 20/28] Renamed `Buffer` to `Bind` and `Deconstruct` to `Unbind` --- src/Bonsai.ML.Torch/Bind.cs | 129 ++++++++++++++++++ src/Bonsai.ML.Torch/Buffer.cs | 120 ---------------- .../{Deconstruct.cs => Unbind.cs} | 5 +- 3 files changed, 132 insertions(+), 122 deletions(-) create mode 100644 src/Bonsai.ML.Torch/Bind.cs delete mode 100644 src/Bonsai.ML.Torch/Buffer.cs rename src/Bonsai.ML.Torch/{Deconstruct.cs => Unbind.cs} (93%) diff --git a/src/Bonsai.ML.Torch/Bind.cs b/src/Bonsai.ML.Torch/Bind.cs new file mode 100644 index 00000000..37ed075e --- /dev/null +++ b/src/Bonsai.ML.Torch/Bind.cs @@ -0,0 +1,129 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Collections.Generic; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that gathers incoming tensors into zero or more tensors by concatenating them along the first dimension. +/// +/// +/// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count. +/// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again. +/// +[Combinator] +[Description("Gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class Bind +{ + private int _count = 1; + /// + /// Gets or sets the number of tensors to accumulate in the buffer before emitting. + /// + [Description("The number of tensors to accumulate in the buffer before emitting.")] + public int Count + { + get => _count; + set => _count = value <= 0 + ? throw new ArgumentOutOfRangeException("Count must be greater than zero.") + : value; + } + + private int _skip = 1; + /// + /// Gets or sets the number of tensors to skip after emitting the buffer. + /// + [Description("The number of tensors to skip after emitting the buffer.")] + public int Skip + { + get => _skip; + set => _skip = value < 0 + ? throw new ArgumentOutOfRangeException("Skip must be non-negative.") + : value; + } + + /// + /// Processes an observable sequence of tensors, buffering them and concatenating along the first dimension. + /// + public IObservable Process(IObservable source) + { + return Observable.Create(observer => + { + var count = Count; + var skip = Skip; + + Tensor buffer = null; + int current = 0; + Tensor idxSrc = null; + Tensor idxDst = null; + + return source.Subscribe(input => + { + if (input is null) + return; + + if (buffer is null) + { + var shape = input.shape.Prepend(count).ToArray(); + buffer = empty(shape, dtype: input.dtype, device: input.device); + + if (skip < count) + { + idxSrc = arange(skip, count, dtype: ScalarType.Int32, device: input.device); + idxDst = arange(0, count - skip, dtype: ScalarType.Int32, device: input.device); + } + } + + if (current >= 0) + { + buffer[current] = input; + } + + current++; + + if (current >= count) + { + var output = buffer.clone(); + if (skip < count) + { + var src = index_select(buffer, 0, idxSrc); + buffer.index_copy_(0, idxDst, src); + buffer[torch.TensorIndex.Slice(count - skip, null)].zero_(); + } + else + { + buffer.zero_(); + } + current = count - skip; + observer.OnNext(output); + } + }, + observer.OnError, + () => + { + var remainder = current; + + if (remainder > 0 && buffer is not null) + { + var outputShape = buffer.shape.ToArray(); + outputShape[0] = remainder; + var output = empty(outputShape, dtype: buffer.dtype, device: buffer.device); + for (int i = 0; i < remainder; i++) + { + output[i] = buffer[i]; + } + observer.OnNext(output.clone()); + } + + buffer?.Dispose(); + idxSrc?.Dispose(); + idxDst?.Dispose(); + observer.OnCompleted(); + }); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Buffer.cs b/src/Bonsai.ML.Torch/Buffer.cs deleted file mode 100644 index 56a7de62..00000000 --- a/src/Bonsai.ML.Torch/Buffer.cs +++ /dev/null @@ -1,120 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; - -namespace Bonsai.ML.Torch; - -/// -/// Represents an operator that gathers incoming tensors into zero or more tensors by concatenating them along the first dimension. -/// -/// -/// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count. -/// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again. -/// -[Combinator] -[Description("Gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.")] -[WorkflowElementCategory(ElementCategory.Combinator)] -public class Buffer -{ - private int _count = 1; - /// - /// Gets or sets the number of tensors to accumulate in the buffer before emitting. - /// - [Description("The number of tensors to accumulate in the buffer before emitting.")] - public int Count - { - get => _count; - set => _count = value <= 0 - ? throw new ArgumentOutOfRangeException("Count must be greater than zero.") - : value; - } - - private int _skip = 1; - /// - /// Gets or sets the number of tensors to skip after emitting the buffer. - /// - [Description("The number of tensors to skip after emitting the buffer.")] - public int Skip - { - get => _skip; - set => _skip = value < 0 - ? throw new ArgumentOutOfRangeException("Skip must be non-negative.") - : value; - } - - private torch.Tensor _buffer = null; - private int _current = 0; - private torch.Tensor _idxSrc = null; - private torch.Tensor _idxDst = null; - - /// - /// Processes an observable sequence of tensors, buffering them and concatenating along the first dimension. - /// - public IObservable Process(IObservable source) - { - var count = _count; - var skip = _skip; - var send = false; - _current = 0; - return source.Select((input) => - { - if (input is null) - return false; - - if (_buffer is null) - { - var shape = input.shape.Prepend(count).ToArray(); - _buffer = torch.empty(shape, dtype: input.dtype, device: input.device); - - if (skip < count) - { - _idxSrc = torch.arange(skip, count, dtype: torch.ScalarType.Int64, device: input.device); - _idxDst = torch.arange(0, count - skip, dtype: torch.ScalarType.Int64, device: input.device); - } - } - - if (_current >= 0) - { - _buffer[_current] = input; - } - - _current++; - - if (_current >= count) - { - send = true; - } - return send; - }) - .Where(x => x) - .Select(x => - { - var output = _buffer.clone(); - if (skip < count) - { - var src = torch.index_select(_buffer, 0, _idxSrc); - _buffer.index_copy_(0, _idxDst, src); - _buffer[torch.TensorIndex.Slice(count - skip, null)].zero_(); - } - else - { - _buffer.zero_(); - } - _current = count - skip; - send = false; - return output; - }) - .Finally(() => - { - _buffer?.Dispose(); - _buffer = null; - _idxSrc?.Dispose(); - _idxSrc = null; - _idxDst?.Dispose(); - _idxDst = null; - send = false; - }); - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Deconstruct.cs b/src/Bonsai.ML.Torch/Unbind.cs similarity index 93% rename from src/Bonsai.ML.Torch/Deconstruct.cs rename to src/Bonsai.ML.Torch/Unbind.cs index d3962668..465267e6 100644 --- a/src/Bonsai.ML.Torch/Deconstruct.cs +++ b/src/Bonsai.ML.Torch/Unbind.cs @@ -12,7 +12,7 @@ namespace Bonsai.ML.Torch; [Combinator] [Description("Deconstructs each tensor in the sequence into one or more tensors by splitting it along the specified dimension.")] [WorkflowElementCategory(ElementCategory.Combinator)] -public class Deconstruct +public class Unbind { private int _dimension = 0; /// @@ -32,7 +32,8 @@ public int Dimension { return source.SelectMany((input) => { - if (input is null) return null; + if (input is null) + return null; return input.unbind(_dimension).ToObservable(); }); } From f062b66df8c2a65091675699c8f4f3b18627dd5a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Nov 2025 15:55:34 +0000 Subject: [PATCH 21/28] Refactored `TensorOperator` into a true `TypeConverter` class and applied it to `CreateTensor` and `Diagonal` operators --- src/Bonsai.ML.Torch/CreateTensor.cs | 159 ++++++------- src/Bonsai.ML.Torch/Diagonal.cs | 139 ++++++----- src/Bonsai.ML.Torch/TensorOperatorBase.cs | 77 ------- .../TensorOperatorConverter.cs | 78 +++++++ src/Bonsai.ML.Torch/ToTensor.cs | 217 +++++++++--------- 5 files changed, 331 insertions(+), 339 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/TensorOperatorBase.cs create mode 100644 src/Bonsai.ML.Torch/TensorOperatorConverter.cs diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index b531bd9e..e594d4a4 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -11,109 +11,78 @@ using Bonsai.ML.Data; using TorchSharp; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Creates a tensor from the specified values. +/// Uses Python-like syntax to specify the tensor values. +/// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class CreateTensor : IScalarTypeProvider { /// - /// Creates a tensor from the specified values. - /// Uses Python-like syntax to specify the tensor values. - /// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// The data type of the tensor elements. /// - [Combinator] - [ResetCombinator] - [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] - [WorkflowElementCategory(ElementCategory.Source)] - public class CreateTensor : IScalarTypeProvider - { - /// - /// The data type of the tensor elements. - /// - [Description("The data type of the tensor elements.")] - [TypeConverter(typeof(ScalarTypeConverter))] - public ScalarType Type - { - get => _scalarType; - set - { - _scalarType = value; - _tensor = ConvertTensorScalarType(_tensor, value); - } - } - private ScalarType _scalarType = ScalarType.Float32; - - /// - /// The values of the tensor elements. - /// Uses Python-like syntax to specify the tensor values. - /// For example: "[[1, 2], [3, 4]]". - /// - [XmlIgnore] - [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] - [TypeConverter(typeof(TensorConverter))] - public Tensor Values - { - get => _tensor; - set => _tensor = ConvertTensorScalarType(value, _scalarType); - } + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; - private Tensor _tensor = zeros(1, dtype: ScalarType.Float32); + /// + /// The values of the tensor elements. + /// Uses Python-like syntax to specify the tensor values. + /// For example: "[[1, 2], [3, 4]]". + /// + [XmlIgnore] + [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] + [TypeConverter(typeof(TensorConverter))] + public Tensor Values + { + get => _values; + set => _values = value; + } - /// - /// This method converts the tensor to the specified scalar type. - /// - /// - /// We use this method in the setter of the and properties to ensure that the tensor is converted to the appropriate type and then returned. - /// - private static Tensor ConvertTensorScalarType(Tensor value, ScalarType scalarType) - { - return value.to_type(scalarType); - } + /// + /// The values of the tensor elements in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(Values, Type); + set => Values = TensorConverter.ConvertFromString(value, Type); + } - /// - /// The values of the tensor elements in XML string format. - /// - [Browsable(false)] - [XmlElement(nameof(Values))] - [EditorBrowsable(EditorBrowsableState.Never)] - public string ValuesXml - { - get => TensorConverter.ConvertToString(Values, _scalarType); - set => Values = TensorConverter.ConvertFromString(value, _scalarType); - } + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + [Description("The device on which to create the tensor.")] + public Device Device { get; set; } = null; - /// - /// The device on which to create the tensor. - /// - [XmlIgnore] - [Description("The device on which to create the tensor.")] - public Device Device - { - get => _device; - set => _device = value; - } - private Device _device = null; + private Tensor _values; - /// - /// Returns an observable sequence that creates a tensor from the specified values. - /// - public IObservable Process() - { - return Observable.Return(_device != null ? _tensor.to(_device).clone() : _tensor.clone()); - } + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process() + { + var device = Device ?? CPU; + return Observable.Return(_values.to(device).clone()); + } - /// - /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. - /// - public IObservable Process(IObservable source) - { - var tensor = _tensor.clone(); - return source.Take(1) - .Select(_ => { - tensor = tensor.to(_device); - return tensor; - }) - .Concat( - source.Skip(1) - .Select(_ => tensor) - ); - } + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + var tensor = _values.to(device).clone(); + return source.Select(_ => tensor); } } diff --git a/src/Bonsai.ML.Torch/Diagonal.cs b/src/Bonsai.ML.Torch/Diagonal.cs index d5c5cfc2..e45c56c1 100644 --- a/src/Bonsai.ML.Torch/Diagonal.cs +++ b/src/Bonsai.ML.Torch/Diagonal.cs @@ -4,73 +4,96 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Creates a diagonal matrix. If input is a 1D tensor, it creates a diagonal matrix with the elements of the tensor on the diagonal. +/// If input is a 2D tensor, it returns the diagonal elements as a 1D tensor. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a diagonal matrix with the given data type, size, and value.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Diagonal : IScalarTypeProvider { + private Tensor _values; + + /// + /// Gets or sets the values to include in the diagonal. + /// + [Description("The values to include in the diagonal.")] + [TypeConverter(typeof(TensorConverter))] + [XmlIgnore] + public Tensor Values + { + get => _values; + set => _values = value; + } + /// - /// Creates a diagonal matrix. If input is a 1D tensor, it creates a diagonal matrix with the elements of the tensor on the diagonal. - /// If input is a 2D tensor, it returns the diagonal elements as a 1D tensor. + /// The values of the tensor elements in XML string format. /// - [Combinator] - [ResetCombinator] - [Description("Creates a diagonal matrix with the given data type, size, and value.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Diagonal + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml { - /// - /// Gets or sets the values to include in the diagonal. - /// - [Description("The values to include in the diagonal.")] - [TypeConverter(typeof(UnidimensionalArrayConverter))] - public double[] Values { get; set; } + get => TensorConverter.ConvertToString(_values, Type); + set => _values = TensorConverter.ConvertFromString(value, Type); + } - /// - /// Gets or sets the data type of the tensor elements. - /// - [Description("The data type of the tensor elements.")] - public ScalarType? Type { get; set; } + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType Type { get; set; } = ScalarType.Float32; - /// - /// Gets or sets the device on which to create the tensor. - /// - [Description("The device on which to create the tensor.")] - [XmlIgnore] - public Device Device { get; set; } = null; + /// + /// Gets or sets the device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; - /// - /// Gets or sets the diagonal offset. Default is 0, which means the main diagonal. - /// - [Description("The diagonal offset. Default is 0, which means the main diagonal.")] - public int Offset { get; set; } = 0; + /// + /// Gets or sets the diagonal offset. Default is 0, which means the main diagonal. + /// + [Description("The diagonal offset. Default is 0, which means the main diagonal.")] + public int Offset { get; set; } = 0; - /// - /// Creates an observable sequence containing a single diagonal matrix constructed from the specified data type, size and values. - /// - public IObservable Process() - { - var inputTensor = tensor(Input, dtype: Type, device: Device); - return Observable.Return(diag(inputTensor, Offset)); - } + /// + /// Creates an observable sequence containing a single diagonal matrix constructed from the specified data type, size and values. + /// + public IObservable Process() + { + var device = Device ?? CPU; + var inputTensor = _values.to(device); + return Observable.Return(diag(inputTensor, Offset)); + } - /// - /// Generates an observable sequence of tensors by extracting the diagonal of the input. - /// - /// - /// - public IObservable Process(IObservable source) - { - var inputTensor = tensor(Input, dtype: Type, device: Device); - return source.Select(value => diag(inputTensor, Offset)); - } + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + return source.Select(value => _values is not null + ? diag(_values, Offset).to(device) + : diag(value, Offset).to(device)); + } - /// - /// Generates an observable sequence of tensors by extracting the diagonal of the input. - /// - /// - /// - public IObservable Process(IObservable source) - { - var inputTensor = tensor(Input, dtype: Type, device: Device); - return source.Select(value => diag(inputTensor, Offset)); - } + /// + /// Generates an observable sequence of tensors by extracting the diagonal of the input. + /// + /// + /// + public IObservable Process(IObservable source) + { + var device = Device ?? CPU; + var inputTensor = _values.to(device); + return source.Select(value => diag(inputTensor, Offset)); } } diff --git a/src/Bonsai.ML.Torch/TensorOperatorBase.cs b/src/Bonsai.ML.Torch/TensorOperatorBase.cs deleted file mode 100644 index 0767fa3a..00000000 --- a/src/Bonsai.ML.Torch/TensorOperatorBase.cs +++ /dev/null @@ -1,77 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Runtime.CompilerServices; -using TorchSharp; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch; - -/// -/// Provides an abstract base class for operators that contain tensor properties. Automatic scalar type conversion is provided for all registered tensors. -/// -public abstract class TensorOperatorBase : IScalarTypeProvider -{ - private ScalarType _scalarType = ScalarType.Float32; - private readonly List _registeredTensors = new(); - private readonly object _sync = new(); - - /// - public ScalarType Type - { - get => _scalarType; - set - { - if (_scalarType == value) - return; - _scalarType = value; - ConvertAllTensors(); - } - } - - /// - /// Registers a tensor property for automatic scalar type conversion. - /// - /// - /// - protected void RegisterTensor(Func getter, Action setter) - { - _registeredTensors.Add(new RegisteredTensor(getter, setter)); - } - - /// - /// Convert a single tensor to the current scalar type. - /// - protected Tensor ConvertTensor(Tensor value) - { - return value.dtype == _scalarType ? value : value.to_type(_scalarType); - } - - /// - /// Converts all tensor properties to the current scalar type. - /// - protected void ConvertAllTensors() - { - lock (_sync) - { - foreach (var registeredTensor in _registeredTensors) - { - var tensor = registeredTensor.Getter(); - if (tensor is null || tensor.dtype == _scalarType) - continue; - registeredTensor.Setter(tensor.to_type(_scalarType)); - } - } - } - - private readonly struct RegisteredTensor( - Func getter, - Action setter) - { - public Func Getter { get; } = getter; - public Action Setter { get; } = setter; - } -} diff --git a/src/Bonsai.ML.Torch/TensorOperatorConverter.cs b/src/Bonsai.ML.Torch/TensorOperatorConverter.cs new file mode 100644 index 00000000..6ac17150 --- /dev/null +++ b/src/Bonsai.ML.Torch/TensorOperatorConverter.cs @@ -0,0 +1,78 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reflection; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Provides a type converter that automatically converts all tensor properties with the attribute when the scalar type changes. +/// Apply this converter to classes using [TypeConverter(typeof(TensorOperatorConverter))]. +/// +public class TensorOperatorConverter : TypeConverter +{ + /// + public override PropertyDescriptorCollection GetProperties(ITypeDescriptorContext context, object value, Attribute[] attributes) + { + var properties = TypeDescriptor.GetProperties(value, attributes); + return new PropertyDescriptorCollection( + properties.Cast() + .Select(p => p.Name == nameof(IScalarTypeProvider.Type) + ? new ScalarTypePropertyDescriptor(p) + : p) + .ToArray()); + } + + /// + public override bool GetPropertiesSupported(ITypeDescriptorContext context) => true; + + private class ScalarTypePropertyDescriptor(PropertyDescriptor baseDescriptor) : PropertyDescriptor(baseDescriptor) + { + private readonly PropertyDescriptor _baseDescriptor = baseDescriptor; + public override Type ComponentType => _baseDescriptor.ComponentType; + public override bool IsReadOnly => _baseDescriptor.IsReadOnly; + public override Type PropertyType => _baseDescriptor.PropertyType; + + public override bool CanResetValue(object component) => _baseDescriptor.CanResetValue(component); + public override object GetValue(object component) => _baseDescriptor.GetValue(component); + public override void ResetValue(object component) => _baseDescriptor.ResetValue(component); + public override bool ShouldSerializeValue(object component) => _baseDescriptor.ShouldSerializeValue(component); + + public override void SetValue(object component, object value) + { + var oldValue = _baseDescriptor.GetValue(component); + if (Equals(oldValue, value)) + return; + + _baseDescriptor.SetValue(component, value); + + if (value is ScalarType newScalarType && component is IScalarTypeProvider) + { + ConvertAllTensorProperties(component, newScalarType); + } + } + + private static void ConvertAllTensorProperties(object component, ScalarType scalarType) + { + var properties = TypeDescriptor.GetProperties(component); + foreach (PropertyDescriptor property in properties) + { + // Check if this property uses TensorConverter + var converterAttr = property.Attributes.OfType().FirstOrDefault(); + + if (converterAttr?.ConverterTypeName?.Contains(nameof(TensorConverter)) != true) + continue; + + if (property.PropertyType != typeof(Tensor)) + continue; + + if (property.GetValue(component) is not Tensor tensor || tensor.dtype == scalarType) + continue; + + property.SetValue(component, tensor.to_type(scalarType)); + } + } + } +} diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index d0d3b7a8..bfc5624d 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -6,128 +6,127 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Converts the input value into a tensor. +/// +[Combinator] +[ResetCombinator] +[Description("Converts the input value into a tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ToTensor { /// - /// Converts the input value into a tensor. + /// The device on which to create the tensor. /// - [Combinator] - [ResetCombinator] - [Description("Converts the input value into a tensor.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToTensor - { - /// - /// The device on which to create the tensor. - /// - [Description("The device on which to create the tensor.")] - [XmlIgnore] - public Device Device { get; set; } = null; + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; - /// - /// The data type of the tensor. - /// - [Description("The data type of the tensor.")] - public ScalarType? Type { get; set; } = null; + /// + /// The data type of the tensor. + /// + [Description("The data type of the tensor.")] + public ScalarType? Type { get; set; } = null; - /// - /// Converts an int into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a double into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a byte into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a bool into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a float into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a long into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts a short into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts an array into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => as_tensor(value, dtype: Type, device: Device)); - } + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => as_tensor(value, dtype: Type, device: Device)); + } - /// - /// Converts an IplImage into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); - } + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); + } - /// - /// Converts a Mat into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); - } + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => OpenCVHelper.ToTensor(value, device: Device)); } } \ No newline at end of file From d68a524a0aeccb8f953f22a1a5cec0d850811fa6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:12:16 +0000 Subject: [PATCH 22/28] Added operator to explicitly convert single valued tensors to .NET data types --- src/Bonsai.ML.Torch/ConvertToItem.cs | 59 ++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 src/Bonsai.ML.Torch/ConvertToItem.cs diff --git a/src/Bonsai.ML.Torch/ConvertToItem.cs b/src/Bonsai.ML.Torch/ConvertToItem.cs new file mode 100644 index 00000000..55dfc93d --- /dev/null +++ b/src/Bonsai.ML.Torch/ConvertToItem.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Represents an operator that converts the input tensor into a single value of the specified element type. The input tensor must only contain a single element. +/// +[Combinator] +[Description("Converts the input tensor into a single value of the specified element type.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ConvertToItem : SingleArgumentExpressionBuilder +{ + /// + /// Gets or sets the type of the item. + /// + [Description("Gets or sets the type of the item.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + public override Expression Build(IEnumerable arguments) + { + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + var type = ScalarTypeLookup.GetTypeFromScalarType(Type); + methodInfo = methodInfo.MakeGenericMethod(type); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into a single item. + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + if (tensor.dtype != Type) + { + tensor = tensor.to_type(Type); + } + return tensor.item(); + }); + } +} \ No newline at end of file From 8ccaa22076544cdf5e4ebb342003f2a684defeff Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 27 Nov 2025 19:38:46 -0800 Subject: [PATCH 23/28] Modified linspace to correctly use double properties for start and end values of range --- src/Bonsai.ML.Torch/LinSpace.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/LinSpace.cs b/src/Bonsai.ML.Torch/LinSpace.cs index aaded985..954c311e 100644 --- a/src/Bonsai.ML.Torch/LinSpace.cs +++ b/src/Bonsai.ML.Torch/LinSpace.cs @@ -19,13 +19,13 @@ public class LinSpace /// The start of the range. /// [Description("The start of the range.")] - public int Start { get; set; } = 0; + public double Start { get; set; } = 0; /// /// The end of the range. /// [Description("The end of the range.")] - public int End { get; set; } = 1; + public double End { get; set; } = 1; /// /// The number of points to generate. From 768d5a3b6b5b4371f606625e81ff5ef7d4fc8d18 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Dec 2025 15:24:57 +0000 Subject: [PATCH 24/28] Updated `Bind` to ensure that tensors used for indices are the correct data type --- src/Bonsai.ML.Torch/Bind.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Bind.cs b/src/Bonsai.ML.Torch/Bind.cs index 37ed075e..c13ca1a8 100644 --- a/src/Bonsai.ML.Torch/Bind.cs +++ b/src/Bonsai.ML.Torch/Bind.cs @@ -73,8 +73,8 @@ public IObservable Process(IObservable source) if (skip < count) { - idxSrc = arange(skip, count, dtype: ScalarType.Int32, device: input.device); - idxDst = arange(0, count - skip, dtype: ScalarType.Int32, device: input.device); + idxSrc = arange(skip, count, dtype: ScalarType.Int64, device: input.device); + idxDst = arange(0, count - skip, dtype: ScalarType.Int64, device: input.device); } } From de1066e4db94f6d9a735146e80c5a4cbf976ec51 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Dec 2025 15:27:49 +0000 Subject: [PATCH 25/28] Added type converter class for value tuple and nullable value tuple properties --- src/Bonsai.ML.Torch/ValueTupleConverter.cs | 496 +++++++++++++++++++++ 1 file changed, 496 insertions(+) create mode 100644 src/Bonsai.ML.Torch/ValueTupleConverter.cs diff --git a/src/Bonsai.ML.Torch/ValueTupleConverter.cs b/src/Bonsai.ML.Torch/ValueTupleConverter.cs new file mode 100644 index 00000000..0738d875 --- /dev/null +++ b/src/Bonsai.ML.Torch/ValueTupleConverter.cs @@ -0,0 +1,496 @@ +using System; +using System.ComponentModel; + +namespace Bonsai.ML.Torch; + +/// +/// Type converter for single-element value tuples. +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 1) + { + var item = (T)TypeDescriptor.GetConverter(typeof(T)).ConvertFromString(context, culture, elements[0].Trim()); + return ValueTuple.Create(item); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item = TypeDescriptor.GetConverter(typeof(T)).ConvertToString(context, culture, tuple.Item1); + return $"({item})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for single-element nullable value tuples. +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for two-element value tuples. +/// +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 2) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + return ValueTuple.Create(item1, item2); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + return $"({item1}, {item2})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for two-element nullable value tuples. +/// +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for three-element value tuples. +/// +/// +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + return ValueTuple.Create(item1, item2, item3); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + return $"({item1}, {item2}, {item3})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for three-element nullable value tuples. +/// +/// +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for four-element value tuples. +/// +/// +/// +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + return ValueTuple.Create(item1, item2, item3, item4); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + return $"({item1}, {item2}, {item3}, {item4})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for four-element nullable value tuples. +/// +/// +/// +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for five-element value tuples. +/// +/// +/// +/// +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + return $"({item1}, {item2}, {item3}, {item4}, {item5})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for five-element nullable value tuples. +/// +/// +/// +/// +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} + +/// +/// Type converter for six-element value tuples. +/// +/// +/// +/// +/// +/// +/// +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + var item6 = (T6)TypeDescriptor.GetConverter(typeof(T6)).ConvertFromString(context, culture, elements[5].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5, item6); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}, {typeof(T6).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + var item6 = TypeDescriptor.GetConverter(typeof(T6)).ConvertToString(context, culture, tuple.Item6); + return $"({item1}, {item2}, {item3}, {item4}, {item5}, {item6})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for six-element nullable value tuples. +/// +/// +/// +/// +/// +/// +/// +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} \ No newline at end of file From 585a5720eaa6c82c271d9d460812ec76c7851a57 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 12 Dec 2025 12:22:04 +0000 Subject: [PATCH 26/28] Updated XML docs with improved descriptions for type parameters --- src/Bonsai.ML.Torch/ConvertScalarType.cs | 2 +- src/Bonsai.ML.Torch/ConvertToArray.cs | 12 +- src/Bonsai.ML.Torch/ConvertToItem.cs | 10 +- src/Bonsai.ML.Torch/ConvertToNDArray.cs | 14 +- src/Bonsai.ML.Torch/ValueTupleConverter.cs | 182 ++++++++++++++++----- 5 files changed, 159 insertions(+), 61 deletions(-) diff --git a/src/Bonsai.ML.Torch/ConvertScalarType.cs b/src/Bonsai.ML.Torch/ConvertScalarType.cs index 2af31b13..745fa722 100644 --- a/src/Bonsai.ML.Torch/ConvertScalarType.cs +++ b/src/Bonsai.ML.Torch/ConvertScalarType.cs @@ -20,7 +20,7 @@ public class ConvertScalarType public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// Converts the data type of the input tensor to the newly specified scalar type. /// public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/ConvertToArray.cs b/src/Bonsai.ML.Torch/ConvertToArray.cs index 271cde16..2c6b3bb4 100644 --- a/src/Bonsai.ML.Torch/ConvertToArray.cs +++ b/src/Bonsai.ML.Torch/ConvertToArray.cs @@ -32,7 +32,7 @@ public override Expression Build(IEnumerable arguments) MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); methodInfo = methodInfo.MakeGenericMethod(ScalarTypeLookup.GetTypeFromScalarType(Type)); Expression sourceExpression = arguments.First(); - + return Expression.Call( Expression.Constant(this), methodInfo, @@ -43,12 +43,12 @@ public override Expression Build(IEnumerable arguments) /// /// Converts the input tensor into a flattened array of the specified element type. /// - /// - /// - /// + /// The element type of the output item. + /// The sequence of input tensors. + /// The sequence of output arrays of the specified element type. public IObservable Process(IObservable source) where T : unmanaged { - return source.Select(tensor => + return source.Select(tensor => { if (tensor.dtype != Type) { @@ -58,4 +58,4 @@ public IObservable Process(IObservable source) where T : unmanag }); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/ConvertToItem.cs b/src/Bonsai.ML.Torch/ConvertToItem.cs index 55dfc93d..e91bf6e9 100644 --- a/src/Bonsai.ML.Torch/ConvertToItem.cs +++ b/src/Bonsai.ML.Torch/ConvertToItem.cs @@ -33,7 +33,7 @@ public override Expression Build(IEnumerable arguments) var type = ScalarTypeLookup.GetTypeFromScalarType(Type); methodInfo = methodInfo.MakeGenericMethod(type); Expression sourceExpression = arguments.First(); - + return Expression.Call( Expression.Constant(this), methodInfo, @@ -44,10 +44,12 @@ public override Expression Build(IEnumerable arguments) /// /// Converts the input tensor into a single item. /// - /// + /// The element type of the output item. + /// The sequence of input tensors. + /// The sequence of output items of the specified element type. public IObservable Process(IObservable source) where T : unmanaged { - return source.Select(tensor => + return source.Select(tensor => { if (tensor.dtype != Type) { @@ -56,4 +58,4 @@ public IObservable Process(IObservable source) where T : unmanaged return tensor.item(); }); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/ConvertToNDArray.cs b/src/Bonsai.ML.Torch/ConvertToNDArray.cs index 171724e0..9c1ea652 100644 --- a/src/Bonsai.ML.Torch/ConvertToNDArray.cs +++ b/src/Bonsai.ML.Torch/ConvertToNDArray.cs @@ -41,7 +41,7 @@ public override Expression Build(IEnumerable arguments) Type arrayType = Array.CreateInstance(type, lengths).GetType(); methodInfo = methodInfo.MakeGenericMethod(type, arrayType); Expression sourceExpression = arguments.First(); - + return Expression.Call( Expression.Constant(this), methodInfo, @@ -52,13 +52,13 @@ public override Expression Build(IEnumerable arguments) /// /// Converts the input tensor into an array of the specified element type. /// - /// - /// - /// - /// + /// The element type of the output item. + /// The type of the output array. + /// The sequence of input tensors. + /// The sequence of output arrays of the specified element type and rank. public IObservable Process(IObservable source) where T : unmanaged { - return source.Select(tensor => + return source.Select(tensor => { if (tensor.dtype != Type) { @@ -68,4 +68,4 @@ public IObservable Process(IObservable source) wher }); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/ValueTupleConverter.cs b/src/Bonsai.ML.Torch/ValueTupleConverter.cs index 0738d875..73b57ed4 100644 --- a/src/Bonsai.ML.Torch/ValueTupleConverter.cs +++ b/src/Bonsai.ML.Torch/ValueTupleConverter.cs @@ -6,7 +6,7 @@ namespace Bonsai.ML.Torch; /// /// Type converter for single-element value tuples. /// -/// +/// The type of the element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -63,7 +63,7 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for single-element nullable value tuples. /// -/// +/// The type of the element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -78,8 +78,8 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global /// /// Type converter for two-element value tuples. /// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -138,8 +138,8 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for two-element nullable value tuples. /// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -154,9 +154,9 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global /// /// Type converter for three-element value tuples. /// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -217,9 +217,9 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for three-element nullable value tuples. /// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -234,10 +234,10 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global /// /// Type converter for four-element value tuples. /// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -300,10 +300,10 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for four-element nullable value tuples. /// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -318,11 +318,11 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global /// /// Type converter for five-element value tuples. /// -/// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -387,11 +387,11 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for five-element nullable value tuples. /// -/// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -406,12 +406,12 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global /// /// Type converter for six-element value tuples. /// -/// -/// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. public class ValueTupleConverter : TypeConverter { /// @@ -478,12 +478,12 @@ public override object ConvertTo(ITypeDescriptorContext context, System.Globaliz /// /// Type converter for six-element nullable value tuples. /// -/// -/// -/// -/// -/// -/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. public class NullableValueTupleConverter : ValueTupleConverter { /// @@ -493,4 +493,100 @@ public override object ConvertFrom(ITypeDescriptorContext context, System.Global return null; return base.ConvertFrom(context, culture, value); } -} \ No newline at end of file +} + +/// +/// Type converter for seven-element value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +/// The type of the seventh element in the value tuple. +public class ValueTupleConverter : TypeConverter +{ + /// + public override bool CanConvertFrom(ITypeDescriptorContext context, Type sourceType) + { + if (sourceType == typeof(string) || sourceType == typeof(ValueTuple)) + return true; + + return base.CanConvertFrom(context, sourceType); + } + + /// + public override bool CanConvertTo(ITypeDescriptorContext context, Type destinationType) + { + if (destinationType == typeof(string) || destinationType == typeof(ValueTuple)) + return true; + + return base.CanConvertTo(context, destinationType); + } + + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue) + { + var elements = stringValue.Trim('(', ')').Split(','); + if (elements.Length == 3) + { + var item1 = (T1)TypeDescriptor.GetConverter(typeof(T1)).ConvertFromString(context, culture, elements[0].Trim()); + var item2 = (T2)TypeDescriptor.GetConverter(typeof(T2)).ConvertFromString(context, culture, elements[1].Trim()); + var item3 = (T3)TypeDescriptor.GetConverter(typeof(T3)).ConvertFromString(context, culture, elements[2].Trim()); + var item4 = (T4)TypeDescriptor.GetConverter(typeof(T4)).ConvertFromString(context, culture, elements[3].Trim()); + var item5 = (T5)TypeDescriptor.GetConverter(typeof(T5)).ConvertFromString(context, culture, elements[4].Trim()); + var item6 = (T6)TypeDescriptor.GetConverter(typeof(T6)).ConvertFromString(context, culture, elements[5].Trim()); + var item7 = (T7)TypeDescriptor.GetConverter(typeof(T7)).ConvertFromString(context, culture, elements[6].Trim()); + return ValueTuple.Create(item1, item2, item3, item4, item5, item6, item7); + } + throw new ArgumentException($"Cannot convert '{stringValue}' to ValueTuple<{typeof(T1).Name}, {typeof(T2).Name}, {typeof(T3).Name}, {typeof(T4).Name}, {typeof(T5).Name}, {typeof(T6).Name}, {typeof(T7).Name}>."); + } + + return base.ConvertFrom(context, culture, value); + } + + /// + public override object ConvertTo(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value, Type destinationType) + { + if (value is ValueTuple tuple) + { + if (destinationType == typeof(string)) + { + var item1 = TypeDescriptor.GetConverter(typeof(T1)).ConvertToString(context, culture, tuple.Item1); + var item2 = TypeDescriptor.GetConverter(typeof(T2)).ConvertToString(context, culture, tuple.Item2); + var item3 = TypeDescriptor.GetConverter(typeof(T3)).ConvertToString(context, culture, tuple.Item3); + var item4 = TypeDescriptor.GetConverter(typeof(T4)).ConvertToString(context, culture, tuple.Item4); + var item5 = TypeDescriptor.GetConverter(typeof(T5)).ConvertToString(context, culture, tuple.Item5); + var item6 = TypeDescriptor.GetConverter(typeof(T6)).ConvertToString(context, culture, tuple.Item6); + var item7 = TypeDescriptor.GetConverter(typeof(T7)).ConvertToString(context, culture, tuple.Item7); + return $"({item1}, {item2}, {item3}, {item4}, {item5}, {item6}, {item7})"; + } + } + + return base.ConvertTo(context, culture, value, destinationType); + } +} + +/// +/// Type converter for seven-element nullable value tuples. +/// +/// The type of the first element in the value tuple. +/// The type of the second element in the value tuple. +/// The type of the third element in the value tuple. +/// The type of the fourth element in the value tuple. +/// The type of the fifth element in the value tuple. +/// The type of the sixth element in the value tuple. +/// The type of the seventh element in the value tuple. +public class NullableValueTupleConverter : ValueTupleConverter +{ + /// + public override object ConvertFrom(ITypeDescriptorContext context, System.Globalization.CultureInfo culture, object value) + { + if (value is string stringValue && string.IsNullOrEmpty(stringValue)) + return null; + return base.ConvertFrom(context, culture, value); + } +} From 5cee2e0c348cac4c2ff778037e2cdfac871ff678 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 16 Dec 2025 11:17:13 +0000 Subject: [PATCH 27/28] Refactored `LoadTensor` and `SaveTensor` operators to exclusively use TorchSharp's .NET specific serialization method --- src/Bonsai.ML.Torch/LoadTensor.cs | 20 ++------------------ src/Bonsai.ML.Torch/SaveTensor.cs | 20 ++------------------ 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs index 6b179ac4..4694e184 100644 --- a/src/Bonsai.ML.Torch/LoadTensor.cs +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -21,29 +21,13 @@ public class LoadTensor [Description("The path to the file containing the tensor.")] public string Path { get; set; } - /// - /// Indicates whether to use the native torch load method for the tensor. - /// - /// - /// If set to true, the native torch load method will be used. - /// If set to false, the tensor will be loaded using the TorchSharp method which is specific to .NET formats. - /// - [Description("Indicates whether to use the native torch load method for the tensor.")] - public bool UseNativeTorchMethod { get; set; } = false; - /// /// Loads a tensor from the specified file. /// /// public IObservable Process() { - switch (UseNativeTorchMethod) - { - case true: - return Observable.Return(load(Path)); - case false: - return Observable.Return(Tensor.Load(Path)); - } + return Observable.Return(Tensor.Load(Path)); } } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs index 72127b11..3edd03b5 100644 --- a/src/Bonsai.ML.Torch/SaveTensor.cs +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -22,16 +22,6 @@ public class SaveTensor [Description("The path to the file where the tensor will be saved.")] public string Path { get; set; } = string.Empty; - /// - /// Indicates whether to use the native torch save method for the tensor. - /// - /// - /// If set to true, the native torch save method will be used. - /// If set to false, the tensor will be saved using the TorchSharp method which is specific to .NET formats. - /// - [Description("Indicates whether to use the native torch save method for the tensor.")] - public bool UseNativeTorchMethod { get; set; } = false; - /// /// Saves the input tensor to the specified file. /// @@ -39,13 +29,7 @@ public class SaveTensor /// public IObservable Process(IObservable source) { - return source.Do(tensor => - { - if (UseNativeTorchMethod) - tensor.save(Path); - else - tensor.Save(Path); - }); + return source.Do(tensor => tensor.Save(Path)); } } -} \ No newline at end of file +} From b5952eb92fa4b9736f7dc43e3ddea4e6892a5d27 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 16 Dec 2025 11:21:55 +0000 Subject: [PATCH 28/28] Slightly reworded XML docs in `InitializeTorchDevice` and ensured variables are set explicitly --- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 69131c54..e18d0606 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch { /// - /// Initializes the Torch device with the specified device type. + /// Represents an operator that initializes the Torch device with the specified device type. /// [Combinator] [Description("Initializes the Torch device with the specified device type.")] @@ -44,14 +44,18 @@ public IObservable Process() /// /// Initializes the Torch device when the input sequence produces an element. /// + /// + /// /// public IObservable Process(IObservable source) { return source.Select((_) => { - InitializeDeviceType(DeviceType); - return new Device(DeviceType, DeviceIndex); + var deviceType = DeviceType; + var deviceIndex = DeviceIndex; + InitializeDeviceType(deviceType); + return new Device(deviceType, deviceIndex); }); } } -} \ No newline at end of file +}