Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples
Submodule examples updated 136 files
74 changes: 74 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Bernoulli.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Represents an operator that creates a Bernoulli probability distribution parameterized by event probabilities.
/// </summary>
[Combinator]
[Description("Creates a Bernoulli distribution with event probabilities and emits a TorchSharp distribution module.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Bernoulli : IScalarTypeProvider
{
/// <summary>
/// The event probabilities in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The event probabilities in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of the probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Bernoulli distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process()
{
return Observable.Return(distributions.Bernoulli(Probabilities));
}

/// <summary>
/// Creates a Bernoulli distribution using the incoming RNG Generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Bernoulli(Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Bernoulli distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Bernoulli(Probabilities));
}
}
95 changes: 95 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Beta.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using TorchSharp;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Represents an operator that creates a beta probability distribution parameterized by two concentration parameters (alpha, beta).
/// </summary>
[Combinator]
[Description("Creates a Beta distribution with concentration parameters (alpha, beta).")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Beta : IScalarTypeProvider
{
/// <summary>
/// The first concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Concentration alpha (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Concentration1 { get; set; } = null;

/// <summary>
/// The values of concentration 1 in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Concentration1))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string Concentration1Xml
{
get => TensorConverter.ConvertToString(Concentration1, Type);
set => Concentration1 = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Concentration parameter beta (> 0). Can be a scalar or tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Concentration beta (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Concentration0 { get; set; } = null;

/// <summary>
/// The values of concentration 0 in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Concentration0))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string Concentration0Xml
{
get => TensorConverter.ConvertToString(Concentration0, Type);
set => Concentration0 = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Beta distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process()
{
return Observable.Return(distributions.Beta(Concentration1, Concentration0));
}

/// <summary>
/// Creates a Beta distribution using the incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Beta(Concentration1, Concentration0, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Beta distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Beta(Concentration1, Concentration0));
}
}
94 changes: 94 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Binomial.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Creates a Binomial probability distribution with a given number of trials and success probability.
/// </summary>
[Combinator]
[Description("Creates a Binomial distribution with count (number of trials) and probability of success.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Binomial : IScalarTypeProvider
{
/// <summary>
/// The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.")]
public Tensor Count { get; set; } = null;

/// <summary>
/// The values of count in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Count))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string CountXml
{
get => TensorConverter.ConvertToString(Count, Type);
set => Count = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Probability of success p in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to <see cref="Count"/>.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Probability of success in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to Count.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Binomial distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process()
{
return Observable.Return(distributions.Binomial(Count, Probabilities));
}

/// <summary>
/// Creates a Binomial distribution for each incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Binomial(Count, Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Binomial distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Binomial(Count, Probabilities));
}
}
75 changes: 75 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Categorical.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using TorchSharp;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Creates a categorical (discrete) distribution over classes given event probabilities.
/// </summary>
[Combinator]
[Description("Creates a categorical (discrete) distribution over classes given event probabilities.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Categorical : IScalarTypeProvider
{
/// <summary>
/// The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a categorical distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process()
{
return Observable.Return(distributions.Categorical(Probabilities));
}

/// <summary>
/// Creates a categorical distribution for each incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Categorical(Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a categorical distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Categorical(Probabilities));
}
}
Loading
Loading