Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bfaa374
Added classes to generate diagonal tensors and identity (eye) matrices
ncguilbeault Aug 1, 2025
c1dd4ad
Added a sink for debugging with extended print capabilities
ncguilbeault Aug 1, 2025
2fde348
Updated main gitignore to ignore nested bonsai environments
ncguilbeault Aug 6, 2025
f9ae733
Updated `LoadTensor` operator to support the native load method from …
ncguilbeault Aug 6, 2025
835657a
Added class to `Buffer` tensors along the first dimension
ncguilbeault Oct 8, 2025
b5e943f
Added class to `Decompose` a large tensor into a sequence of smaller …
ncguilbeault Oct 8, 2025
3613429
Added feature to save a tensor using TorchSharp's .NET format or the …
ncguilbeault Oct 8, 2025
048b9a7
Added `.pt` file filters for load/save tensor objects
ncguilbeault Oct 8, 2025
e49bf08
Added operators to manage downstream tensor operations in different m…
ncguilbeault Oct 8, 2025
787925f
Updated `TensorConverter` to handle case when tensor is null and retu…
ncguilbeault Oct 8, 2025
8fc4947
Remove unnecessary reset of current index in `Buffer` class
ncguilbeault Oct 15, 2025
5a19f85
Refactored `InitializeTorchDevice` to ensure the device initializatio…
ncguilbeault Oct 30, 2025
a4c15b3
Added an abstract base class for operators that contain tensor proper…
ncguilbeault Nov 1, 2025
9be41b8
Added XML docs to classes
ncguilbeault Nov 10, 2025
31cee0d
Updated property name from `UseNativeMethod` to `UseNativeTorchMethod…
ncguilbeault Nov 10, 2025
26900fd
Renamed `Decompose` to `Deconstruct` for better alignment with operat…
ncguilbeault Nov 11, 2025
8188866
Apply suggestions from code review
ncguilbeault Nov 12, 2025
b3c9b17
Renamed `PrintTensor` to `FormatTensor`
ncguilbeault Nov 12, 2025
e7b71de
Renamed `TensorContainer` to `TensorOperator`
ncguilbeault Nov 12, 2025
a932185
Renamed `Buffer` to `Bind` and `Deconstruct` to `Unbind`
ncguilbeault Nov 12, 2025
f062b66
Refactored `TensorOperator` into a true `TypeConverter` class and app…
ncguilbeault Nov 12, 2025
d68a524
Added operator to explicitly convert single valued tensors to .NET da…
ncguilbeault Nov 13, 2025
8ccaa22
Modified linspace to correctly use double properties for start and en…
ncguilbeault Nov 28, 2025
768d5a3
Updated `Bind` to ensure that tensors used for indices are the correc…
ncguilbeault Dec 10, 2025
de1066e
Added type converter class for value tuple and nullable value tuple p…
ncguilbeault Dec 10, 2025
585a572
Updated XML docs with improved descriptions for type parameters
ncguilbeault Dec 12, 2025
5cee2e0
Refactored `LoadTensor` and `SaveTensor` operators to exclusively use…
ncguilbeault Dec 16, 2025
b5952eb
Slightly reworded XML docs in `InitializeTorchDevice` and ensured var…
ncguilbeault Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.bonsai/Bonsai.exe*
.bonsai/Packages/
.bonsai/Settings/
**/.bonsai/Bonsai.exe*
**/.bonsai/Packages/
**/.bonsai/Settings/
.vs/
/artifacts/
.venv
Expand Down
129 changes: 129 additions & 0 deletions src/Bonsai.ML.Torch/Bind.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
using System;
using System.ComponentModel;
using System.Linq;
using System.Reactive.Linq;
using System.Collections.Generic;
using TorchSharp;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch;

/// <summary>
/// Represents an operator that gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.
/// </summary>
/// <remarks>
/// The operator maintains an internal buffer that accumulates incoming tensors until it reaches the specified count.
/// When the buffer reaches the specified count, it is emitted as a single tensor. After emitting the buffer, the operator skips a specified number of incoming tensors before starting to fill the buffer again.
/// </remarks>
[Combinator]
[Description("Gathers incoming tensors into zero or more tensors by concatenating them along the first dimension.")]
[WorkflowElementCategory(ElementCategory.Combinator)]
public class Bind
{
private int _count = 1;
/// <summary>
/// Gets or sets the number of tensors to accumulate in the buffer before emitting.
/// </summary>
[Description("The number of tensors to accumulate in the buffer before emitting.")]
public int Count
{
get => _count;
set => _count = value <= 0
? throw new ArgumentOutOfRangeException("Count must be greater than zero.")
: value;
}

private int _skip = 1;
/// <summary>
/// Gets or sets the number of tensors to skip after emitting the buffer.
/// </summary>
[Description("The number of tensors to skip after emitting the buffer.")]
public int Skip
{
get => _skip;
set => _skip = value < 0
? throw new ArgumentOutOfRangeException("Skip must be non-negative.")
: value;
}

/// <summary>
/// Processes an observable sequence of tensors, buffering them and concatenating along the first dimension.
/// </summary>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
return Observable.Create<Tensor>(observer =>
{
var count = Count;
var skip = Skip;

Tensor buffer = null;
int current = 0;
Tensor idxSrc = null;
Tensor idxDst = null;

return source.Subscribe(input =>
{
if (input is null)
return;

if (buffer is null)
{
var shape = input.shape.Prepend(count).ToArray();
buffer = empty(shape, dtype: input.dtype, device: input.device);

if (skip < count)
{
idxSrc = arange(skip, count, dtype: ScalarType.Int64, device: input.device);
idxDst = arange(0, count - skip, dtype: ScalarType.Int64, device: input.device);
}
}

if (current >= 0)
{
buffer[current] = input;
}

current++;

if (current >= count)
{
var output = buffer.clone();
if (skip < count)
{
var src = index_select(buffer, 0, idxSrc);
buffer.index_copy_(0, idxDst, src);
buffer[torch.TensorIndex.Slice(count - skip, null)].zero_();
}
else
{
buffer.zero_();
}
current = count - skip;
observer.OnNext(output);
}
},
observer.OnError,
() =>
{
var remainder = current;

if (remainder > 0 && buffer is not null)
{
var outputShape = buffer.shape.ToArray();
outputShape[0] = remainder;
var output = empty(outputShape, dtype: buffer.dtype, device: buffer.device);
for (int i = 0; i < remainder; i++)
{
output[i] = buffer[i];
}
observer.OnNext(output.clone());
}

buffer?.Dispose();
idxSrc?.Dispose();
idxDst?.Dispose();
observer.OnCompleted();
});
});
}
}
2 changes: 1 addition & 1 deletion src/Bonsai.ML.Torch/ConvertScalarType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class ConvertScalarType
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Returns an observable sequence that converts the input tensor to the specified scalar type.
/// Converts the data type of the input tensor to the newly specified scalar type.
/// </summary>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
Expand Down
12 changes: 6 additions & 6 deletions src/Bonsai.ML.Torch/ConvertToArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public override Expression Build(IEnumerable<Expression> arguments)
MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance);
methodInfo = methodInfo.MakeGenericMethod(ScalarTypeLookup.GetTypeFromScalarType(Type));
Expression sourceExpression = arguments.First();

return Expression.Call(
Expression.Constant(this),
methodInfo,
Expand All @@ -43,12 +43,12 @@ public override Expression Build(IEnumerable<Expression> arguments)
/// <summary>
/// Converts the input tensor into a flattened array of the specified element type.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
/// <typeparam name="T">The element type of the output item.</typeparam>
/// <param name="source">The sequence of input tensors.</param>
/// <returns>The sequence of output arrays of the specified element type.</returns>
public IObservable<T[]> Process<T>(IObservable<Tensor> source) where T : unmanaged
{
return source.Select(tensor =>
return source.Select(tensor =>
{
if (tensor.dtype != Type)
{
Expand All @@ -58,4 +58,4 @@ public IObservable<T[]> Process<T>(IObservable<Tensor> source) where T : unmanag
});
}
}
}
}
61 changes: 61 additions & 0 deletions src/Bonsai.ML.Torch/ConvertToItem.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reactive.Linq;
using System.Xml.Serialization;
using System.Linq.Expressions;
using System.Reflection;
using Bonsai.Expressions;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch;

/// <summary>
/// Represents an operator that converts the input tensor into a single value of the specified element type. The input tensor must only contain a single element.
/// </summary>
[Combinator]
[Description("Converts the input tensor into a single value of the specified element type.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class ConvertToItem : SingleArgumentExpressionBuilder
{
/// <summary>
/// Gets or sets the type of the item.
/// </summary>
[Description("Gets or sets the type of the item.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <inheritdoc/>
public override Expression Build(IEnumerable<Expression> arguments)
{
MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance);
var type = ScalarTypeLookup.GetTypeFromScalarType(Type);
methodInfo = methodInfo.MakeGenericMethod(type);
Expression sourceExpression = arguments.First();

return Expression.Call(
Expression.Constant(this),
methodInfo,
sourceExpression
);
}

/// <summary>
/// Converts the input tensor into a single item.
/// </summary>
/// <typeparam name="T">The element type of the output item.</typeparam>
/// <param name="source">The sequence of input tensors.</param>
/// <returns>The sequence of output items of the specified element type.</returns>
public IObservable<T> Process<T>(IObservable<Tensor> source) where T : unmanaged
{
return source.Select(tensor =>
{
if (tensor.dtype != Type)
{
tensor = tensor.to_type(Type);
}
return tensor.item<T>();
});
}
}
14 changes: 7 additions & 7 deletions src/Bonsai.ML.Torch/ConvertToNDArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public override Expression Build(IEnumerable<Expression> arguments)
Type arrayType = Array.CreateInstance(type, lengths).GetType();
methodInfo = methodInfo.MakeGenericMethod(type, arrayType);
Expression sourceExpression = arguments.First();

return Expression.Call(
Expression.Constant(this),
methodInfo,
Expand All @@ -52,13 +52,13 @@ public override Expression Build(IEnumerable<Expression> arguments)
/// <summary>
/// Converts the input tensor into an array of the specified element type.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="TResult"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
/// <typeparam name="T">The element type of the output item.</typeparam>
/// <typeparam name="TResult">The type of the output array.</typeparam>
/// <param name="source">The sequence of input tensors.</param>
/// <returns>The sequence of output arrays of the specified element type and rank.</returns>
public IObservable<TResult> Process<T, TResult>(IObservable<Tensor> source) where T : unmanaged
{
return source.Select(tensor =>
return source.Select(tensor =>
{
if (tensor.dtype != Type)
{
Expand All @@ -68,4 +68,4 @@ public IObservable<TResult> Process<T, TResult>(IObservable<Tensor> source) wher
});
}
}
}
}
Loading