Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
18153ee
Added normalization step to likelihood calculation
ncguilbeault Feb 7, 2025
d6abc99
Removed extra steps to move to outer dispose scopes
ncguilbeault Feb 7, 2025
2e4fa6e
Changed inputs to bool type in poisson likelihood calculation
ncguilbeault Feb 7, 2025
94210b0
Removed use of IDisposable interface and dispose methods since its no…
ncguilbeault Feb 12, 2025
c940529
Updated likelihood interface to the method name Likelihood instead of…
ncguilbeault Feb 12, 2025
20f8cb0
Modified to use call to tensor.size instead of tensor.shape since ten…
ncguilbeault Feb 13, 2025
c797fc8
Removed eps declaration since it is no longer needed
ncguilbeault Feb 13, 2025
b00ba63
Updated poisson likelihood calculation to remove redundant steps of n…
ncguilbeault Feb 13, 2025
5504c92
Modifed estimate methods to return NaN tensor of correct size when th…
ncguilbeault Feb 13, 2025
732655f
Updated with new naming
ncguilbeault Feb 17, 2025
3d370bb
Remove call to flatten since these should be falttened anyways
ncguilbeault Feb 17, 2025
a70fb10
Updated clusterless marks encoder
ncguilbeault Feb 17, 2025
9402c27
Updated sorted spike encoder
ncguilbeault Feb 17, 2025
bbce420
Updated to correctly use dimensions when calculating bandwidth
ncguilbeault Feb 17, 2025
97c64b9
Updated clusterless likelihood
ncguilbeault Feb 17, 2025
2425793
Ensured nan to num is called on log tensors for Infinity values to co…
ncguilbeault Feb 17, 2025
c49c841
Ensured empty tensoros are returned if no kernels are present
ncguilbeault Feb 17, 2025
e0821fd
Update testing utilities
ncguilbeault Feb 17, 2025
6cd058d
Added ignore no spikes flag to poisson likelihood
ncguilbeault Feb 18, 2025
8fdfbc6
Updated clusterless likelihood with class constructor syntax
ncguilbeault Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions src/PointProcessDecoder.Core/Decoder/StateSpaceDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public StateSpaceDecoder(
_ => throw new ArgumentException("Invalid transitions type.")
};

var n = _stateSpace.Points.shape[0];
var n = _stateSpace.Points.size(0);
_initialState = ones(n, dtype: _scalarType, device: _device) / n;
_posterior = empty(0);
}
Expand All @@ -81,23 +81,23 @@ public Tensor Decode(Tensor inputs, Tensor likelihood)
{
using var _ = NewDisposeScope();

var outputShape = new long[] { inputs.shape[0] }
var outputShape = new long[] { inputs.size(0) }
.Concat(_stateSpace.Shape)
.ToArray();

var output = zeros(outputShape, dtype: _scalarType, device: _device);

if (_posterior.numel() == 0) {
_posterior = (_initialState * likelihood[0].flatten())
_posterior = (_initialState * likelihood[0])
.nan_to_num()
.clamp_min(_eps);
_posterior /= _posterior.sum();
output[0] = _posterior.reshape(_stateSpace.Shape);
}

for (int i = 1; i < inputs.shape[0]; i++)
for (int i = 1; i < inputs.size(0); i++)
{
_posterior = (_stateTransitions.Transitions.matmul(_posterior) * likelihood[i].flatten())
_posterior = (_stateTransitions.Transitions.matmul(_posterior) * likelihood[i])
.nan_to_num()
.clamp_min(_eps);
_posterior /= _posterior.sum();
Expand All @@ -106,12 +106,4 @@ public Tensor Decode(Tensor inputs, Tensor likelihood)
_posterior.MoveToOuterDisposeScope();
return output.MoveToOuterDisposeScope();
}

/// <inheritdoc/>
public override void Dispose()
{
_stateTransitions.Dispose();
_initialState.Dispose();
_posterior.Dispose();
}
}
Loading