From 18153ee5e3a67842e2dbc6171cf1a83c3d19ae66 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 7 Feb 2025 13:29:25 +0000 Subject: [PATCH 01/20] Added normalization step to likelihood calculation --- .../Likelihood/ClusterlessLikelihood.cs | 7 +++++-- .../Likelihood/PoissonLikelihood.cs | 13 +++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs index e13780c..aecf547 100644 --- a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs @@ -73,9 +73,12 @@ IEnumerable conditionalIntensities logLikelihood -= logLikelihood .max(dim: -1, keepdim: true) .values; - return logLikelihood + logLikelihood = logLikelihood .exp() - .nan_to_num() + .nan_to_num(); + logLikelihood /= logLikelihood + .sum(dim: -1, keepdim: true); + return logLikelihood .MoveToOuterDisposeScope(); } diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index 4908e47..790cf18 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -33,13 +33,18 @@ IEnumerable conditionalIntensities using var _ = NewDisposeScope(); var conditionalIntensity = conditionalIntensities.First(); var conditionalIntensityTensor = conditionalIntensity.flatten(1).T.unsqueeze(0); - var logLikelihood = (xlogy(inputs.unsqueeze(1), conditionalIntensityTensor) - conditionalIntensityTensor) + var logLikelihood = (inputs.unsqueeze(1) * conditionalIntensityTensor - conditionalIntensityTensor.exp()) .nan_to_num() .sum(dim: -1); - logLikelihood -= logLikelihood.max(dim: -1, keepdim: true).values; - return logLikelihood + logLikelihood -= logLikelihood + .max(dim: -1, keepdim: true) + .values; + logLikelihood = logLikelihood .exp() - .nan_to_num() + .nan_to_num(); + logLikelihood /= logLikelihood + .sum(dim: -1, keepdim: true); + return logLikelihood .MoveToOuterDisposeScope(); } } From d6abc992c133194f22203ff4e1e12d1a544b71d5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 7 Feb 2025 13:29:52 +0000 Subject: [PATCH 02/20] Removed extra steps to move to outer dispose scopes --- .../Encoder/ClusterlessMarkEncoder.cs | 3 +-- .../Encoder/SortedSpikeEncoder.cs | 21 +++++++------------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index 8b300a7..b0a1ed3 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -261,8 +261,7 @@ public void Encode(Tensor observations, Tensor marks) _samples += observations.shape[0]; } - _rates = (_spikeCounts.log() - _samples.log()) - .MoveToOuterDisposeScope(); + _rates = _spikeCounts.log() - _samples.log(); var mask = ~marks.isnan().all(dim: 1); diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index 5c5bdf0..dba1046 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -148,26 +148,19 @@ public void Encode(Tensor observations, Tensor inputs) if (_spikeCounts.numel() == 0) { _spikeCounts = inputs.nan_to_num() - .sum(dim: 0); - _samples = observations.shape[0]; + .sum(dim: 0) + .to(_device); + _samples = tensor(observations.shape[0], device: _device); } else { _spikeCounts += inputs.nan_to_num() - .sum(dim: 0); + .sum(dim: 0) + .to(_device); _samples += observations.shape[0]; } - _spikeCounts = _spikeCounts - .to(_device) - .MoveToOuterDisposeScope(); - - _samples = _samples - .to(_device) - .MoveToOuterDisposeScope(); - - _rates = (_spikeCounts.log() - _samples.log()) - .MoveToOuterDisposeScope(); + _rates = _spikeCounts.log() - _samples.log(); var inputMask = inputs.to_type(ScalarType.Bool); @@ -201,7 +194,7 @@ public IEnumerable Evaluate(params Tensor[] inputs) var unitDensity = _unitEstimation[i].Evaluate(_stateSpace.Points) .log(); - unitConditionalIntensities[i] = exp(_rates[i] + unitDensity - observationDensity) + unitConditionalIntensities[i] = (_rates[i] + unitDensity - observationDensity) .reshape(_stateSpace.Shape); } var output = stack(unitConditionalIntensities, dim: 0) From 2e4fa6ec829716147b579d73372edbd088be83dd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 7 Feb 2025 13:42:37 +0000 Subject: [PATCH 03/20] Changed inputs to bool type in poisson likelihood calculation --- .../Likelihood/PoissonLikelihood.cs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index 790cf18..589e141 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -33,9 +33,13 @@ IEnumerable conditionalIntensities using var _ = NewDisposeScope(); var conditionalIntensity = conditionalIntensities.First(); var conditionalIntensityTensor = conditionalIntensity.flatten(1).T.unsqueeze(0); - var logLikelihood = (inputs.unsqueeze(1) * conditionalIntensityTensor - conditionalIntensityTensor.exp()) - .nan_to_num() - .sum(dim: -1); + var logLikelihood = (inputs + .to_type(ScalarType.Bool) + .unsqueeze(1) + * conditionalIntensityTensor + - conditionalIntensityTensor.exp()) + .nan_to_num() + .sum(dim: -1); logLikelihood -= logLikelihood .max(dim: -1, keepdim: true) .values; From 94210b0b4d66fec459a8d42508509df88ec7bde0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Feb 2025 14:41:10 +0000 Subject: [PATCH 04/20] Removed use of IDisposable interface and dispose methods since its not needed --- .../Decoder/StateSpaceDecoder.cs | 8 ------ .../Encoder/ClusterlessMarkEncoder.cs | 22 ---------------- .../Encoder/SortedSpikeEncoder.cs | 26 ------------------- .../Estimation/KernelCompression.cs | 7 ----- .../Estimation/KernelDensity.cs | 7 ----- .../IModelComponent.cs | 2 +- .../Likelihood/ClusterlessLikelihood.cs | 6 ----- src/PointProcessDecoder.Core/ModelBase.cs | 3 --- .../PointProcessModel.cs | 9 ------- .../StateSpace/DiscreteUniformStateSpace.cs | 6 ----- .../Transitions/RandomWalkTransitions.cs | 7 ----- .../Transitions/UniformTransitions.cs | 6 ----- .../PointProcessDecoder.Cpu.Test/TestModel.cs | 4 --- 13 files changed, 1 insertion(+), 112 deletions(-) diff --git a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs index e16678c..718c8af 100644 --- a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs +++ b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs @@ -106,12 +106,4 @@ public Tensor Decode(Tensor inputs, Tensor likelihood) _posterior.MoveToOuterDisposeScope(); return output.MoveToOuterDisposeScope(); } - - /// - public override void Dispose() - { - _stateTransitions.Dispose(); - _initialState.Dispose(); - _posterior.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index b0a1ed3..2b1aacf 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -469,26 +469,4 @@ public override IModelComponent Load(string basePath) return this; } - - /// - public override void Dispose() - { - _observationEstimation.Dispose(); - foreach (var estimation in _channelEstimation) - { - estimation.Dispose(); - } - foreach (var estimation in _markEstimation) - { - estimation.Dispose(); - } - _estimations = []; - _updateConditionalIntensities = true; - _conditionalIntensities = [empty(0)]; - _markConditionalIntensities.Dispose(); - _channelConditionalIntensities.Dispose(); - _spikeCounts.Dispose(); - _samples.Dispose(); - _rates.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index dba1046..913f30f 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -279,30 +279,4 @@ public override IModelComponent Load(string basePath) return this; } - - /// - public override void Dispose() - { - _observationEstimation.Dispose(); - foreach (var estimation in _unitEstimation) - { - estimation.Dispose(); - } - _estimations = []; - - _updateConditionalIntensities = true; - _conditionalIntensities = [empty(0)]; - - _unitConditionalIntensities.Dispose(); - _unitConditionalIntensities = empty(0); - - _spikeCounts.Dispose(); - _spikeCounts = empty(0); - - _samples.Dispose(); - _samples = empty(0); - - _rates.Dispose(); - _rates = empty(0); - } } diff --git a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs index 799a8ef..f7abe99 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs @@ -269,12 +269,5 @@ public override IModelComponent Load(string basePath) _kernels = Tensor.Load(Path.Combine(basePath, "kernels.bin")).to(_device); return this; } - - /// - public override void Dispose() - { - _kernels.Dispose(); - _kernels = empty(0); - } } diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index 5aad927..1b80da3 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -217,11 +217,4 @@ public override IModelComponent Load(string basePath) _kernels = Tensor.Load(Path.Combine(basePath, "kernels.bin")).to(_device); return this; } - - /// - public override void Dispose() - { - _kernels.Dispose(); - _kernels = empty(0); - } } diff --git a/src/PointProcessDecoder.Core/IModelComponent.cs b/src/PointProcessDecoder.Core/IModelComponent.cs index fbec348..2395c0b 100644 --- a/src/PointProcessDecoder.Core/IModelComponent.cs +++ b/src/PointProcessDecoder.Core/IModelComponent.cs @@ -5,7 +5,7 @@ namespace PointProcessDecoder.Core; /// /// Represents a single component of the model. /// -public interface IModelComponent : IDisposable +public interface IModelComponent { /// /// The device on which the model component in located. diff --git a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs index aecf547..27624a4 100644 --- a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs @@ -81,10 +81,4 @@ IEnumerable conditionalIntensities return logLikelihood .MoveToOuterDisposeScope(); } - - /// - public override void Dispose() - { - _noSpikeLikelihood.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/ModelBase.cs b/src/PointProcessDecoder.Core/ModelBase.cs index 3dce655..975dbb3 100644 --- a/src/PointProcessDecoder.Core/ModelBase.cs +++ b/src/PointProcessDecoder.Core/ModelBase.cs @@ -44,7 +44,4 @@ public virtual void Save(string basePath) { } /// /// public static IModelComponent Load(string basePath, Device? device = null) => throw new NotImplementedException(); - - /// - public virtual void Dispose() { } } diff --git a/src/PointProcessDecoder.Core/PointProcessModel.cs b/src/PointProcessDecoder.Core/PointProcessModel.cs index d326b30..f4c5df8 100644 --- a/src/PointProcessDecoder.Core/PointProcessModel.cs +++ b/src/PointProcessDecoder.Core/PointProcessModel.cs @@ -282,13 +282,4 @@ public override void Save(string basePath) return model; } - - /// - public override void Dispose() - { - _encoderModel.Dispose(); - _decoderModel.Dispose(); - _likelihood.Dispose(); - _stateSpace.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/StateSpace/DiscreteUniformStateSpace.cs b/src/PointProcessDecoder.Core/StateSpace/DiscreteUniformStateSpace.cs index 4a5d419..de44dc0 100644 --- a/src/PointProcessDecoder.Core/StateSpace/DiscreteUniformStateSpace.cs +++ b/src/PointProcessDecoder.Core/StateSpace/DiscreteUniformStateSpace.cs @@ -91,10 +91,4 @@ ScalarType scalarType .to(device) .MoveToOuterDisposeScope(); } - - /// - public override void Dispose() - { - _points.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs b/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs index 293683d..5637d7f 100644 --- a/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs +++ b/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs @@ -80,11 +80,4 @@ private static Tensor ComputeRandomWalkTransitions( .to(device: device) .MoveToOuterDisposeScope(); } - - /// - public override void Dispose() - { - _transitions.Dispose(); - _sigma?.Dispose(); - } } diff --git a/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs b/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs index e3c670f..234f3a7 100644 --- a/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs +++ b/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs @@ -62,10 +62,4 @@ ScalarType scalarType .to(device) .MoveToOuterDisposeScope(); } - - /// - public override void Dispose() - { - _transitions.Dispose(); - } } diff --git a/test/PointProcessDecoder.Cpu.Test/TestModel.cs b/test/PointProcessDecoder.Cpu.Test/TestModel.cs index 2775ab7..f1198a1 100644 --- a/test/PointProcessDecoder.Cpu.Test/TestModel.cs +++ b/test/PointProcessDecoder.Cpu.Test/TestModel.cs @@ -417,8 +417,6 @@ public void CompareClusterlessEncodingBatchSizes() var prediction1 = pointProcessModel.Decode(marks[TensorIndex.Slice(nBatches * batchSize, (nBatches + 1) * batchSize)]) .sum(dim: 0); - pointProcessModel.Dispose(); - pointProcessModel = new PointProcessModel( estimationMethod: EstimationMethod.KernelCompression, transitionsType: TransitionsType.RandomWalk, @@ -499,8 +497,6 @@ public void CompareSortedUnitsEncodingBatchSizes() var prediction1 = pointProcessModel.Decode(spikingData[TensorIndex.Slice(nBatches * batchSize, (nBatches + 1) * batchSize)]) .sum(dim: 0); - pointProcessModel.Dispose(); - pointProcessModel = new PointProcessModel( estimationMethod: EstimationMethod.KernelCompression, transitionsType: TransitionsType.RandomWalk, From c9405294b6c385cc5dbb461acd0efff2130a08a4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 12 Feb 2025 14:42:31 +0000 Subject: [PATCH 05/20] Updated likelihood interface to the method name Likelihood instead of logLikelihood since it actually returns likelihood as a probability measure --- src/PointProcessDecoder.Core/ILikelihood.cs | 2 +- .../Likelihood/ClusterlessLikelihood.cs | 2 +- .../Likelihood/PoissonLikelihood.cs | 8 +++----- src/PointProcessDecoder.Core/PointProcessModel.cs | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/PointProcessDecoder.Core/ILikelihood.cs b/src/PointProcessDecoder.Core/ILikelihood.cs index 6eb2056..96e0270 100644 --- a/src/PointProcessDecoder.Core/ILikelihood.cs +++ b/src/PointProcessDecoder.Core/ILikelihood.cs @@ -18,5 +18,5 @@ public interface ILikelihood : IModelComponent /// /// /// - public Tensor LogLikelihood(Tensor inputs, IEnumerable conditionalIntensities); + public Tensor Likelihood(Tensor inputs, IEnumerable conditionalIntensities); } diff --git a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs index 27624a4..d3da64d 100644 --- a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs @@ -57,7 +57,7 @@ public ClusterlessLikelihood( } /// - public Tensor LogLikelihood( + public Tensor Likelihood( Tensor inputs, IEnumerable conditionalIntensities ) diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index 589e141..c2697df 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -25,19 +25,17 @@ public class PoissonLikelihood( public LikelihoodType LikelihoodType => LikelihoodType.Poisson; /// - public Tensor LogLikelihood( + public Tensor Likelihood( Tensor inputs, IEnumerable conditionalIntensities ) { using var _ = NewDisposeScope(); var conditionalIntensity = conditionalIntensities.First(); - var conditionalIntensityTensor = conditionalIntensity.flatten(1).T.unsqueeze(0); var logLikelihood = (inputs - .to_type(ScalarType.Bool) .unsqueeze(1) - * conditionalIntensityTensor - - conditionalIntensityTensor.exp()) + * conditionalIntensity + - conditionalIntensity.exp()) .nan_to_num() .sum(dim: -1); logLikelihood -= logLikelihood diff --git a/src/PointProcessDecoder.Core/PointProcessModel.cs b/src/PointProcessDecoder.Core/PointProcessModel.cs index f4c5df8..17ad9f2 100644 --- a/src/PointProcessDecoder.Core/PointProcessModel.cs +++ b/src/PointProcessDecoder.Core/PointProcessModel.cs @@ -203,7 +203,7 @@ public override void Encode(Tensor observations, Tensor inputs) public override Tensor Decode(Tensor inputs) { var conditionalIntensities = _encoderModel.Evaluate(inputs); - var likelihood = _likelihood.LogLikelihood(inputs, conditionalIntensities); + var likelihood = _likelihood.Likelihood(inputs, conditionalIntensities); return _decoderModel.Decode(inputs, likelihood); } From 20f8cb0a776aae801f24aa8455d0497b56a2f52f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Feb 2025 10:56:53 +0000 Subject: [PATCH 06/20] Modified to use call to tensor.size instead of tensor.shape since tensor.shape just calls tensor.size --- .../Decoder/StateSpaceDecoder.cs | 6 ++--- .../Encoder/ClusterlessMarkEncoder.cs | 18 +++++++-------- .../Encoder/SortedSpikeEncoder.cs | 8 +++---- .../Estimation/KernelCompression.cs | 23 +++++++++---------- .../Estimation/KernelDensity.cs | 20 ++++++++-------- .../PointProcessModel.cs | 4 ++-- .../Transitions/RandomWalkTransitions.cs | 4 ++-- .../Transitions/UniformTransitions.cs | 2 +- 8 files changed, 41 insertions(+), 44 deletions(-) diff --git a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs index 718c8af..637831b 100644 --- a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs +++ b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs @@ -71,7 +71,7 @@ public StateSpaceDecoder( _ => throw new ArgumentException("Invalid transitions type.") }; - var n = _stateSpace.Points.shape[0]; + var n = _stateSpace.Points.size(0); _initialState = ones(n, dtype: _scalarType, device: _device) / n; _posterior = empty(0); } @@ -81,7 +81,7 @@ public Tensor Decode(Tensor inputs, Tensor likelihood) { using var _ = NewDisposeScope(); - var outputShape = new long[] { inputs.shape[0] } + var outputShape = new long[] { inputs.size(0) } .Concat(_stateSpace.Shape) .ToArray(); @@ -95,7 +95,7 @@ public Tensor Decode(Tensor inputs, Tensor likelihood) output[0] = _posterior.reshape(_stateSpace.Shape); } - for (int i = 1; i < inputs.shape[0]; i++) + for (int i = 1; i < inputs.size(0); i++) { _posterior = (_stateTransitions.Transitions.matmul(_posterior) * likelihood[i].flatten()) .nan_to_num() diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index 2b1aacf..7440584 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -206,7 +206,7 @@ Tensor marks using var _ = NewDisposeScope(); var markKernelEstimate = _markEstimation[i].Estimate(marks); var markDensity = markKernelEstimate.matmul(_channelEstimates[i].T) - / markKernelEstimate.shape[1]; + / markKernelEstimate.size(1); return (_rates[i] + markDensity.log() - _observationDensity) .nan_to_num() .MoveToOuterDisposeScope(); @@ -220,7 +220,7 @@ Tensor marks using var _ = NewDisposeScope(); var markKernelEstimate = _markEstimation[i].Estimate(marks, _stateSpace.Dimensions); var markDensity = markKernelEstimate.matmul(_markStateSpaceKernelEstimates[i].T) - / markKernelEstimate.shape[1]; + / markKernelEstimate.size(1); return (_rates[i] + markDensity.log() - _observationDensity) .nan_to_num() .MoveToOuterDisposeScope(); @@ -229,17 +229,17 @@ Tensor marks /// public void Encode(Tensor observations, Tensor marks) { - if (marks.shape[1] != _markDimensions) + if (marks.size(1) != _markDimensions) { throw new ArgumentException("The number of mark dimensions must match the shape of the marks tensor on dimension 1."); } - if (marks.shape[2] != _markChannels) + if (marks.size(2) != _markChannels) { throw new ArgumentException("The number of mark channels must match the shape of the marks tensor on dimension 2."); } - if (observations.shape[1] != _stateSpace.Dimensions) + if (observations.size(1) != _stateSpace.Dimensions) { throw new ArgumentException("The number of observation dimensions must match the dimensions of the state space."); } @@ -251,14 +251,14 @@ public void Encode(Tensor observations, Tensor marks) _spikeCounts = (marks.sum(dim: 1) > 0) .sum(dim: 0) .to(_device); - _samples = tensor(observations.shape[0], device: _device); + _samples = tensor(observations.size(0), device: _device); } else { _spikeCounts += (marks.sum(dim: 1) > 0) .sum(dim: 0); - _samples += observations.shape[0]; + _samples += observations.size(0); } _rates = _spikeCounts.log() - _samples.log(); @@ -289,7 +289,7 @@ private Tensor EvaluateMarkConditionalIntensities(Tensor inputs) using var _ = NewDisposeScope(); var markConditionalIntensities = zeros( - [_markChannels, inputs.shape[0], _stateSpace.Points.shape[0]], + [_markChannels, inputs.size(0), _stateSpace.Points.size(0)], device: _device, dtype: _scalarType ); @@ -322,7 +322,7 @@ private Tensor EvaluateChannelConditionalIntensities() .log(); var channelConditionalIntensities = zeros( - [_markChannels, _stateSpace.Points.shape[0]], + [_markChannels, _stateSpace.Points.size(0)], device: _device, dtype: _scalarType ); diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index 913f30f..e54d8ec 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -133,12 +133,12 @@ private static IEstimation GetEstimationMethod( /// public void Encode(Tensor observations, Tensor inputs) { - if (inputs.shape[1] != _nUnits) + if (inputs.size(1) != _nUnits) { throw new ArgumentException("The number of units in the input tensor must match the expected number of units."); } - if (observations.shape[1] != _stateSpace.Dimensions) + if (observations.size(1) != _stateSpace.Dimensions) { throw new ArgumentException("The number of observation dimensions must match the dimensions of the state space."); } @@ -150,14 +150,14 @@ public void Encode(Tensor observations, Tensor inputs) _spikeCounts = inputs.nan_to_num() .sum(dim: 0) .to(_device); - _samples = tensor(observations.shape[0], device: _device); + _samples = tensor(observations.size(0), device: _device); } else { _spikeCounts += inputs.nan_to_num() .sum(dim: 0) .to(_device); - _samples += observations.shape[0]; + _samples += observations.size(0); } _rates = _spikeCounts.log() - _samples.log(); diff --git a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs index f7abe99..4f4c47b 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs @@ -115,11 +115,11 @@ public KernelCompression( /// public void Fit(Tensor data) { - if (data.shape[1] != _dimensions) + if (data.size(1) != _dimensions) { throw new ArgumentException("Data shape must match expected dimensions"); } - if (data.shape[0] == 0) return; + if (data.size(0) == 0) return; using var _ = NewDisposeScope(); @@ -127,7 +127,7 @@ public void Fit(Tensor data) { _kernels = stack([_weight, data[0], _kernelBandwidth], dim: 1) .unsqueeze(0); - if (data.shape[0] == 1) + if (data.size(0) == 1) { _kernels.MoveToOuterDisposeScope(); return; @@ -135,13 +135,13 @@ public void Fit(Tensor data) data = data[TensorIndex.Slice(1)]; } - for (int i = 0; i < data.shape[0]; i++) + for (int i = 0; i < data.size(0); i++) { var kernel = stack([_weight, data[i], _kernelBandwidth], dim: 1); var dist = CalculateMahalanobisDistance(data[i]); var (minDist, argminDist) = dist.min(0); if ((minDist > _distanceThreshold).item() - && _kernels.shape[0] < _kernelLimit) + && _kernels.size(0) < _kernelLimit) { _kernels = concat([_kernels, kernel.unsqueeze(0)], dim: 0); continue; @@ -196,9 +196,8 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension public Tensor Normalize(Tensor points) { using var _ = NewDisposeScope(); - var density = (points.sum(dim: -1) - / points.shape[1]) - .clamp_min(_eps); + var density = points.sum(dim: -1) + / points.size(1); density /= density.sum(); return density .to_type(_scalarType) @@ -209,7 +208,7 @@ public Tensor Normalize(Tensor points) /// public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) { - if (min.shape[0] != _dimensions || max.shape[0] != _dimensions || steps.shape[0] != _dimensions) + if (min.size(0) != _dimensions || max.size(0) != _dimensions || steps.size(0) != _dimensions) { throw new ArgumentException("The lengths of min, max, and steps must be equal to the number of dimensions."); } @@ -225,8 +224,8 @@ public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) } using var _ = NewDisposeScope(); - var arrays = new Tensor[min.shape[0]]; - for (int i = 0; i < min.shape[0]; i++) + var arrays = new Tensor[min.size(0)]; + for (int i = 0; i < min.size(0); i++) { arrays[i] = linspace(min[i].item(), max[i].item(), steps[i].item()).to(_device); } @@ -241,7 +240,7 @@ public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) /// public Tensor Evaluate(Tensor points) { - if (points.shape[1] != _dimensions) + if (points.size(1) != _dimensions) { throw new ArgumentException("The number of dimensions must match the shape of the data."); } diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index 1b80da3..ef381c3 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -89,7 +89,7 @@ public KernelDensity( /// public void Fit(Tensor data) { - if (data.shape[1] != _dimensions) + if (data.size(1) != _dimensions) { throw new ArgumentException("The number of dimensions must match the shape of the data."); } @@ -104,9 +104,9 @@ public void Fit(Tensor data) _kernels = cat([ _kernels, data ], dim: 0); - if (_kernels.shape[0] > _kernelLimit) + if (_kernels.size(0) > _kernelLimit) { - var start = _kernels.shape[0] - _kernelLimit; + var start = _kernels.size(0) - _kernelLimit; _kernels = _kernels[TensorIndex.Slice(start)]; } @@ -140,10 +140,8 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension public Tensor Normalize(Tensor points) { using var _ = NewDisposeScope(); - - var density = (points.sum(dim: -1) - / points.shape[1]) - .clamp_min(_eps); + var density = points.sum(dim: -1) + / points.size(1); density /= density.sum(); return density .to_type(_scalarType) @@ -154,7 +152,7 @@ public Tensor Normalize(Tensor points) /// public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) { - if (min.shape[0] != _dimensions || max.shape[0] != _dimensions || steps.shape[0] != _dimensions) + if (min.size(0) != _dimensions || max.size(0) != _dimensions || steps.size(0) != _dimensions) { throw new ArgumentException("The lengths of min, max, and steps must be equal to the number of dimensions."); } @@ -170,8 +168,8 @@ public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) } using var _ = NewDisposeScope(); - var arrays = new Tensor[min.shape[0]]; - for (int i = 0; i < min.shape[0]; i++) + var arrays = new Tensor[min.size(0)]; + for (int i = 0; i < min.size(0); i++) { arrays[i] = linspace(min[i].item(), max[i].item(), steps[i].item(), dtype: _scalarType, device: _device); } @@ -189,7 +187,7 @@ public Tensor Evaluate(Tensor min, Tensor max, Tensor steps) /// public Tensor Evaluate(Tensor points) { - if (points.shape[1] != _dimensions) + if (points.size(1) != _dimensions) { throw new ArgumentException("The number of dimensions must match the shape of the data."); } diff --git a/src/PointProcessDecoder.Core/PointProcessModel.cs b/src/PointProcessDecoder.Core/PointProcessModel.cs index 17ad9f2..92bc2a1 100644 --- a/src/PointProcessDecoder.Core/PointProcessModel.cs +++ b/src/PointProcessDecoder.Core/PointProcessModel.cs @@ -187,11 +187,11 @@ public PointProcessModel( /// public override void Encode(Tensor observations, Tensor inputs) { - if (observations.shape[1] != _stateSpace.Dimensions) + if (observations.size(1) != _stateSpace.Dimensions) { throw new ArgumentException("The number of latent dimensions must match the shape of the observations."); } - if (observations.shape[0] != inputs.shape[0]) + if (observations.size(0) != inputs.size(0)) { throw new ArgumentException("The number of observations must match the number of inputs."); } diff --git a/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs b/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs index 5637d7f..7677386 100644 --- a/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs +++ b/src/PointProcessDecoder.Core/Transitions/RandomWalkTransitions.cs @@ -73,8 +73,8 @@ private static Tensor ComputeRandomWalkTransitions( .sum(dim: 2); var estimate = exp(-0.5 * sumSquaredDiff / bandwidth); var weights = estimate / sqrt(pow(2 * Math.PI, stateSpace.Dimensions) * bandwidth); - var transitions = weights / points.shape[1]; - + var transitions = weights / points.size(1); + transitions /= transitions.sum(dim: 1, keepdim: true); return transitions .to_type(type: scalarType) .to(device: device) diff --git a/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs b/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs index 234f3a7..6be3737 100644 --- a/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs +++ b/src/PointProcessDecoder.Core/Transitions/UniformTransitions.cs @@ -54,7 +54,7 @@ ScalarType scalarType ) { using var _ = NewDisposeScope(); - var n = stateSpace.Points.shape[0]; + var n = stateSpace.Points.size(0); var transitions = ones(n, n, device: device, dtype: scalarType); transitions /= transitions.sum(1, true); return transitions From c797fc83fb598e9acd4427dc8f58ff380c11902f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Feb 2025 10:59:04 +0000 Subject: [PATCH 07/20] Removed eps declaration since it is no longer needed --- src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs | 2 -- src/PointProcessDecoder.Core/Estimation/KernelCompression.cs | 4 ---- src/PointProcessDecoder.Core/Estimation/KernelDensity.cs | 4 ---- 3 files changed, 10 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index e54d8ec..54130c1 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -37,7 +37,6 @@ public class SortedSpikeEncoder : ModelComponent, IEncoder private Tensor _samples = empty(0); private Tensor _rates = empty(0); private readonly IStateSpace _stateSpace; - private readonly double _eps; /// /// Initializes a new instance of the class. @@ -68,7 +67,6 @@ public SortedSpikeEncoder( _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; - _eps = finfo(_scalarType).eps; _stateSpace = stateSpace; _nUnits = nUnits; diff --git a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs index 4f4c47b..e217419 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs @@ -51,8 +51,6 @@ public class KernelCompression : ModelComponent, IEstimation /// public int Dimensions => _dimensions; - private readonly double _eps; - /// /// Initializes a new instance of the class. /// @@ -71,7 +69,6 @@ public KernelCompression( { _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; - _eps = finfo(_scalarType).eps; _distanceThreshold = distanceThreshold ?? double.NegativeInfinity; _kernelLimit = kernelLimit ?? int.MaxValue; _dimensions = dimensions ?? 1; @@ -104,7 +101,6 @@ public KernelCompression( _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; - _eps = finfo(_scalarType).eps; _distanceThreshold = distanceThreshold ?? double.NegativeInfinity; _kernelLimit = kernelLimit ?? int.MaxValue; _dimensions = dimensions; diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index ef381c3..d50f811 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -31,8 +31,6 @@ public class KernelDensity : ModelComponent, IEstimation private Tensor _kernels = empty(0); /// public Tensor Kernels => _kernels; - - private readonly double _eps; private readonly int _kernelLimit; /// @@ -52,7 +50,6 @@ public KernelDensity( _dimensions = dimensions ?? 1; _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; - _eps = finfo(_scalarType).eps; _kernelBandwidth = tensor(bandwidth ?? 1.0, device: _device, dtype: _scalarType) .repeat(_dimensions); _kernelLimit = kernelLimit ?? int.MaxValue; @@ -81,7 +78,6 @@ public KernelDensity( _dimensions = dimensions; _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; - _eps = finfo(_scalarType).eps; _kernelBandwidth = tensor(bandwidth, device: _device, dtype: _scalarType); _kernelLimit = kernelLimit ?? int.MaxValue; } From b00ba63ac6a22e7c2887670d7673f09cc91c93d2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Feb 2025 11:00:58 +0000 Subject: [PATCH 08/20] Updated poisson likelihood calculation to remove redundant steps of normalizing in log space. Also removed step to reshape conditional intensity in sorted spike encoder since this is not necessary --- .../Encoder/SortedSpikeEncoder.cs | 4 +--- .../Likelihood/PoissonLikelihood.cs | 18 +++++++----------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index 54130c1..4b660ff 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -191,9 +191,7 @@ public IEnumerable Evaluate(params Tensor[] inputs) { var unitDensity = _unitEstimation[i].Evaluate(_stateSpace.Points) .log(); - - unitConditionalIntensities[i] = (_rates[i] + unitDensity - observationDensity) - .reshape(_stateSpace.Shape); + unitConditionalIntensities[i] = _rates[i] + unitDensity - observationDensity; } var output = stack(unitConditionalIntensities, dim: 0) .MoveToOuterDisposeScope(); diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index c2697df..8080a70 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -31,19 +31,15 @@ IEnumerable conditionalIntensities ) { using var _ = NewDisposeScope(); - var conditionalIntensity = conditionalIntensities.First(); - var logLikelihood = (inputs - .unsqueeze(1) - * conditionalIntensity + var conditionalIntensity = conditionalIntensities.First() + .unsqueeze(0); + var logLikelihood = (inputs.unsqueeze(-1) + * conditionalIntensity - conditionalIntensity.exp()) .nan_to_num() - .sum(dim: -1); - logLikelihood -= logLikelihood - .max(dim: -1, keepdim: true) - .values; - logLikelihood = logLikelihood - .exp() - .nan_to_num(); + .sum(dim: 1) + .exp() + .nan_to_num(); logLikelihood /= logLikelihood .sum(dim: -1, keepdim: true); return logLikelihood From 5504c92e331ad1e8692deebb186726b92f8e2ea0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Feb 2025 11:03:04 +0000 Subject: [PATCH 09/20] Modifed estimate methods to return NaN tensor of correct size when there are no kernels --- src/PointProcessDecoder.Core/Estimation/KernelCompression.cs | 2 +- src/PointProcessDecoder.Core/Estimation/KernelDensity.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs index e217419..0814bce 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs @@ -172,7 +172,7 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension using var _ = NewDisposeScope(); if (_kernels.numel() == 0) { - return (ones([1, 1], dtype: _scalarType, device: _device) * float.NaN) + return (ones([points.size(0), 1], dtype: _scalarType, device: _device) * float.NaN) .MoveToOuterDisposeScope(); } var kernels = _kernels[TensorIndex.Colon, TensorIndex.Slice(dimensionStart, dimensionEnd)]; diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index d50f811..9f8cac8 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -116,7 +116,7 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension using var _ = NewDisposeScope(); if (_kernels.numel() == 0) { - return (ones([1, 1], dtype: _scalarType, device: _device) * float.NaN) + return (ones([points.size(0), 1], dtype: _scalarType, device: _device) * float.NaN) .MoveToOuterDisposeScope(); } var kernels = _kernels[TensorIndex.Colon, TensorIndex.Slice(dimensionStart, dimensionEnd)]; From 732655f6ad4417d43343cb4ffe6588d44cbf61bf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:07:53 +0000 Subject: [PATCH 10/20] Updated with new naming --- src/PointProcessDecoder.Core/IEncoder.cs | 6 +++--- src/PointProcessDecoder.Core/ILikelihood.cs | 6 +++--- .../Likelihood/PoissonLikelihood.cs | 18 +++++++++++------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/PointProcessDecoder.Core/IEncoder.cs b/src/PointProcessDecoder.Core/IEncoder.cs index 957697e..9aa2c87 100644 --- a/src/PointProcessDecoder.Core/IEncoder.cs +++ b/src/PointProcessDecoder.Core/IEncoder.cs @@ -13,9 +13,9 @@ public interface IEncoder : IModelComponent public Encoder.EncoderType EncoderType { get; } /// - /// The conditional intensities of the model. + /// The intensities of the model. /// - public Tensor[] ConditionalIntensities { get; } + public Tensor[] Intensities { get; } /// /// The estimations of the model. @@ -23,7 +23,7 @@ public interface IEncoder : IModelComponent public IEstimation[] Estimations { get; } /// - /// Evaluates the conditional intensities of the model given the inputs. + /// Evaluates the intensities of the model given the inputs. /// /// /// diff --git a/src/PointProcessDecoder.Core/ILikelihood.cs b/src/PointProcessDecoder.Core/ILikelihood.cs index 96e0270..d6efb02 100644 --- a/src/PointProcessDecoder.Core/ILikelihood.cs +++ b/src/PointProcessDecoder.Core/ILikelihood.cs @@ -13,10 +13,10 @@ public interface ILikelihood : IModelComponent public Likelihood.LikelihoodType LikelihoodType { get; } /// - /// Measures the likelihood of the model given the inputs and the conditional intensities. + /// Measures the likelihood of the model given the inputs and the intensities. /// /// - /// + /// /// - public Tensor Likelihood(Tensor inputs, IEnumerable conditionalIntensities); + public Tensor Likelihood(Tensor inputs, IEnumerable intensities); } diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index 8080a70..bbfe397 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -27,22 +27,26 @@ public class PoissonLikelihood( /// public Tensor Likelihood( Tensor inputs, - IEnumerable conditionalIntensities + IEnumerable intensities ) { using var _ = NewDisposeScope(); - var conditionalIntensity = conditionalIntensities.First() + + var intensity = intensities.First() .unsqueeze(0); - var logLikelihood = (inputs.unsqueeze(-1) - * conditionalIntensity - - conditionalIntensity.exp()) + + var likelihood = ((inputs.unsqueeze(-1) + * intensity) + - intensity.exp()) .nan_to_num() .sum(dim: 1) .exp() .nan_to_num(); - logLikelihood /= logLikelihood + + likelihood /= likelihood .sum(dim: -1, keepdim: true); - return logLikelihood + + return likelihood .MoveToOuterDisposeScope(); } } From 3d370bb9e18e2b86f9dd47173580e0883655cb41 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:08:13 +0000 Subject: [PATCH 11/20] Remove call to flatten since these should be falttened anyways --- src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs index 637831b..1ab60a3 100644 --- a/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs +++ b/src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs @@ -88,7 +88,7 @@ public Tensor Decode(Tensor inputs, Tensor likelihood) var output = zeros(outputShape, dtype: _scalarType, device: _device); if (_posterior.numel() == 0) { - _posterior = (_initialState * likelihood[0].flatten()) + _posterior = (_initialState * likelihood[0]) .nan_to_num() .clamp_min(_eps); _posterior /= _posterior.sum(); @@ -97,7 +97,7 @@ public Tensor Decode(Tensor inputs, Tensor likelihood) for (int i = 1; i < inputs.size(0); i++) { - _posterior = (_stateTransitions.Transitions.matmul(_posterior) * likelihood[i].flatten()) + _posterior = (_stateTransitions.Transitions.matmul(_posterior) * likelihood[i]) .nan_to_num() .clamp_min(_eps); _posterior /= _posterior.sum(); From a70fb10b369c1764a7503fd44748b4b272820eaf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:10:26 +0000 Subject: [PATCH 12/20] Updated clusterless marks encoder --- .../Encoder/ClusterlessMarkEncoder.cs | 204 +++++------------- 1 file changed, 56 insertions(+), 148 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index 7440584..11972fd 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -20,35 +20,28 @@ public class ClusterlessMarkEncoder : ModelComponent, IEncoder /// public EncoderType EncoderType => EncoderType.ClusterlessMarkEncoder; - private Tensor[] _conditionalIntensities = [empty(0)]; /// - public Tensor[] ConditionalIntensities => _conditionalIntensities; + public Tensor[] Intensities => [_channelIntensities, _markIntensities]; - private IEstimation[] _estimations = []; /// - public IEstimation[] Estimations => _estimations; + public IEstimation[] Estimations => [_observationEstimation, .. _markEstimation]; private readonly IEstimation _observationEstimation; - private readonly IEstimation[] _channelEstimation; private readonly IEstimation[] _markEstimation; private readonly IStateSpace _stateSpace; - private bool _updateConditionalIntensities = true; - - private Tensor[] _markStateSpaceKernelEstimates = []; - private Tensor _markConditionalIntensities = empty(0); - private Tensor[] _channelEstimates = []; - private Tensor _channelConditionalIntensities = empty(0); + private bool _updateIntensities = true; + private Tensor _markIntensities = empty(0); + private Tensor _channelIntensities = empty(0); private Tensor _observationDensity = empty(0); + private Tensor[] _channelEstimates = []; private Tensor _spikeCounts = empty(0); private Tensor _samples = empty(0); private Tensor _rates = empty(0); + private readonly int _markDimensions; - private int _markChannels; - private readonly Action _markFitMethod; - private readonly Func _estimateMarkConditionalIntensityMethod; - private readonly Func _estimateMarkStateSpaceKernelMethod; + private readonly int _markChannels; /// /// Initializes a new instance of the class. @@ -92,11 +85,11 @@ public ClusterlessMarkEncoder( _markChannels = markChannels; _stateSpace = stateSpace; - _channelEstimation = new IEstimation[_markChannels]; _markEstimation = new IEstimation[_markChannels]; - _channelEstimates = new Tensor[_markChannels]; - _markStateSpaceKernelEstimates = new Tensor[_markChannels]; + + var bandwidth = observationBandwidth.Concat(markBandwidth).ToArray(); + var jointDimensions = _stateSpace.Dimensions + _markDimensions; switch (estimationMethod) { @@ -111,28 +104,18 @@ public ClusterlessMarkEncoder( for (int i = 0; i < _markChannels; i++) { - _channelEstimation[i] = new KernelDensity( - bandwidth: observationBandwidth, - dimensions: _stateSpace.Dimensions, - device: device, - scalarType: scalarType - ); - _markEstimation[i] = new KernelDensity( - bandwidth: markBandwidth, - dimensions: _markDimensions, + bandwidth: bandwidth, + dimensions: jointDimensions, device: device, scalarType: scalarType ); } - _markFitMethod = FitMarksFactoredMethod; - _estimateMarkConditionalIntensityMethod = EstimateMarksFactoredMethod; - _estimateMarkStateSpaceKernelMethod = (_) => empty(0); - break; case EstimationMethod.KernelCompression: + _observationEstimation = new KernelCompression( bandwidth: observationBandwidth, dimensions: _stateSpace.Dimensions, @@ -142,20 +125,8 @@ public ClusterlessMarkEncoder( scalarType: scalarType ); - var bandwidth = observationBandwidth.Concat(markBandwidth).ToArray(); - var jointDimensions = _stateSpace.Dimensions + _markDimensions; - for (int i = 0; i < _markChannels; i++) { - _channelEstimation[i] = new KernelCompression( - bandwidth: observationBandwidth, - dimensions: _stateSpace.Dimensions, - distanceThreshold: distanceThreshold, - kernelLimit: kernelLimit, - device: device, - scalarType: scalarType - ); - _markEstimation[i] = new KernelCompression( bandwidth: bandwidth, dimensions: jointDimensions, @@ -166,64 +137,11 @@ public ClusterlessMarkEncoder( ); } - _markFitMethod = FitMarksUnfactoredMethod; - _estimateMarkConditionalIntensityMethod = EstimateMarksUnfactoredMethod; - _estimateMarkStateSpaceKernelMethod = (IEstimation estimation) => - estimation.Estimate(_stateSpace.Points, 0, _stateSpace.Dimensions); - break; default: throw new ArgumentException("Invalid estimation method."); }; - - _estimations = [_observationEstimation, .. _channelEstimation, .. _markEstimation]; - } - - private void FitMarksFactoredMethod( - int i, - Tensor observations, - Tensor marks - ) - { - _markEstimation[i].Fit(marks); - } - - private void FitMarksUnfactoredMethod( - int i, - Tensor observations, - Tensor marks - ) - { - _markEstimation[i].Fit(concat([observations, marks], dim: 1)); - } - - private Tensor EstimateMarksFactoredMethod( - int i, - Tensor marks - ) - { - using var _ = NewDisposeScope(); - var markKernelEstimate = _markEstimation[i].Estimate(marks); - var markDensity = markKernelEstimate.matmul(_channelEstimates[i].T) - / markKernelEstimate.size(1); - return (_rates[i] + markDensity.log() - _observationDensity) - .nan_to_num() - .MoveToOuterDisposeScope(); - } - - private Tensor EstimateMarksUnfactoredMethod( - int i, - Tensor marks - ) - { - using var _ = NewDisposeScope(); - var markKernelEstimate = _markEstimation[i].Estimate(marks, _stateSpace.Dimensions); - var markDensity = markKernelEstimate.matmul(_markStateSpaceKernelEstimates[i].T) - / markKernelEstimate.size(1); - return (_rates[i] + markDensity.log() - _observationDensity) - .nan_to_num() - .MoveToOuterDisposeScope(); } /// @@ -248,7 +166,8 @@ public void Encode(Tensor observations, Tensor marks) if (_spikeCounts.numel() == 0) { - _spikeCounts = (marks.sum(dim: 1) > 0) + _spikeCounts = (~marks.isnan()) + .any(dim: 1) .sum(dim: 0) .to(_device); _samples = tensor(observations.size(0), device: _device); @@ -256,7 +175,8 @@ public void Encode(Tensor observations, Tensor marks) } else { - _spikeCounts += (marks.sum(dim: 1) > 0) + _spikeCounts += (~marks.isnan()) + .any(dim: 1) .sum(dim: 0); _samples += observations.size(0); } @@ -267,28 +187,28 @@ public void Encode(Tensor observations, Tensor marks) for (int i = 0; i < _markChannels; i++) { - if (mask[TensorIndex.Colon, i].sum().item() == 0) + if ((~mask[TensorIndex.Colon, i].any()).item()) { continue; } - var channelObservation = observations[mask[TensorIndex.Colon, i]]; - var markObservation = marks[TensorIndex.Tensor(mask[TensorIndex.Colon, i]), TensorIndex.Colon, i]; - _channelEstimation[i].Fit(channelObservation); - _markFitMethod(i, channelObservation, markObservation); + _markEstimation[i].Fit( + concat([ + observations[mask[TensorIndex.Colon, i]], + marks[TensorIndex.Tensor(mask[TensorIndex.Colon, i]), TensorIndex.Colon, i] + ], dim: 1) + ); } - _updateConditionalIntensities = true; - _channelConditionalIntensities = Evaluate() - .First() - .MoveToOuterDisposeScope(); + _updateIntensities = true; + Evaluate(); } - private Tensor EvaluateMarkConditionalIntensities(Tensor inputs) + private void EvaluateMarkIntensities(Tensor inputs) { using var _ = NewDisposeScope(); - var markConditionalIntensities = zeros( + _markIntensities = zeros( [_markChannels, inputs.size(0), _stateSpace.Points.size(0)], device: _device, dtype: _scalarType @@ -298,30 +218,36 @@ private Tensor EvaluateMarkConditionalIntensities(Tensor inputs) for (int i = 0; i < _markChannels; i++) { - if (mask[TensorIndex.Colon, i].sum().item() == 0) + if ((~mask[TensorIndex.Colon, i].any()).item()) { continue; } - var marks = inputs[TensorIndex.Tensor(mask[TensorIndex.Colon, i]), TensorIndex.Colon, i]; - markConditionalIntensities[i, TensorIndex.Tensor(mask[TensorIndex.Colon, i])] = _estimateMarkConditionalIntensityMethod( - i, - marks + var markKernelEstimate = _markEstimation[i].Estimate( + inputs[TensorIndex.Tensor(mask[TensorIndex.Colon, i]), TensorIndex.Colon, i], + _stateSpace.Dimensions ); + + var markDensity = markKernelEstimate.matmul(_channelEstimates[i].T); + markDensity /= markDensity.sum(dim: 1, keepdim: true); + markDensity = markDensity + .log(); + + _markIntensities[i, TensorIndex.Tensor(mask[TensorIndex.Colon, i])] = _rates[i] + markDensity - _observationDensity; } - return markConditionalIntensities - .MoveToOuterDisposeScope(); + _markIntensities.MoveToOuterDisposeScope(); } - private Tensor EvaluateChannelConditionalIntensities() + private void EvaluateChannelIntensities() { using var _ = NewDisposeScope(); _observationDensity = _observationEstimation.Evaluate(_stateSpace.Points) - .log(); + .log() + .MoveToOuterDisposeScope(); - var channelConditionalIntensities = zeros( + _channelIntensities = zeros( [_markChannels, _stateSpace.Points.size(0)], device: _device, dtype: _scalarType @@ -329,43 +255,32 @@ private Tensor EvaluateChannelConditionalIntensities() for (int i = 0; i < _markChannels; i++) { - _channelEstimates[i] = _channelEstimation[i].Estimate(_stateSpace.Points) + _channelEstimates[i] = _markEstimation[i].Estimate(_stateSpace.Points, 0, _stateSpace.Dimensions) .MoveToOuterDisposeScope(); - var channelDensity = _channelEstimation[i].Normalize(_channelEstimates[i]) - .log(); + var channelDensity = _markEstimation[i].Normalize(_channelEstimates[i]); - channelConditionalIntensities[i] = exp(_rates[i] + channelDensity - _observationDensity); - - _markStateSpaceKernelEstimates[i] = _estimateMarkStateSpaceKernelMethod(_markEstimation[i]) - .MoveToOuterDisposeScope(); + _channelIntensities[i] = _rates[i] + channelDensity.log() - _observationDensity; } - _observationDensity.MoveToOuterDisposeScope(); - _updateConditionalIntensities = false; - - return channelConditionalIntensities - .MoveToOuterDisposeScope(); + _channelIntensities.MoveToOuterDisposeScope(); + _updateIntensities = false; } /// public IEnumerable Evaluate(params Tensor[] inputs) { - if (_updateConditionalIntensities) + if (_updateIntensities) { - _channelConditionalIntensities = EvaluateChannelConditionalIntensities() - .MoveToOuterDisposeScope(); + EvaluateChannelIntensities(); } if (inputs.Length > 0) { - _markConditionalIntensities = EvaluateMarkConditionalIntensities(inputs[0]) - .MoveToOuterDisposeScope(); + EvaluateMarkIntensities(inputs[0]); } - _conditionalIntensities = [_channelConditionalIntensities, _markConditionalIntensities]; - - return _conditionalIntensities; + return Intensities; } /// @@ -382,7 +297,7 @@ public override void Save(string basePath) _samples.Save(Path.Combine(path, "samples.bin")); _rates.Save(Path.Combine(path, "rates.bin")); _observationDensity.Save(Path.Combine(path, "observationDensity.bin")); - _channelConditionalIntensities.Save(Path.Combine(path, "channelConditionalIntensities.bin")); + _channelIntensities.Save(Path.Combine(path, "channelIntensities.bin")); var observationEstimationPath = Path.Combine(path, $"observationEstimation"); @@ -401,8 +316,6 @@ public override void Save(string basePath) { Directory.CreateDirectory(channelEstimationPath); } - - _channelEstimation[i].Save(channelEstimationPath); var markEstimationPath = Path.Combine(path, $"markEstimation{i}"); @@ -412,9 +325,7 @@ public override void Save(string basePath) } _markEstimation[i].Save(markEstimationPath); - _channelEstimates[i].Save(Path.Combine(path, $"channelEstimates{i}.bin")); - _markStateSpaceKernelEstimates[i].Save(Path.Combine(path, $"markStateSpaceKernelEstimates{i}.bin")); } } @@ -432,7 +343,7 @@ public override IModelComponent Load(string basePath) _samples = Tensor.Load(Path.Combine(path, "samples.bin")).to(_device); _rates = Tensor.Load(Path.Combine(path, "rates.bin")).to(_device); _observationDensity = Tensor.Load(Path.Combine(path, "observationDensity.bin")).to(_device); - _channelConditionalIntensities = Tensor.Load(Path.Combine(path, "channelConditionalIntensities.bin")).to(_device); + _channelIntensities = Tensor.Load(Path.Combine(path, "channelIntensities.bin")).to(_device); var observationEstimationPath = Path.Combine(path, $"observationEstimation"); @@ -452,8 +363,6 @@ public override IModelComponent Load(string basePath) throw new ArgumentException("The channel estimation directory does not exist."); } - _channelEstimation[i].Load(channelEstimationPath); - var markEstimationPath = Path.Combine(path, $"markEstimation{i}"); if (!Directory.Exists(markEstimationPath)) @@ -464,7 +373,6 @@ public override IModelComponent Load(string basePath) _markEstimation[i].Load(markEstimationPath); _channelEstimates[i] = Tensor.Load(Path.Combine(path, $"channelEstimates{i}.bin")).to(_device); - _markStateSpaceKernelEstimates[i] = Tensor.Load(Path.Combine(path, $"markStateSpaceKernelEstimates{i}.bin")).to(_device); } return this; From 9402c2717758a9ef647d6269dd3ca3c966718994 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:11:37 +0000 Subject: [PATCH 13/20] Updated sorted spike encoder --- .../Encoder/SortedSpikeEncoder.cs | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index 4b660ff..1febf15 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -20,16 +20,15 @@ public class SortedSpikeEncoder : ModelComponent, IEncoder /// public EncoderType EncoderType => EncoderType.SortedSpikeEncoder; - private Tensor[] _conditionalIntensities = [empty(0)]; /// - public Tensor[] ConditionalIntensities => _conditionalIntensities; + public Tensor[] Intensities => [_unitIntensities]; private IEstimation[] _estimations = []; /// public IEstimation[] Estimations => _estimations; - private Tensor _unitConditionalIntensities = empty(0); - private bool _updateConditionalIntensities = true; + private Tensor _unitIntensities = empty(0); + private bool _updateIntensities = true; private readonly int _nUnits; private readonly IEstimation[] _unitEstimation; private readonly IEstimation _observationEstimation; @@ -167,37 +166,45 @@ public void Encode(Tensor observations, Tensor inputs) _unitEstimation[i].Fit(observations[inputMask[TensorIndex.Colon, i]]); } - _updateConditionalIntensities = true; - _unitConditionalIntensities = Evaluate() - .First() - .MoveToOuterDisposeScope(); + _updateIntensities = true; + Evaluate(); } - /// - public IEnumerable Evaluate(params Tensor[] inputs) + private void EvaluateUnitIntensities() { - if (_unitConditionalIntensities.numel() != 0 && !_updateConditionalIntensities) - { - _conditionalIntensities = [_unitConditionalIntensities]; - return _conditionalIntensities; - } - using var _ = NewDisposeScope(); + var observationDensity = _observationEstimation.Evaluate(_stateSpace.Points) .log(); - var unitConditionalIntensities = new Tensor[_nUnits]; + + _unitIntensities = zeros( + [_nUnits, _stateSpace.Points.size(0)], + device: _device, + dtype: _scalarType + ); for (int i = 0; i < _nUnits; i++) { - var unitDensity = _unitEstimation[i].Evaluate(_stateSpace.Points) - .log(); - unitConditionalIntensities[i] = _rates[i] + unitDensity - observationDensity; + var unitDensity = _unitEstimation[i].Evaluate(_stateSpace.Points); + + _unitIntensities[i] = (_rates[i] + unitDensity.log() - observationDensity) + .MoveToOuterDisposeScope(); } - var output = stack(unitConditionalIntensities, dim: 0) - .MoveToOuterDisposeScope(); - _updateConditionalIntensities = false; - _conditionalIntensities = [output]; - return _conditionalIntensities; + + _unitIntensities.MoveToOuterDisposeScope(); + + _updateIntensities = false; + } + + /// + public IEnumerable Evaluate(params Tensor[] inputs) + { + if (_unitIntensities.numel() == 0 || _updateIntensities) + { + EvaluateUnitIntensities(); + } + + return Intensities; } /// @@ -213,7 +220,7 @@ public override void Save(string basePath) _spikeCounts.Save(Path.Combine(path, "spikeCounts.bin")); _samples.Save(Path.Combine(path, "samples.bin")); _rates.Save(Path.Combine(path, "rates.bin")); - _unitConditionalIntensities.Save(Path.Combine(path, "unitConditionalIntensities.bin")); + _unitIntensities.Save(Path.Combine(path, "unitIntensities.bin")); var observationEstimationPath = Path.Combine(path, $"observationEstimation"); @@ -250,7 +257,7 @@ public override IModelComponent Load(string basePath) _spikeCounts = Tensor.Load(Path.Combine(path, "spikeCounts.bin")).to(_device); _samples = Tensor.Load(Path.Combine(path, "samples.bin")).to(_device); _rates = Tensor.Load(Path.Combine(path, "rates.bin")).to(_device); - _unitConditionalIntensities = Tensor.Load(Path.Combine(path, "unitConditionalIntensities.bin")).to(_device); + _unitIntensities = Tensor.Load(Path.Combine(path, "unitIntensities.bin")).to(_device); var observationEstimationPath = Path.Combine(path, $"observationEstimation"); From bbce420df7e9c2bf2a4440feb0c2a4850a9677ef Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:13:14 +0000 Subject: [PATCH 14/20] Updated to correctly use dimensions when calculating bandwidth --- src/PointProcessDecoder.Core/Estimation/KernelDensity.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index 9f8cac8..153b574 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -120,12 +120,13 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension .MoveToOuterDisposeScope(); } var kernels = _kernels[TensorIndex.Colon, TensorIndex.Slice(dimensionStart, dimensionEnd)]; - var dist = (kernels.unsqueeze(0) - points.unsqueeze(1)) / _kernelBandwidth; + var bandwidth = _kernelBandwidth[TensorIndex.Slice(dimensionStart, dimensionEnd)]; + var dist = (kernels.unsqueeze(0) - points.unsqueeze(1)) / bandwidth; var sumSquaredDiff = dist .pow(exponent: 2) .sum(dim: -1); var estimate = exp(-0.5 * sumSquaredDiff); - var sqrtDiagonalCovariance = sqrt(pow(2 * Math.PI, _dimensions) * _kernelBandwidth.prod(dim: -1)); + var sqrtDiagonalCovariance = sqrt(pow(2 * Math.PI, _dimensions) * bandwidth.prod(dim: -1)); return (estimate / sqrtDiagonalCovariance) .to_type(_scalarType) .to(_device) From 97c64b92f395a9ed1745f2a63a0c97c21ab6aec3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 11:14:26 +0000 Subject: [PATCH 15/20] Updated clusterless likelihood --- .../Likelihood/ClusterlessLikelihood.cs | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs index d3da64d..1bfb0bc 100644 --- a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs @@ -17,20 +17,13 @@ public class ClusterlessLikelihood : ModelComponent, ILikelihood public override ScalarType ScalarType => _scalarType; private bool _ignoreNoSpikes; - private Tensor _noSpikeLikelihood; /// /// Whether to ignore the contribution of channels with no spikes to the likelihood. /// public bool IgnoreNoSpikes { get => _ignoreNoSpikes; - set - { - _ignoreNoSpikes = value; - _noSpikeLikelihood = _ignoreNoSpikes ? - zeros(1, device: _device, dtype: _scalarType) - : ones(1, dtype: _scalarType, device: _device); - } + set => _ignoreNoSpikes = value; } /// @@ -51,34 +44,36 @@ public ClusterlessLikelihood( _device = device ?? CPU; _scalarType = scalarType ?? ScalarType.Float32; _ignoreNoSpikes = ignoreNoSpikes; - _noSpikeLikelihood = _ignoreNoSpikes ? - zeros(1, dtype: _scalarType, device: _device) - : ones(1, dtype: _scalarType, device: _device); } /// public Tensor Likelihood( Tensor inputs, - IEnumerable conditionalIntensities + IEnumerable intensities ) { using var _ = NewDisposeScope(); - var channelConditionalIntensities = conditionalIntensities.ElementAt(0); - var markConditionalIntensities = conditionalIntensities.ElementAt(1); - var logLikelihood = markConditionalIntensities - .nan_to_num() - .sum(dim: 0) - channelConditionalIntensities - .nan_to_num() - .sum(dim: 0) * _noSpikeLikelihood; - logLikelihood -= logLikelihood - .max(dim: -1, keepdim: true) - .values; - logLikelihood = logLikelihood - .exp() - .nan_to_num(); - logLikelihood /= logLikelihood + + var channelIntensities = intensities.ElementAt(0); + var markIntensities = intensities.ElementAt(1); + + var likelihood = markIntensities + .sum(dim: 0); + + if (!_ignoreNoSpikes) + { + likelihood -= channelIntensities + .exp() + .sum(dim: 0); + } + + likelihood = likelihood + .exp(); + + likelihood /= likelihood .sum(dim: -1, keepdim: true); - return logLikelihood + + return likelihood .MoveToOuterDisposeScope(); } } From 24257939d473aa0aa8a22a2a33793d8ecb00658b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 16:15:18 +0000 Subject: [PATCH 16/20] Ensured nan to num is called on log tensors for Infinity values to convert properly. Without this, values being computed turn to NaN values --- .../Encoder/ClusterlessMarkEncoder.cs | 20 ++++++++++++++++--- .../Encoder/SortedSpikeEncoder.cs | 14 ++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index 11972fd..1149482 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -228,10 +228,16 @@ private void EvaluateMarkIntensities(Tensor inputs) _stateSpace.Dimensions ); + if (markKernelEstimate.numel() == 0) + { + continue; + } + var markDensity = markKernelEstimate.matmul(_channelEstimates[i].T); markDensity /= markDensity.sum(dim: 1, keepdim: true); markDensity = markDensity - .log(); + .log() + .nan_to_num(); _markIntensities[i, TensorIndex.Tensor(mask[TensorIndex.Colon, i])] = _rates[i] + markDensity - _observationDensity; } @@ -245,6 +251,7 @@ private void EvaluateChannelIntensities() _observationDensity = _observationEstimation.Evaluate(_stateSpace.Points) .log() + .nan_to_num() .MoveToOuterDisposeScope(); _channelIntensities = zeros( @@ -258,9 +265,16 @@ private void EvaluateChannelIntensities() _channelEstimates[i] = _markEstimation[i].Estimate(_stateSpace.Points, 0, _stateSpace.Dimensions) .MoveToOuterDisposeScope(); - var channelDensity = _markEstimation[i].Normalize(_channelEstimates[i]); + if (_channelEstimates[i].numel() == 0) + { + continue; + } + + var channelDensity = _markEstimation[i].Normalize(_channelEstimates[i]) + .log() + .nan_to_num(); - _channelIntensities[i] = _rates[i] + channelDensity.log() - _observationDensity; + _channelIntensities[i] = _rates[i] + channelDensity - _observationDensity; } _channelIntensities.MoveToOuterDisposeScope(); diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index 1febf15..a205c00 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -175,7 +175,8 @@ private void EvaluateUnitIntensities() using var _ = NewDisposeScope(); var observationDensity = _observationEstimation.Evaluate(_stateSpace.Points) - .log(); + .log() + .nan_to_num(); _unitIntensities = zeros( [_nUnits, _stateSpace.Points.size(0)], @@ -187,8 +188,15 @@ private void EvaluateUnitIntensities() { var unitDensity = _unitEstimation[i].Evaluate(_stateSpace.Points); - _unitIntensities[i] = (_rates[i] + unitDensity.log() - observationDensity) - .MoveToOuterDisposeScope(); + if (unitDensity.numel() == 0) + { + continue; + } + + unitDensity = unitDensity.log() + .nan_to_num(); + + _unitIntensities[i] = _rates[i] + unitDensity - observationDensity; } _unitIntensities.MoveToOuterDisposeScope(); From c49c8413052ed5e5198f9742f967cbae08e8c744 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 16:16:50 +0000 Subject: [PATCH 17/20] Ensured empty tensoros are returned if no kernels are present --- .../Estimation/KernelCompression.cs | 8 ++++++-- src/PointProcessDecoder.Core/Estimation/KernelDensity.cs | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs index 0814bce..932b51c 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelCompression.cs @@ -172,7 +172,7 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension using var _ = NewDisposeScope(); if (_kernels.numel() == 0) { - return (ones([points.size(0), 1], dtype: _scalarType, device: _device) * float.NaN) + return empty(0, dtype: _scalarType, device: _device) .MoveToOuterDisposeScope(); } var kernels = _kernels[TensorIndex.Colon, TensorIndex.Slice(dimensionStart, dimensionEnd)]; @@ -192,6 +192,10 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension public Tensor Normalize(Tensor points) { using var _ = NewDisposeScope(); + if (points.numel() == 0) + { + return points; + } var density = points.sum(dim: -1) / points.size(1); density /= density.sum(); @@ -243,7 +247,7 @@ public Tensor Evaluate(Tensor points) if (_kernels.numel() == 0) { - return zeros(points.shape[0], dtype: _scalarType, device: _device); + return empty(0, dtype: _scalarType, device: _device); } using var _ = NewDisposeScope(); diff --git a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs index 153b574..e98fee1 100644 --- a/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs +++ b/src/PointProcessDecoder.Core/Estimation/KernelDensity.cs @@ -116,7 +116,7 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension using var _ = NewDisposeScope(); if (_kernels.numel() == 0) { - return (ones([points.size(0), 1], dtype: _scalarType, device: _device) * float.NaN) + return empty(0, dtype: _scalarType, device: _device) .MoveToOuterDisposeScope(); } var kernels = _kernels[TensorIndex.Colon, TensorIndex.Slice(dimensionStart, dimensionEnd)]; @@ -137,6 +137,10 @@ public Tensor Estimate(Tensor points, int? dimensionStart = null, int? dimension public Tensor Normalize(Tensor points) { using var _ = NewDisposeScope(); + if (points.numel() == 0) + { + return points; + } var density = points.sum(dim: -1) / points.size(1); density /= density.sum(); @@ -191,7 +195,7 @@ public Tensor Evaluate(Tensor points) if (_kernels.numel() == 0) { - return zeros(points.shape[0], dtype: _scalarType, device: _device); + return empty(0, dtype: _scalarType, device: _device); } using var _ = NewDisposeScope(); From e0821fd10e96876e7b4c9b8c55a5db494965dd29 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 17 Feb 2025 16:42:11 +0000 Subject: [PATCH 18/20] Update testing utilities --- .../EncoderUtilities.cs | 13 +- .../SimulationUtilities.cs | 111 +++++++++++++++--- .../SortedUnitsUtilities.cs | 2 +- 3 files changed, 106 insertions(+), 20 deletions(-) diff --git a/test/PointProcessDecoder.Test.Common/EncoderUtilities.cs b/test/PointProcessDecoder.Test.Common/EncoderUtilities.cs index e2d26e8..b23fb0f 100644 --- a/test/PointProcessDecoder.Test.Common/EncoderUtilities.cs +++ b/test/PointProcessDecoder.Test.Common/EncoderUtilities.cs @@ -95,7 +95,8 @@ public static void SortedSpikeEncoder( ); RunSortedSpikeEncoder2D( - sortedSpikeEncoder, + sortedSpikeEncoder, + evaluationSteps, position2D, spikingData, outputDirectory, @@ -249,7 +250,8 @@ public static void RunSortedSpikeEncoder1D( for (int i = 0; i < densities.shape[0]; i++) { - var density = densities[i]; + var density = densities[i] + .exp(); var density1DExpanded = vstack([arange(evaluationSteps), density]).T; var directoryScatterPlot1D = Path.Combine(encoderDirectory, "ScatterPlot1D"); @@ -283,7 +285,8 @@ public static void RunSortedSpikeEncoder1D( } public static void RunSortedSpikeEncoder2D( - IEncoder encoder, + IEncoder encoder, + long[] evaluationSteps, Tensor observations, Tensor spikes, string encoderDirectory, @@ -296,7 +299,9 @@ public static void RunSortedSpikeEncoder2D( for (int i = 0; i < densities.shape[0]; i++) { - var density = densities[i]; + var density = densities[i] + .exp() + .reshape(evaluationSteps); var directoryHeatmap2D = Path.Combine(encoderDirectory, "Heatmap2D"); Heatmap plotDensity2D = new( diff --git a/test/PointProcessDecoder.Test.Common/SimulationUtilities.cs b/test/PointProcessDecoder.Test.Common/SimulationUtilities.cs index c99cb18..724e2dd 100644 --- a/test/PointProcessDecoder.Test.Common/SimulationUtilities.cs +++ b/test/PointProcessDecoder.Test.Common/SimulationUtilities.cs @@ -37,7 +37,14 @@ public static void SpikingNeurons1D( var minPosition = 0; var maxPosition = position1D.shape[0]; - ScatterPlot plotPosition1D = new(minPosition, maxPosition, min, max, "Position1D"); + ScatterPlot plotPosition1D = new( + xMin: minPosition, + xMax: maxPosition, + yMin: min, + yMax: max, + title: "Position1D" + ); + plotPosition1D.OutputDirectory = Path.Combine(plotPosition1D.OutputDirectory, outputDirectory); plotPosition1D.Show(position1DExpandedTime); plotPosition1D.Save(png: true); @@ -52,7 +59,14 @@ public static void SpikingNeurons1D( ); var placeFieldCenters2D = concat([zeros_like(placeFieldCenters), placeFieldCenters], dim: 1); - ScatterPlot plotPlaceFieldCenters = new(-1, 1, min, max, "PlaceFieldCenters1D"); + ScatterPlot plotPlaceFieldCenters = new( + xMin: -1, + xMax: 1, + yMin: min, + yMax: max, + title: "PlaceFieldCenters1D" + ); + plotPlaceFieldCenters.OutputDirectory = Path.Combine(plotPlaceFieldCenters.OutputDirectory, outputDirectory); plotPlaceFieldCenters.Show(placeFieldCenters2D); plotPlaceFieldCenters.Save(png: true); @@ -66,14 +80,22 @@ public static void SpikingNeurons1D( device: device ); - ScatterPlot plotSpikingNeurons = new(0, position1D.shape[0], min, max, title: "SpikingNeurons1D"); + ScatterPlot plotSpikingNeurons = new( + xMin: 0, + xMax: position1D.shape[0], + yMin: min, + yMax: max, + title: "SpikingNeurons1D" + ); + plotSpikingNeurons.OutputDirectory = Path.Combine(plotSpikingNeurons.OutputDirectory, outputDirectory); var colors = Plot.Utilities.GenerateRandomColors(numNeurons, seed); for (int i = 0; i < numNeurons; i++) { - var positionsAtSpikes = position1DExpandedTime[spikingData[TensorIndex.Ellipsis, i]]; + var spikesMask = spikingData[TensorIndex.Ellipsis, i] != 0; + var positionsAtSpikes = position1DExpandedTime[spikesMask]; plotSpikingNeurons.Show(positionsAtSpikes, colors[i]); } plotSpikingNeurons.Save(png: true); @@ -110,7 +132,14 @@ public static void SpikingNeurons2D( device: device ); - ScatterPlot plotPosition2D = new(xMin, xMax, yMin, yMax, "Position2D"); + ScatterPlot plotPosition2D = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "Position2D" + ); + plotPosition2D.OutputDirectory = Path.Combine(plotPosition2D.OutputDirectory, outputDirectory); plotPosition2D.Show(position2D); plotPosition2D.Save(png: true); @@ -126,7 +155,14 @@ public static void SpikingNeurons2D( device: device ); - ScatterPlot plotPlaceFieldCenters = new(xMin, xMax, yMin, yMax, "PlaceFieldCenters"); + ScatterPlot plotPlaceFieldCenters = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "PlaceFieldCenters" + ); + plotPlaceFieldCenters.OutputDirectory = Path.Combine(plotPlaceFieldCenters.OutputDirectory, outputDirectory); plotPlaceFieldCenters.Show(placeFieldCenters); plotPlaceFieldCenters.Save(png: true); @@ -140,14 +176,22 @@ public static void SpikingNeurons2D( device: device ); - ScatterPlot plotSpikingNeurons = new(xMin, xMax, yMin, yMax, title: "SpikingNeurons"); + ScatterPlot plotSpikingNeurons = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "SpikingNeurons" + ); + plotSpikingNeurons.OutputDirectory = Path.Combine(plotSpikingNeurons.OutputDirectory, outputDirectory); var colors = Plot.Utilities.GenerateRandomColors(numNeurons, seed); for (int i = 0; i < numNeurons; i++) { - var positionsAtSpikes = position2D[spikingData[TensorIndex.Colon, i]]; + var spikesMask = spikingData[TensorIndex.Ellipsis, i] != 0; + var positionsAtSpikes = position2D[spikesMask]; plotSpikingNeurons.Show(positionsAtSpikes, colors[i]); } plotSpikingNeurons.Save(png: true); @@ -187,12 +231,26 @@ public static void SpikingNeurons2DFirstAndLastSteps( device: device ); - ScatterPlot plotPositionFirst = new(xMin, xMax, yMin, yMax, "Position2DFirst"); + ScatterPlot plotPositionFirst = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "Position2DFirst" + ); + plotPositionFirst.OutputDirectory = Path.Combine(plotPositionFirst.OutputDirectory, outputDirectory); plotPositionFirst.Show(position2D[TensorIndex.Slice(0, stepsToSeperate)]); plotPositionFirst.Save(png: true); - ScatterPlot plotPositionLast = new(xMin, xMax, yMin, yMax, "Position2DLast"); + ScatterPlot plotPositionLast = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "Position2DLast" + ); + plotPositionLast.OutputDirectory = Path.Combine(plotPositionLast.OutputDirectory, outputDirectory); plotPositionLast.Show(position2D[TensorIndex.Slice(stepsToSeperate)]); plotPositionLast.Save(png: true); @@ -208,7 +266,14 @@ public static void SpikingNeurons2DFirstAndLastSteps( device: device ); - ScatterPlot plotPlaceFieldCenters = new(xMin, xMax, yMin, yMax, "PlaceFieldCenters"); + ScatterPlot plotPlaceFieldCenters = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "PlaceFieldCenters" + ); + plotPlaceFieldCenters.OutputDirectory = Path.Combine(plotPlaceFieldCenters.OutputDirectory, outputDirectory); plotPlaceFieldCenters.Show(placeFieldCenters); plotPlaceFieldCenters.Save(png: true); @@ -222,20 +287,36 @@ public static void SpikingNeurons2DFirstAndLastSteps( device: device ); - ScatterPlot plotSpikingNeuronsFirst = new(xMin, xMax, yMin, yMax, title: "SpikingNeuronsFirst"); + ScatterPlot plotSpikingNeuronsFirst = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "SpikingNeuronsFirst" + ); + plotSpikingNeuronsFirst.OutputDirectory = Path.Combine(plotSpikingNeuronsFirst.OutputDirectory, outputDirectory); - ScatterPlot plotSpikingNeuronsLast = new(xMin, xMax, yMin, yMax, title: "SpikingNeuronsLast"); + ScatterPlot plotSpikingNeuronsLast = new( + xMin: xMin, + xMax: xMax, + yMin: yMin, + yMax: yMax, + title: "SpikingNeuronsLast" + ); + plotSpikingNeuronsLast.OutputDirectory = Path.Combine(plotSpikingNeuronsLast.OutputDirectory, outputDirectory); var colors = Plot.Utilities.GenerateRandomColors(numNeurons, seed); for (int i = 0; i < numNeurons; i++) { - var positionsAtSpikesFirst = position2D[TensorIndex.Slice(0, stepsToSeperate)][spikingData[TensorIndex.Slice(0, stepsToSeperate), i]]; + var spikesMaskFirst = spikingData[TensorIndex.Slice(0, stepsToSeperate), i] != 0; + var positionsAtSpikesFirst = position2D[TensorIndex.Slice(0, stepsToSeperate)][spikesMaskFirst]; plotSpikingNeuronsFirst.Show(positionsAtSpikesFirst, colors[i]); - var positionsAtSpikesLast = position2D[TensorIndex.Slice(stepsToSeperate)][spikingData[TensorIndex.Slice(stepsToSeperate), i]]; + var spikesMaskLast = spikingData[TensorIndex.Slice(stepsToSeperate), i] != 0; + var positionsAtSpikesLast = position2D[TensorIndex.Slice(stepsToSeperate)][spikesMaskLast]; plotSpikingNeuronsLast.Show(positionsAtSpikesLast, colors[i]); } diff --git a/test/PointProcessDecoder.Test.Common/SortedUnitsUtilities.cs b/test/PointProcessDecoder.Test.Common/SortedUnitsUtilities.cs index 82c8b74..725e460 100644 --- a/test/PointProcessDecoder.Test.Common/SortedUnitsUtilities.cs +++ b/test/PointProcessDecoder.Test.Common/SortedUnitsUtilities.cs @@ -171,7 +171,7 @@ public static void BayesianStateSpaceSortedUnitsRealData( position = position.reshape(-1, 2); spikingData = spikingData.reshape(position.shape[0], -1) - .to_type(ScalarType.Bool); + .to_type(ScalarType.Int32); var numNeurons = (int)spikingData.shape[1]; double[] heatmapRange = [minVals[0], maxVals[0], minVals[1], maxVals[1]]; From 6cd058db1b75ec4d1070038f8f9c34cd6b712233 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 18 Feb 2025 10:41:23 +0000 Subject: [PATCH 19/20] Added ignore no spikes flag to poisson likelihood --- .../Likelihood/PoissonLikelihood.cs | 30 ++++++++++++++----- .../PointProcessModel.cs | 1 + 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs index bbfe397..71f535b 100644 --- a/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/PoissonLikelihood.cs @@ -10,7 +10,8 @@ namespace PointProcessDecoder.Core.Likelihood; /// public class PoissonLikelihood( Device? device = null, - ScalarType? scalarType = null + ScalarType? scalarType = null, + bool ignoreNoSpikes = false ) : ModelComponent, ILikelihood { private readonly Device _device = device ?? CPU; @@ -21,6 +22,16 @@ public class PoissonLikelihood( /// public override ScalarType ScalarType => _scalarType; + private bool _ignoreNoSpikes = ignoreNoSpikes; + /// + /// Whether to ignore the contribution of no spikes to the likelihood. + /// + public bool IgnoreNoSpikes + { + get => _ignoreNoSpikes; + set => _ignoreNoSpikes = value; + } + /// public LikelihoodType LikelihoodType => LikelihoodType.Poisson; @@ -35,13 +46,16 @@ IEnumerable intensities var intensity = intensities.First() .unsqueeze(0); - var likelihood = ((inputs.unsqueeze(-1) - * intensity) - - intensity.exp()) - .nan_to_num() - .sum(dim: 1) - .exp() - .nan_to_num(); + var likelihood = inputs.unsqueeze(-1) + * intensity; + + if (!_ignoreNoSpikes) { + likelihood -= intensity.exp(); + } + + likelihood = likelihood + .sum(dim: 1) + .exp(); likelihood /= likelihood .sum(dim: -1, keepdim: true); diff --git a/src/PointProcessDecoder.Core/PointProcessModel.cs b/src/PointProcessDecoder.Core/PointProcessModel.cs index 92bc2a1..fb92527 100644 --- a/src/PointProcessDecoder.Core/PointProcessModel.cs +++ b/src/PointProcessDecoder.Core/PointProcessModel.cs @@ -149,6 +149,7 @@ public PointProcessModel( _likelihood = likelihoodType switch { LikelihoodType.Poisson => new PoissonLikelihood( + ignoreNoSpikes: ignoreNoSpikes, device: _device, scalarType: _scalarType ), From 8fdfbc615797b8887d6f24e134c624b1cac644ec Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 18 Feb 2025 10:41:52 +0000 Subject: [PATCH 20/20] Updated clusterless likelihood with class constructor syntax --- .../Likelihood/ClusterlessLikelihood.cs | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs index 1bfb0bc..2a561b1 100644 --- a/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs +++ b/src/PointProcessDecoder.Core/Likelihood/ClusterlessLikelihood.cs @@ -6,17 +6,21 @@ namespace PointProcessDecoder.Core.Likelihood; /// Represents a clusterless likelihood. /// Expected to be used when the encoder is set to the . /// -public class ClusterlessLikelihood : ModelComponent, ILikelihood +public class ClusterlessLikelihood( + Device? device = null, + ScalarType? scalarType = null, + bool ignoreNoSpikes = false +) : ModelComponent, ILikelihood { - private readonly Device _device; + private readonly Device _device = device ?? CPU; /// public override Device Device => _device; - private readonly ScalarType _scalarType; + private readonly ScalarType _scalarType = scalarType ?? ScalarType.Float32; /// public override ScalarType ScalarType => _scalarType; - private bool _ignoreNoSpikes; + private bool _ignoreNoSpikes = ignoreNoSpikes; /// /// Whether to ignore the contribution of channels with no spikes to the likelihood. /// @@ -29,23 +33,6 @@ public bool IgnoreNoSpikes /// public LikelihoodType LikelihoodType => LikelihoodType.Clusterless; - /// - /// Initializes a new instance of the class. - /// - /// - /// - /// - public ClusterlessLikelihood( - Device? device = null, - ScalarType? scalarType = null, - bool ignoreNoSpikes = false - ) - { - _device = device ?? CPU; - _scalarType = scalarType ?? ScalarType.Float32; - _ignoreNoSpikes = ignoreNoSpikes; - } - /// public Tensor Likelihood( Tensor inputs, @@ -57,17 +44,17 @@ IEnumerable intensities var channelIntensities = intensities.ElementAt(0); var markIntensities = intensities.ElementAt(1); - var likelihood = markIntensities - .sum(dim: 0); + var likelihood = markIntensities; if (!_ignoreNoSpikes) { likelihood -= channelIntensities - .exp() - .sum(dim: 0); + .unsqueeze(1) + .exp(); } likelihood = likelihood + .sum(dim: 0) .exp(); likelihood /= likelihood