-
Notifications
You must be signed in to change notification settings - Fork 54
Reduce memory footprint of dss_line
#58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks @eort. I have looked at the changes only quickly, but I can already tell you that the CI do not pass any more (you can check that locally by running I think only a subset of those changes are beneficial (the "inplace" recentring for instance). Some others I am less sure about (for instance when you compute I will try to post proper benchmarks (computation time and memory consumption) soon to get to the bottom of this. |
|
Yeah I wasn't sure about those ones either, but for me saving memory was more critical than saving time. In any case, looking forward to those benchmarks! (Hope CI is happier now) |
Bah, I meant run I'll post the benchmarks shortly |
|
Ok, I've run a memory profiler to assess changes. Using simulated data with 100 channels and 1e6 time points @ 200 hz. With your changesMemory: DetailsLine # Mem usage Increment Occurences Line Contents
============================================================
139 3058.7 MiB 3058.7 MiB 1 @profile
140 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
141 show=False):
142 """Apply DSS to remove power line artifacts.
143
144 Implements the ZapLine algorithm described in [1]_.
145
146 Parameters
147 ----------
148 X : data, shape=(n_samples, n_chans, n_trials)
149 Input data.
150 fline : float
151 Line frequency (normalized to sfreq, if ``sfreq`` == 1).
152 sfreq : float
153 Sampling frequency (default=1, which assymes ``fline`` is normalised).
154 nremove : int
155 Number of line noise components to remove (default=1).
156 nfft : int
157 FFT size (default=1024).
158 nkeep : int
159 Number of components to keep in DSS (default=None).
160 blocksize : int
161 If not None (default), covariance is computed on blocks of
162 ``blocksize`` samples. This may improve performance for large datasets.
163 show: bool
164 If True, show DSS results (default=False).
165
166 Returns
167 -------
168 y : array, shape=(n_samples, n_chans, n_trials)
169 Denoised data.
170 artifact : array, shape=(n_samples, n_chans, n_trials)
171 Artifact
172
173 Examples
174 --------
175 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot
176 results:
177 >>> dss_line(X, 50/1000)
178
179 Removing 4 line-dominated components:
180 >>> dss_line(X, 50/1000, 4)
181
182 Truncating PCs beyond the 30th to avoid overfitting:
183 >>> dss_line(X, 50/1000, 4, nkeep=30);
184
185 Return cleaned data in y, noise in yy, do not plot:
186 >>> [y, artifact] = dss_line(X, 60/1000)
187
188 References
189 ----------
190 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
191 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029
192
193 """
194 3058.7 MiB 0.0 MiB 1 if X.shape[0] < nfft:
195 print('Reducing nfft to {}'.format(X.shape[0]))
196 nfft = X.shape[0]
197 3058.7 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X)
198 3058.7 MiB 0.0 MiB 1 if blocksize is None:
199 3058.7 MiB 0.0 MiB 1 blocksize = n_samples
200
201 # Recentre data
202 3058.7 MiB 0.0 MiB 1 X = demean(X, inplace=True)
203
204 # Cancel line_frequency and harmonics + light lowpass
205 5370.5 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline)
206
207 # X - X_filt results in the artifact plus some residual biological signal
208 # Reduce dimensionality to avoid overfitting
209 5370.5 MiB 0.0 MiB 1 if nkeep is not None:
210 cov_X_res = tscov(X - X_filt)[0]
211 V, _ = pca(cov_X_res, nkeep)
212 X_noise_pca = (X - X_filt) @ V
213 else:
214 7636.4 MiB 2265.9 MiB 1 X_noise_pca = (X - X_filt).copy()
215 7636.4 MiB 0.0 MiB 1 nkeep = n_chans
216
217 # Compute blockwise covariances of raw and biased data
218 7636.4 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int)
219 7636.4 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep))
220 7636.5 MiB 0.1 MiB 1 c1 = np.zeros((nkeep, nkeep))
221 7777.3 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep),
222 7636.5 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]:
223 # if n_trials>1, reshape to (n_samples, nkeep, n_trials)
224 7636.5 MiB 0.0 MiB 1 if X_block.ndim == 3:
225 X_block = X_block.transpose(1, 2, 0)
226
227 # bias data
228 7637.0 MiB 0.5 MiB 1 c0 += tscov(X_block)[0]
229 7777.3 MiB 140.3 MiB 1 c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0]
230
231 # DSS to isolate line components from residual
232 7778.1 MiB 0.8 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1)
233
234 7778.1 MiB 0.0 MiB 1 if show:
235 import matplotlib.pyplot as plt
236 plt.plot(pwr1 / pwr0, '.-')
237 plt.xlabel('component')
238 plt.ylabel('score')
239 plt.title('DSS to enhance line frequencies')
240 plt.show()
241
242 # Remove line components from X_noise
243 7778.1 MiB 0.0 MiB 1 idx_remove = np.arange(nremove)
244 7778.1 MiB 0.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove])
245 10056.1 MiB 2278.0 MiB 1 X_res = tsr(X - X_filt, X_artifact)[0] # project them out
246 # reconstruct clean signal
247 12322.0 MiB 2265.9 MiB 1 y = X_filt + X_res
248
249 # Power of components
250 12322.0 MiB 0.0 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0]
251 12322.1 MiB 0.0 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p))
252 # return the reconstructed clean signal, and the artifact
253 14588.0 MiB 2265.9 MiB 1 return y, X - y
Computation time: Details 458451 function calls (455790 primitive calls) in 64.340 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 64.340 64.340 /memory_profiler.py:1140(wrapper)
1 0.313 0.313 64.017 64.017 /memory_profiler.py:715(f)
1 7.629 7.629 63.703 63.703 /meegkit/meegkit/dss.py:138(dss_line)
1473/272 1.176 0.001 36.986 0.136 {built-in method numpy.core._multiarray_umath.implement_array_function}
1 4.796 4.796 29.480 29.480 /meegkit/meegkit/utils/sig.py:279(gaussfilt)
2 0.000 0.000 24.615 12.308 /numpy/fft/_pocketfft.py:49(_raw_fft)
2 24.615 12.307 24.615 12.307 {built-in method numpy.fft._pocketfft_internal.execute}
1 0.000 0.000 14.306 14.306 <__array_function__ internals>:2(fft)
1 0.000 0.000 14.306 14.306 /numpy/fft/_pocketfft.py:122(fft)
1 0.000 0.000 10.309 10.309 <__array_function__ internals>:2(ifft)
1 0.000 0.000 10.309 10.309 /numpy/fft/_pocketfft.py:219(ifft)
1 4.082 4.082 10.040 10.040 /meegkit/meegkit/tspca.py:71(tsr)
1 0.000 0.000 7.492 7.492 /meegkit/meegkit/utils/sig.py:114(smooth)
100/1 0.001 0.000 7.492 7.492 <__array_function__ internals>:2(apply_along_axis)
100/1 1.763 0.018 7.490 7.490 /numpy/lib/shape_base.py:267(apply_along_axis)
99 0.008 0.000 6.035 0.061 /meegkit/meegkit/utils/sig.py:171(_smooth1d)
99 0.008 0.000 5.997 0.061 /scipy/signal/signaltools.py:1866(lfilter)
99 0.001 0.000 5.536 0.056 /scipy/signal/signaltools.py:2038(<lambda>)
99 0.001 0.000 5.536 0.056 <__array_function__ internals>:2(convolve)
99 0.003 0.000 5.534 0.056 /numpy/core/numeric.py:753(convolve)
99 5.531 0.056 5.531 0.056 {built-in method numpy.core._multiarray_umath.correlate}
4 3.269 0.817 4.477 1.119 /meegkit/meegkit/utils/denoise.py:10(demean)
4 0.000 0.000 4.386 1.097 /meegkit/meegkit/utils/covariances.py:170(tscov)
17 3.882 0.228 3.882 0.228 {method 'copy' of 'numpy.ndarray' objects}
6 0.000 0.000 2.974 0.496 /meegkit/meegkit/utils/matrix.py:211(multishift)
2 2.283 1.142 2.765 1.382 /meegkit/meegkit/utils/denoise.py:102(wpwr)
64 0.530 0.008 1.899 0.030 /meegkit/meegkit/utils/matrix.py:653(_check_data)
259 1.757 0.007 1.757 0.007 {method 'reduce' of 'numpy.ufunc' objects}
1 0.000 0.000 1.713 1.713 /meegkit/meegkit/utils/covariances.py:103(tsxcov)
50 0.000 0.000 1.500 0.030 /meegkit/meegkit/utils/matrix.py:472(theshapeof)
122 0.002 0.000 1.484 0.012 /numpy/core/fromnumeric.py:70(_wrapreduction)
2 0.000 0.000 1.201 0.600 <__array_function__ internals>:2(einsum)
2 0.000 0.000 1.201 0.600 /numpy/core/einsumfunc.py:997(einsum)
2 1.201 0.600 1.201 0.600 {built-in method numpy.core._multiarray_umath.c_einsum}
10 0.000 0.000 1.166 0.117 <__array_function__ internals>:2(dot)Without your changes (#57 )Memory: DetailsLine # Mem usage Increment Occurences Line Contents
============================================================
138 3050.5 MiB 3050.5 MiB 1 @profile
139 def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
140 show=False):
141 """Apply DSS to remove power line artifacts.
142
143 Implements the ZapLine algorithm described in [1]_.
144
145 Parameters
146 ----------
147 X : data, shape=(n_samples, n_chans, n_trials)
148 Input data.
149 fline : float
150 Line frequency (normalized to sfreq, if ``sfreq`` == 1).
151 sfreq : float
152 Sampling frequency (default=1, which assymes ``fline`` is normalised).
153 nremove : int
154 Number of line noise components to remove (default=1).
155 nfft : int
156 FFT size (default=1024).
157 nkeep : int
158 Number of components to keep in DSS (default=None).
159 blocksize : int
160 If not None (default), covariance is computed on blocks of
161 ``blocksize`` samples. This may improve performance for large datasets.
162 show: bool
163 If True, show DSS results (default=False).
164
165 Returns
166 -------
167 y : array, shape=(n_samples, n_chans, n_trials)
168 Denoised data.
169 artifact : array, shape=(n_samples, n_chans, n_trials)
170 Artifact
171
172 Examples
173 --------
174 Apply to X, assuming line frequency=50Hz and sampling rate=1000Hz, plot
175 results:
176 >>> dss_line(X, 50/1000)
177
178 Removing 4 line-dominated components:
179 >>> dss_line(X, 50/1000, 4)
180
181 Truncating PCs beyond the 30th to avoid overfitting:
182 >>> dss_line(X, 50/1000, 4, nkeep=30);
183
184 Return cleaned data in y, noise in yy, do not plot:
185 >>> [y, artifact] = dss_line(X, 60/1000)
186
187 References
188 ----------
189 .. [1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to
190 remove power line artifacts [Preprint]. https://doi.org/10.1101/782029
191
192 """
193 3050.5 MiB 0.0 MiB 1 if X.shape[0] < nfft:
194 print('Reducing nfft to {}'.format(X.shape[0]))
195 nfft = X.shape[0]
196 3050.5 MiB 0.0 MiB 1 n_samples, n_chans, n_trials = theshapeof(X)
197 3050.5 MiB 0.0 MiB 1 if blocksize is None:
198 3050.5 MiB 0.0 MiB 1 blocksize = n_samples
199
200 # Recentre data
201 5316.5 MiB 2265.9 MiB 1 X = demean(X)
202
203 # Cancel line_frequency and harmonics + light lowpass
204 7628.2 MiB 2311.7 MiB 1 X_filt = smooth(X, sfreq / fline)
205
206 # Subtract clean data from original data. The result is the artifact plus
207 # some residual biological signal
208 9894.1 MiB 2265.9 MiB 1 X_noise = X - X_filt
209
210 # Reduce dimensionality to avoid overfitting
211 9894.1 MiB 0.0 MiB 1 if nkeep is not None:
212 cov_X_res = tscov(X_noise)[0]
213 V, _ = pca(cov_X_res, nkeep)
214 X_noise_pca = X_noise @ V
215 else:
216 12160.1 MiB 2265.9 MiB 1 X_noise_pca = X_noise.copy()
217 12160.1 MiB 0.0 MiB 1 nkeep = n_chans
218
219 # Compute blockwise covariances of raw and biased data
220 12160.1 MiB 0.0 MiB 1 n_harm = np.floor((sfreq / 2) / fline).astype(int)
221 12160.1 MiB 0.0 MiB 1 c0 = np.zeros((nkeep, nkeep))
222 12160.1 MiB 0.0 MiB 1 c1 = np.zeros((nkeep, nkeep))
223 12160.1 MiB 0.0 MiB 4 for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep),
224 12160.1 MiB 0.0 MiB 2 axis=(0, 1))[::blocksize, 0]:
225 # if n_trials>1, reshape to (n_samples, nkeep, n_trials)
226 12160.1 MiB 0.0 MiB 1 if X_block.ndim == 3:
227 X_block = X_block.transpose(1, 2, 0)
228
229 # bias data
230 4661.1 MiB -7499.0 MiB 1 X_bias = gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm)
231 6728.9 MiB 2067.8 MiB 1 c0 += tscov(X_block)[0]
232 6729.2 MiB 0.3 MiB 1 c1 += tscov(X_bias)[0]
233
234 # DSS to isolate line components from residual
235 6731.8 MiB -5428.2 MiB 1 todss, _, pwr0, pwr1 = dss0(c0, c1)
236
237 6731.8 MiB 0.0 MiB 1 if show:
238 import matplotlib.pyplot as plt
239 plt.plot(pwr1 / pwr0, '.-')
240 plt.xlabel('component')
241 plt.ylabel('score')
242 plt.title('DSS to enhance line frequencies')
243 plt.show()
244
245 # Remove line components from X_noise
246 6731.8 MiB 0.0 MiB 1 idx_remove = np.arange(nremove)
247 6754.8 MiB 23.0 MiB 1 X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove])
248 6822.9 MiB 68.1 MiB 1 X_res = tsr(X_noise, X_artifact)[0] # project them out
249
250 # reconstruct clean signal
251 12540.8 MiB 5717.9 MiB 1 y = X_filt + X_res
252 17072.3 MiB 4531.4 MiB 1 artifact = X - y
253
254 # Power of components
255 9515.1 MiB -7557.2 MiB 1 p = wpwr(X - y)[0] / wpwr(X)[0]
256 9515.3 MiB 0.2 MiB 1 print('Power of components removed by DSS: {:.2f}'.format(p))
257 9515.3 MiB 0.0 MiB 1 return y, artifact
Computation time: Details 458499 function calls (455838 primitive calls) in 99.589 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.002 0.002 99.589 99.589 /memory_profiler.py:1140(wrapper)
1 1.202 1.202 99.242 99.242 /memory_profiler.py:715(f)
1 18.173 18.173 98.039 98.039 /meegkit/meegkit/dss.py:138(dss_line)
1478/277 1.039 0.001 47.523 0.172 {built-in method numpy.core._multiarray_umath.implement_array_function}
1 3.782 3.782 37.715 37.715 /meegkit/meegkit/utils/sig.py:279(gaussfilt)
2 0.001 0.000 33.862 16.931 /numpy/fft/_pocketfft.py:49(_raw_fft)
2 33.861 16.931 33.861 16.931 {built-in method numpy.fft._pocketfft_internal.execute}
1 4.969 4.969 20.961 20.961 /meegkit/meegkit/tspca.py:71(tsr)
1 0.000 0.000 18.838 18.838 <__array_function__ internals>:2(ifft)
1 0.000 0.000 18.838 18.838 /numpy/fft/_pocketfft.py:219(ifft)
4 7.452 1.863 15.295 3.824 /meegkit/meegkit/utils/denoise.py:10(demean)
1 0.000 0.000 15.024 15.024 <__array_function__ internals>:2(fft)
1 0.000 0.000 15.024 15.024 /numpy/fft/_pocketfft.py:122(fft)
22 11.456 0.521 11.456 0.521 {method 'copy' of 'numpy.ndarray' objects}
1 0.000 0.000 7.932 7.932 /meegkit/meegkit/utils/sig.py:114(smooth)
100/1 0.001 0.000 7.932 7.932 <__array_function__ internals>:2(apply_along_axis)
100/1 1.763 0.018 7.931 7.931 /numpy/lib/shape_base.py:267(apply_along_axis)
5 0.001 0.000 6.655 1.331 /meegkit/meegkit/utils/matrix.py:497(fold)
99 0.008 0.000 6.465 0.065 /meegkit/meegkit/utils/sig.py:171(_smooth1d)
99 0.008 0.000 6.427 0.065 /scipy/signal/signaltools.py:1866(lfilter)
99 0.001 0.000 5.971 0.060 /scipy/signal/signaltools.py:2038(<lambda>)
99 0.001 0.000 5.971 0.060 <__array_function__ internals>:2(convolve)
99 0.003 0.000 5.969 0.060 /numpy/core/numeric.py:753(convolve)
99 5.966 0.060 5.966 0.060 {built-in method numpy.core._multiarray_umath.correlate}
2 3.924 1.962 5.269 2.634 /meegkit/meegkit/utils/denoise.py:93(wpwr)
4 0.001 0.000 5.217 1.304 /meegkit/meegkit/utils/covariances.py:170(tscov)
6 0.000 0.000 3.875 0.646 /meegkit/meegkit/utils/matrix.py:211(multishift)
64 0.454 0.007 2.657 0.042 /meegkit/meegkit/utils/matrix.py:652(_check_data)
50 0.000 0.000 2.263 0.045 /meegkit/meegkit/utils/matrix.py:472(theshapeof)
259 1.915 0.007 1.915 0.007 {method 'reduce' of 'numpy.ufunc' objects}
1 0.000 0.000 1.721 1.721 /meegkit/meegkit/utils/covariances.py:103(tsxcov)
385 1.698 0.004 1.698 0.004 {built-in method numpy.zeros}
122 0.002 0.000 1.627 0.013 /numpy/core/fromnumeric.py:70(_wrapreduction)
73 0.000 0.000 1.570 0.022 <__array_function__ internals>:2(iscomplex)
73 0.002 0.000 1.569 0.021 /numpy/lib/type_check.py:210(iscomplex)
2 0.000 0.000 1.223 0.612 <__array_function__ internals>:2(einsum)
2 0.000 0.000 1.223 0.612 /numpy/core/einsumfunc.py:997(einsum)
2 1.223 0.612 1.223 0.612 {built-in method numpy.core._multiarray_umath.c_einsum}
15 0.000 0.000 1.075 0.072 /meegkit/meegkit/utils/matrix.py:511(unfold)
10 0.000 0.000 1.027 0.103 <__array_function__ internals>:2(dot) |
|
Okay, that is surprising.
That's it, right? Pytest is passing now locally. The issue was the change to matrix.py |
|
I've made some changes. The only thing I reversed from your code is the multiple X-X_filt, which cannot possibly beneficial in terms of computations. If this is ok with you then let's merge this into #57 |
dss_linedss_line
|
Sure, sounds good. |
* [ENH] make dss_line() faster - add blocksize parameter to dss_line() - matmul3d works with 2d data * Update requirements.txt * require python 3.7+ * Reduce memory footprint of `dss_line` (#58) * memsaving suggestions * remove testing code * fix pep and flake * undo matrix change * undo matrix reshaping * compromise Co-authored-by: nbara <10333715+nbara@users.noreply.github.com> * doc Co-authored-by: eort <eduardxort@gmail.com>
Here the promised PR. I don't have a good estimate of the saved memory (as it never worked with the original code) but the gain is at least 50% (reduced from far more than 24gb to 15gb at the max
In a nutshell:
fftresults, but feed it directly intoifftGenerally, my motivation was quite selfish. I wanted to make the algorithm work for my data. So, while I obviously tried to not break anything, there is a chance that for other data things might not work anymore. Particularly, the edit in
matrix.pyI suggested, because I simply don't understand why the code was what was. The new version works for my purposes, but perhaps not for others.If the PR is of any use to you, I can try to add a few tests.
Overall, the results looks really nice, I think:

ps. I am not the most versatile github user, so sorry that this is a separate PR and not an upgrade to #57 (if this was desirable)