-
Notifications
You must be signed in to change notification settings - Fork 5
New package to extract latents from high-dimensional data using LDS and TorchSharp #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ncguilbeault
wants to merge
90
commits into
bonsai-rx:main
Choose a base branch
from
ncguilbeault:dev/torch-lds
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
90 commits
Select commit
Hold shift + click to select a range
3504951
Added kalman filter neural latents package to the Bonsai.ML project
ncguilbeault 348c51c
Added `Bonsai.ML.Torch` project to references
ncguilbeault 8c40f3f
Added KalmanFilter class
ncguilbeault ec8e092
Updated KalmanFilter class for better matrix, vector, and scalar vali…
ncguilbeault 3d369ba
Adding classes from original scripted demo
ncguilbeault cf5a2cc
Added test repo for `Torch.LDS` to compare output with Python impleme…
ncguilbeault 1c43518
Updated Torch.LDS test for running test against Python script for est…
ncguilbeault 4905ac7
Added null checks to properties
ncguilbeault 55874bf
Updated target framework to include netstandard2.0
ncguilbeault 6d90172
Updated `CreateKalmanFilter` class to use `TensorConverter` for tenso…
ncguilbeault 3cd5cf0
Refactored to use the KalmanFilterModelManager class
ncguilbeault ece0560
Removed attempt to move EM algorithm to background process in favor o…
ncguilbeault d567ff0
Removed extra properties of structs to streamline filter and smooth p…
ncguilbeault 81bca79
Refactored kalman filter class to support static methods, to streamli…
ncguilbeault 0fbf082
Added reserve method to model manager to support model creation from …
ncguilbeault f3c7190
Added project references to tests
ncguilbeault 62333df
Removed unused variables that were previously there for plotting
ncguilbeault 927343a
Modified requirements.txt to use ssm from github instead of local
ncguilbeault 055fd9c
Corrected bug with initialization of state and covariance
ncguilbeault 5aa7fe8
Refactored EM function for improved readability
ncguilbeault 3f71c1d
Updated test to correctly compare python and bonsai tensor results
ncguilbeault 4913e21
Removed unused import
ncguilbeault e4366b2
Removed unused imports and added XML docs
ncguilbeault c205ad2
Removed convoluted lock mechanism in favor of no lock
ncguilbeault e85105e
Added cleanup to test
ncguilbeault babc70d
Updated package installs and removed requirements
ncguilbeault 872bcca
Added the package to the documentation and included a basic article f…
ncguilbeault c9267e5
Updated variable naming from state to mean to more accurately represe…
ncguilbeault fa695a1
Removed the line declaring requirements.txt is a deployment item
ncguilbeault fb0d67b
Updated test workflow to use the variable name mean instead of state
ncguilbeault 3948ccd
Added functionality to allow fine grained control over which paramete…
ncguilbeault cdd38ed
Updated package info with better description and shared package tags
ncguilbeault ef70e3d
Added `Bonsai.ML.Torch.LDS.Design` package for visualizing latents
ncguilbeault 6d41ad7
Moved color cycle class to shared `Bonsai.ML.Design` library
ncguilbeault a3d5af0
Added keyword arguments to function call for precision
ncguilbeault fcc6855
Added `StateVisualizer` class to design package to support visualizin…
ncguilbeault e872dc9
Added explicit conversion to null for empty tensors
ncguilbeault da51eb4
Removed initial values from non-static Kalman smoother
ncguilbeault 7208179
Removed explicit null conversion from empty tensor
ncguilbeault a4fa853
Updated `KalmanFilter` with to allow automatically populating null pa…
ncguilbeault 7911e31
Updated to allow parameters to contain null tensor values
ncguilbeault 5688343
Added XML docs to class
ncguilbeault ca6c3ae
Refactored to use autosize and expose plot control
ncguilbeault cfa9c80
Updated `NeuralLatentsTest` with default null values
ncguilbeault 94b7042
Added categories to class properties and renamed `ModelName` to just …
ncguilbeault da37be0
Added `ResetCombinator` to class attributes
ncguilbeault 372af60
Added generic class and interface to represent LDS state
ncguilbeault 34aa3da
Removed `ResetCombinator` attribute from classes where it is not needed
ncguilbeault a114b02
Changed naming from `xResult` to `xState` to for improved naming cons…
ncguilbeault fa0b615
Added property `Bonsai.ML.Torch.LDS.Design` project to ignore repacka…
ncguilbeault 3543865
Refactored `ExpectationMaximization` to emit values on each iteration…
ncguilbeault 9344dd5
Changed `ExpectationMaximization` operator to a type of `Combinator` …
ncguilbeault 54240cb
Updated name of `Bonsai.ML.Torch.LDS` package to `Bonsai.ML.Lds.Torch…
ncguilbeault adb27aa
Added `Bonsai.ML.Torch` using statements to classes that depend on th…
ncguilbeault ddc60a6
Added operators to save and load the parameters of a Kalman filter model
ncguilbeault 995fbc9
Added Stochastic Subspace Identification method
ncguilbeault 9959b61
Added method for `StochasticSubspaceIdentification` which is much fas…
ncguilbeault f113ed9
Refactored implementation to use explicit `KalmanFilter` property whi…
ncguilbeault c669f74
Ensure data are centered in SSID method
ncguilbeault 7c2e6f1
Updated test workflow after changing packge to use explicit model pro…
ncguilbeault da0b8a6
Refactored EM to avoid potential mismatch between numStates, numObser…
ncguilbeault 1fdd64a
Updated workflow to correctly update the parameters of the model
ncguilbeault 85814ee
Modified python test script to install packages without caching packa…
ncguilbeault caabc61
Modified test case to download data and only run Bonsai script as opp…
ncguilbeault cdc1a45
Removed redundant dependency
ncguilbeault 8d1bf46
Updated documentation with correct package naming
ncguilbeault acf9c6c
Changed struct and interface names from LdsState to full LinearDynami…
ncguilbeault 0f34ab4
Removed redundant classes for orthogonalized and smoothed states in f…
ncguilbeault 5db9d3c
Changed name from `UpdateParameters` to `UpdateKalmanFilterParameters…
ncguilbeault b59abfe
Updated `StateVisualizer` to match changes in naming
ncguilbeault a30dc36
Fixed test workflow after making changes
ncguilbeault 6e8645c
Refactored `KalmanFilter` for improved parameter validation when `num…
ncguilbeault e9571a3
Added overload to create a `LinearDynamicalSystemState` from a stream…
ncguilbeault 5280204
Refactored `SaveKalmanFilterParameters` class to save to a folder rat…
ncguilbeault 6b3be7e
Refactored `KalmanFilterParameters` class to handle validation and mo…
ncguilbeault 33b762a
Updated `LoadKalmanFilterParameters` operator to load parameters from…
ncguilbeault 4438fe7
Updated test to use `null` value in `NumStates` property
ncguilbeault b499b76
Removed unnecessary `PredictedState` struct and used `LinearDynamical…
ncguilbeault f6cf0c0
Removed extra check when calling `Validate` on parameters
ncguilbeault ac99961
Fixed issue with setting the incorrect number of iterations from the …
ncguilbeault b584d55
Used tensorhelper method to extract float instead of explicitly movin…
ncguilbeault ce80cf3
Refactored `KalmanFilter` class to rely on device and scalar type pro…
ncguilbeault e7add12
Refactored `KalmanFilterParameters` to manage operations on tensors, …
ncguilbeault 2478c08
Refactored `CreateKalmanFilter` operator to correctly pass in `Type` …
ncguilbeault d8deec5
Updated to allow specifying `Device` property when loading parameters
ncguilbeault e44e17d
Removed device and scalartype overrides in `Initialize` method
ncguilbeault 77d9b13
Refactored `LoadKalmanFilterParameters` operator to allow tensors to …
ncguilbeault 65674dd
Updated filtering step to support missing nan values
ncguilbeault 6260dc1
Updated neural latents test to use new load/save tensor method
ncguilbeault faca4ca
Updated to use `TensorOperatorConverter` class
ncguilbeault File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # Bonsai.ML.Lds.Torch - Overview | ||
|
|
||
| This package provides an implementation of the Kalman filter, Rauch-Tung-Striebel (RTS) smoother, expectation maximization (EM) algorithm, and stochastic subspace identification, developed for online filtering, smoothing, and parameter estimation from data streams in Bonsai using the TorchSharp package. | ||
|
|
||
| ## Installation Guide | ||
|
|
||
| Install the `Bonsai.ML.Lds.Torch` package from the Bonsai package manager. You will also need to follow the [instructions for setting up the Bonsai.ML.Torch package](../Torch/torch-overview.md) for running on the CPU or GPU. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...cessDecoder.Design/OxyColorPresetCycle.cs → src/Bonsai.ML.Design/OxyColorPresetCycle.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| <Project Sdk="Microsoft.NET.Sdk"> | ||
| <PropertyGroup> | ||
| <Description>Visualizers for the Bonsai.ML.Lds.Torch library.</Description> | ||
| <PackageTags>$(PackageTags) Torch LDS Design</PackageTags> | ||
| <TargetFramework>net472</TargetFramework> | ||
| <UseWindowsForms>true</UseWindowsForms> | ||
| </PropertyGroup> | ||
| <ItemGroup> | ||
| <PackageReference Include="Bonsai.Core" Version="2.9.0" /> | ||
| </ItemGroup> | ||
| <ItemGroup> | ||
| <ProjectReference Include="..\Bonsai.ML.Lds.Torch\Bonsai.ML.Lds.Torch.csproj" /> | ||
| <ProjectReference Include="..\Bonsai.ML.Design\Bonsai.ML.Design.csproj" /> | ||
| </ItemGroup> | ||
| <PropertyGroup> | ||
| <!-- This property is needed to avoid repacking the native SkiaSharp libraries, which should already be included with the Bonsai.ML.Torch --> | ||
| <ShouldIncludeNativeSkiaSharp>false</ShouldIncludeNativeSkiaSharp> | ||
| </PropertyGroup> | ||
| </Project> |
90 changes: 90 additions & 0 deletions
90
src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| using System; | ||
| using System.Reactive; | ||
| using System.Linq; | ||
| using System.Windows.Forms; | ||
| using System.Collections.Generic; | ||
|
|
||
| using Bonsai; | ||
| using Bonsai.Design; | ||
| using Bonsai.ML.Design; | ||
|
|
||
| using OxyPlot; | ||
| using OxyPlot.Series; | ||
|
|
||
| using static TorchSharp.torch; | ||
|
|
||
| [assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.ExpectationMaximizationVisualizer), | ||
| Target = typeof(Bonsai.ML.Lds.Torch.ExpectationMaximizationResult))] | ||
|
|
||
| namespace Bonsai.ML.Lds.Torch.Design; | ||
|
|
||
| /// <summary> | ||
| /// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. | ||
| /// </summary> | ||
| public class ExpectationMaximizationVisualizer : BufferedVisualizer | ||
| { | ||
| private TimeSeriesOxyPlotBase _plot; | ||
| private LineSeries _lineSeries; | ||
|
|
||
| /// <summary> | ||
| /// Gets the underlying plot control. | ||
| /// </summary> | ||
| public TimeSeriesOxyPlotBase Plot => _plot; | ||
|
|
||
| /// <inheritdoc/> | ||
| public override void Load(IServiceProvider provider) | ||
| { | ||
| _plot = new TimeSeriesOxyPlotBase() | ||
| { | ||
| Dock = DockStyle.Fill, | ||
| StartTime = DateTime.Now, | ||
| BufferData = true, | ||
| ValueLabel = "Log Likelihood" | ||
| }; | ||
|
|
||
| _lineSeries = _plot.AddNewLineSeries("Log Likelihood", OxyColors.Blue); | ||
|
|
||
| var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); | ||
| visualizerService?.AddControl(_plot); | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| public override void Show(object value) | ||
| { | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| protected override void Show(DateTime time, object value) | ||
| { | ||
| if (value is null) return; | ||
|
|
||
| if (value is not ExpectationMaximizationResult result) return; | ||
|
|
||
| var logLikelihood = result.LogLikelihood; | ||
| if (logLikelihood is null) return; | ||
|
|
||
| var ll = logLikelihood[-1].to_type(ScalarType.Float64).item<double>(); | ||
|
|
||
| _plot.AddToLineSeries( | ||
| lineSeries: _lineSeries, | ||
| time: time, | ||
| value: ll | ||
| ); | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| protected override void ShowBuffer(IList<Timestamped<object>> values) | ||
| { | ||
| base.ShowBuffer(values); | ||
| if (values.Count > 0) | ||
| { | ||
| _plot.UpdatePlot(); | ||
| } | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| public override void Unload() | ||
| { | ||
| if (!_plot.IsDisposed) _plot.Dispose(); | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| using Bonsai; | ||
|
|
||
| // General Information about an assembly is controlled through the following | ||
| // set of attributes. Change these attribute values to modify the information | ||
| // associated with an assembly. | ||
| [assembly: XmlNamespacePrefix("clr-namespace:Bonsai.ML.Lds.Torch.Design", null)] |
10 changes: 10 additions & 0 deletions
10
src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| { | ||
| "profiles": { | ||
| "Bonsai": { | ||
| "commandName": "Executable", | ||
| "executablePath": "$(BonsaiExecutablePath)", | ||
| "commandLineArgs": "--lib:\"$(TargetDir).\"", | ||
| "nativeDebugging": true | ||
| } | ||
| } | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought but given the current decision to lean more on acronyms, this could also be called
EMVisualizer.However, we do have a class already that is named
ExpectationMaximizationand we probably don't want to contract that toEM, and I do also like the verbose name as it is, so I am torn and more than happy to go either way and leave it as-is.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest we keep the full name for now and we can revisit down the line if we want to change it