pyriemann/estimation.py
"""Estimation of SPD matrices."""
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.covariance import shrunk_covariance
from sklearn.metrics.pairwise import pairwise_kernels
from .spatialfilters import Xdawn
from .utils.covariance import (covariances, covariances_EP, cross_spectrum,
coherence, block_covariances)
from .utils import deprecated
def _nextpow2(i):
"""Find next power of 2."""
n = 1
while n < i:
n *= 2
return n
class Covariances(BaseEstimator, TransformerMixin):
"""Estimation of covariance matrices.
Perform a simple covariance matrix estimation for each given input.
Parameters
----------
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : dict
Any further parameters are passed directly to the covariance estimator.
See Also
--------
ERPCovariances
XdawnCovariances
"""
def __init__(self, estimator='scm', **kwds):
"""Init."""
self.estimator = estimator
self.kwds = kwds
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : Covariances instance
The Covariances instance.
"""
return self
def transform(self, X):
"""Estimate covariance matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_channels, n_channels)
Covariance matrices.
"""
covmats = covariances(X, estimator=self.estimator, **self.kwds)
return covmats
class ERPCovariances(BaseEstimator, TransformerMixin):
r"""Estimate special form covariance matrices for ERP.
Estimation of special form covariance matrix dedicated to event-related
potentials (ERP) processing.
For each class, a prototyped response is obtained by average across trials:
.. math::
\mathbf{P} = \frac{1}{m} \sum_{i=1}^{m} \mathbf{X}_i
and a super trial is built using the concatenation of :math:`\mathbf{P}`
and the trial :math:`\mathbf{X}_i`:
.. math::
\mathbf{\tilde{X}}_i = \left[ \begin{array}{c} \mathbf{P} \\
\mathbf{X}_i \end{array} \right]
This super trial :math:`\mathbf{\tilde{X}}_i` will be used for covariance
estimation.
This allows to take into account the spatial structure of the signal, as
described in [1]_.
Parameters
----------
classes : list of int | None, default=None
List of classes to take into account for prototype estimation.
If None, all classes will be accounted.
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
svd : int | None, default=None
If not None, number of components of SVD used to reduce prototype
responses.
**kwds : dict
Any further parameters are passed directly to the covariance estimator.
Attributes
----------
P_ : ndarray, shape (n_components, n_times)
If fit, prototyped responses for each class, where `n_components` is
equal to `n_classes x n_channels` if `svd` is None,
and to `n_classes x min(svd, n_channels)` otherwise.
See Also
--------
Covariances
XdawnCovariances
References
----------
.. [1] `A Plug and Play P300 BCI Using Information Geometry
<https://arxiv.org/abs/1409.0107>`_
A. Barachant, M. Congedo. Research report, 2014.
.. [2] `A New generation of Brain-Computer Interface Based on Riemannian
Geometry
<https://hal.archives-ouvertes.fr/hal-00879050>`_
M. Congedo, A. Barachant, A. Andreev. Research report, 2013.
.. [3] `Classification de potentiels evoques P300 par geometrie
riemannienne pour les interfaces cerveau-machine EEG
<https://hal.archives-ouvertes.fr/hal-00877447>`_
A. Barachant, M. Congedo, G. van Veen, and C. Jutten, 24eme colloque
GRETSI, 2013.
"""
def __init__(self, classes=None, estimator='scm', svd=None, **kwds):
"""Init."""
self.classes = classes
self.estimator = estimator
self.svd = svd
self.kwds = kwds
def fit(self, X, y):
"""Fit.
Estimate the prototyped responses for each class.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
self : ERPCovariances instance
The ERPCovariances instance.
"""
if self.svd is not None:
if not isinstance(self.svd, int):
raise TypeError('svd must be None or int')
if self.classes is not None:
classes = self.classes
else:
classes = np.unique(y)
self.P_ = []
for c in classes:
# Prototyped response for each class
P = np.mean(X[y == c], axis=0)
# Apply svd if requested
if self.svd is not None:
U, _, _ = np.linalg.svd(P)
P = U[:, 0:self.svd].T @ P
self.P_.append(P)
self.P_ = np.concatenate(self.P_, axis=0)
return self
def transform(self, X):
"""Estimate special form covariance matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_components, n_components)
Covariance matrices for ERP, where the size of matrices
`n_components` is equal to `(1 + n_classes) x n_channels` if `svd`
is None, and to `n_channels + n_classes x min(svd, n_channels)`
otherwise.
"""
covmats = covariances_EP(
X,
self.P_,
estimator=self.estimator,
**self.kwds
)
return covmats
class XdawnCovariances(BaseEstimator, TransformerMixin):
"""Estimate special form covariance matrices for ERP combined with Xdawn.
Estimation of special form covariance matrix dedicated to ERP processing
combined with `Xdawn` spatial filtering.
This is similar to :class:`pyriemann.estimation.ERPCovariances` but data
are spatially filtered with :class:`pyriemann.spatialfilters.Xdawn`.
A complete description of the method is available in [1]_.
The advantage of this estimation is to reduce dimensionality of the
covariance matrices supervisely.
Parameters
----------
nfilter : int, default=4
Number of Xdawn filters per class.
applyfilters : bool, default=True
If set to true, spatial filter are applied to the prototypes and the
signals. When set to False, filters are applied only to the ERP
prototypes allowing for a better generalization across subject and
session at the expense of dimensionality increase. In that case, the
estimation is similar to :class:`pyriemann.estimation.ERPCovariances`
with `svd=nfilter` but with more compact prototype reduction.
classes : list of int | None, default=None
list of classes to take into account for prototype estimation.
If None, all classes will be accounted.
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
xdawn_estimator : string, default='scm'
Covariance matrix estimator for `Xdawn` spatial filtering.
Should be regularized using 'lwf' or 'oas', see
:func:`pyriemann.utils.covariance.covariances`.
baseline_cov : array, shape (n_channels, n_channels) | None, default=None
Baseline covariance for `Xdawn` spatial filtering,
see :class:`pyriemann.spatialfilters.Xdawn`.
**kwds : dict
Any further parameters are passed directly to the covariance estimator.
Attributes
----------
P_ : ndarray, shape (n_classes x min(n_channels, n_filters), n_times)
If fit, the evoked response for each event type, concatenated.
See Also
--------
ERPCovariances
Xdawn
References
----------
.. [1] `MEG decoding using Riemannian Geometry and
Unsupervised classification
<https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.5131>`_
A. Barachant. Technical report with the solution of the DecMeg 2014
challenge.
"""
def __init__(self,
nfilter=4,
applyfilters=True,
classes=None,
estimator='scm',
xdawn_estimator='scm',
baseline_cov=None,
**kwds):
"""Init."""
self.applyfilters = applyfilters
self.estimator = estimator
self.xdawn_estimator = xdawn_estimator
self.classes = classes
self.nfilter = nfilter
self.baseline_cov = baseline_cov
self.kwds = kwds
def fit(self, X, y):
"""Fit.
Estimate spatial filters and prototyped response for each classes.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
self : XdawnCovariances instance
The XdawnCovariances instance.
"""
self.Xd_ = Xdawn(
nfilter=self.nfilter,
classes=self.classes,
estimator=self.xdawn_estimator,
baseline_cov=self.baseline_cov,
)
self.Xd_.fit(X, y)
self.P_ = self.Xd_.evokeds_
return self
def transform(self, X):
"""Estimate Xdawn covariance matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_components, n_components)
Covariance matrices filtered by Xdawn, where n_components is equal
to `2 x n_classes x min(n_channels, nfilter)` if `applyfilters` is
True, and to `n_channels + n_classes x min(n_channels, nfilter)`
otherwise.
"""
if self.applyfilters:
X = self.Xd_.transform(X)
covmats = covariances_EP(
X,
self.P_,
estimator=self.estimator,
**self.kwds
)
return covmats
class BlockCovariances(BaseEstimator, TransformerMixin):
"""Estimation of block covariance matrices.
Perform a block covariance estimation for each given matrix. The
resulting matrices are block diagonal matrices.
The blocks on the diagonal are calculated as individual covariance
matrices for a subset of channels using the given the estimator.
Varying block sized possible by passing a list to allow incorporation
of different modalities with different number of channels (e.g. EEG,
ECoG, LFP, EMG) with their own respective covariance matrices.
Parameters
----------
block_size : int | list of int
Sizes of individual blocks given as int for same-size block, or list
for varying block sizes.
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : dict
Any further parameters are passed directly to the covariance estimator.
Notes
-----
.. versionadded:: 0.3
See Also
--------
Covariances
"""
def __init__(self, block_size, estimator='scm', **kwds):
"""Init."""
self.estimator = estimator
self.block_size = block_size
self.kwds = kwds
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : BlockCovariances instance
The BlockCovariances instance.
"""
return self
def transform(self, X):
"""Estimate block covariance matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_channels, n_channels)
Covariance matrices.
"""
n_matrices, n_channels, n_times = X.shape
if isinstance(self.block_size, int):
n_blocks = n_channels // self.block_size
blocks = [self.block_size for b in range(n_blocks)]
elif isinstance(self.block_size, (list, np.ndarray)):
blocks = self.block_size
else:
raise ValueError("Parameter block_size must be int or list.")
return block_covariances(X, blocks, self.estimator, **self.kwds)
###############################################################################
class CrossSpectra(BaseEstimator, TransformerMixin):
"""Estimation of cross-spectral matrices.
Complex cross-spectral matrices are HPD matrices estimated as the spectrum
covariance in the frequency domain [1]_. It returns a 4-d array with a
cross-spectral matrix for each input and in each frequency bin of the
Fourier transform.
Parameters
----------
window : int, default=128
The length of the FFT window used for spectral estimation.
overlap : float, default=0.75
The percentage of overlap between window.
fmin : float | None, default=None
The minimal frequency to be returned.
fmax : float | None, default=None
The maximal frequency to be returned.
fs : float | None, default=None
The sampling frequency of the signal.
Attributes
----------
freqs_ : ndarray, shape (n_freqs,)
If transformed, the frequencies associated to cross-spectra.
None if ``fs`` is None.
Notes
-----
.. versionadded:: 0.6
See Also
--------
CoSpectra
Coherences
References
----------
.. [1] https://en.wikipedia.org/wiki/Cross-spectrum
"""
def __init__(self, window=128, overlap=0.75, fmin=None, fmax=None,
fs=None):
"""Init."""
self.window = _nextpow2(window)
self.overlap = overlap
self.fmin = fmin
self.fmax = fmax
self.fs = fs
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : CrossSpectra instance
The CrossSpectra instance.
"""
return self
def transform(self, X):
"""Estimate cross-spectral matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
X_new : ndarray, shape (n_matrices, n_channels, n_channels, n_freqs)
Cross-spectral matrices for each input and for each frequency bin.
"""
X_new = []
for i in range(len(X)):
S, freqs = cross_spectrum(
X[i],
window=self.window,
overlap=self.overlap,
fmin=self.fmin,
fmax=self.fmax,
fs=self.fs)
X_new.append(S)
self.freqs_ = freqs
return np.array(X_new)
class CoSpectra(CrossSpectra):
"""Estimation of co-spectral matrices.
Co-spectral matrices are SPD matrices estimated as the real part of the
:class:`pyriemann.estimation.CrossSpectra`. It returns a 4-d array with a
co-spectral matrix for each input and in each frequency bin of the
Fourier transform.
Parameters
----------
window : int, default=128
The length of the FFT window used for spectral estimation.
overlap : float, default=0.75
The percentage of overlap between window.
fmin : float | None, default=None
The minimal frequency to be returned.
fmax : float | None, default=None
The maximal frequency to be returned.
fs : float | None, default=None
The sampling frequency of the signal.
Attributes
----------
freqs_ : ndarray, shape (n_freqs,)
If transformed, the frequencies associated to cospectra.
None if ``fs`` is None.
See Also
--------
CrossSpectra
Coherences
"""
def transform(self, X):
"""Estimate co-spectral matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
X_new : ndarray, shape (n_matrices, n_channels, n_channels, n_freqs)
Co-spectral matrices for each input and for each frequency bin.
"""
X_new = super().transform(X)
return X_new.real
@deprecated(
"CospCovariances is deprecated and will be removed in 0.8.0; "
"please use CoSpectra."
)
class CospCovariances(CoSpectra):
pass
class Coherences(CospCovariances):
"""Estimation of squared coherence matrices.
Squared coherence matrices estimation [1]_. This method will return a 4-d
array with a squared coherence matrix estimation for each input and in
each frequency bin of the FFT.
Parameters
----------
window : int, default=128
The length of the FFT window used for spectral estimation.
overlap : float, default=0.75
The percentage of overlap between window.
fmin : float | None, default=None
the minimal frequency to be returned.
fmax : float | None, default=None
The maximal frequency to be returned.
fs : float | None, default=None
The sampling frequency of the signal.
coh : {'ordinary', 'instantaneous', 'lagged', 'imaginary'}, \
default='ordinary'
The coherence type:
* 'ordinary' for the ordinary coherence, defined in Eq.(22) of [1]_;
this normalization of cross-spectral matrices captures both in-phase
and out-of-phase correlations. However it is inflated by the
artificial in-phase (zero-lag) correlation engendered by volume
conduction.
* 'instantaneous' for the instantaneous coherence, Eq.(26) of [1]_,
capturing only in-phase correlation.
* 'lagged' for the lagged-coherence, Eq.(28) of [1]_, capturing only
out-of-phase correlation (not defined for DC and Nyquist bins).
* 'imaginary' for the imaginary coherence [2]_, Eq.(0.16) of [3]_,
capturing out-of-phase correlation but still affected by in-phase
correlation.
Attributes
----------
freqs_ : ndarray, shape (n_freqs,)
If transformed, the frequencies associated to cospectra.
None if ``fs`` is None.
Notes
-----
.. versionadded:: 0.3
See Also
--------
CrossSpectra
TimeDelayCovariances
References
----------
.. [1] `Instantaneous and lagged measurements of linear
and nonlinear dependence between groups of multivariate time series:
frequency decomposition
<https://arxiv.org/ftp/arxiv/papers/0711/0711.1455.pdf>`_
R. Pascual-Marqui. Technical report, 2007.
.. [2] `Identifying true brain interaction from EEG data using the
imaginary part of coherency
<https://doi.org/10.1016/j.clinph.2004.04.029>`_
G. Nolte, O. Bai, L. Wheaton, Z. Mari, S. Vorbach, M. Hallett.
Clinical Neurophysioly, Volume 115, Issue 10, October 2004,
Pages 2292-2307
.. [3] `Non-Parametric Synchronization Measures used in EEG
and MEG
<https://hal.archives-ouvertes.fr/hal-01868538v2>`_
M. Congedo. Technical Report, 2018.
"""
def __init__(self, window=128, overlap=0.75, fmin=None, fmax=None,
fs=None, coh='ordinary'):
"""Init."""
self.window = _nextpow2(window)
self.overlap = overlap
self.fmin = fmin
self.fmax = fmax
self.fs = fs
self.coh = coh
def transform(self, X):
"""Estimate the squared coherences matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_channels, n_channels, n_freqs)
Squared coherence matrices for each input and for each frequency
bin.
"""
out = []
for i in range(len(X)):
S, freqs = coherence(
X[i],
window=self.window,
overlap=self.overlap,
fmin=self.fmin,
fmax=self.fmax,
fs=self.fs,
coh=self.coh)
out.append(S)
self.freqs_ = freqs
return np.array(out)
class TimeDelayCovariances(BaseEstimator, TransformerMixin):
"""Estimation of covariance matrices with time delay matrices.
Time delay covariance matrices are useful to catch spectral dynamics of
the signal, similarly to the CSSP method [1]_. It is done by concatenating
time delayed version of the signal before covariance estimation.
Parameters
----------
delays : int | list of int, default=4
The delays to apply for the Hankel matrices. If `int`, it use a range
of delays up to the given value. A list of int can be given.
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : dict
Any further parameters are passed directly to the covariance estimator.
Attributes
----------
Xtd_ : ndarray, shape (n_matrices, n_channels x n_delays, n_times)
Time delay multi-channel time-series, where `n_delays` is equal to:
`delays` when it is a int, and `1 + len(delays)` when it is a list.
See Also
--------
Covariances
ERPCovariances
CospCovariances
References
----------
.. [1] `Spatio-spectral filters for improving the classification of single
trial EEG
<http://doc.ml.tu-berlin.de/bbci/publications/LemBlaCurMue05.pdf>`_
S. Lemm, B. Blankertz, B. Curio, K-R. Muller. IEEE Transactions on
Biomedical Engineering 52(9), 1541-1548, 2005.
"""
def __init__(self, delays=4, estimator='scm', **kwds):
"""Init."""
self.delays = delays
self.estimator = estimator
self.kwds = kwds
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : TimeDelayCovariances instance
The TimeDelayCovariances instance.
"""
return self
def transform(self, X):
"""Estimate the time delay covariance matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
covmats : ndarray, shape (n_matrices, n_channels x n_delays, \
n_channels x n_delays)
Time delay covariance matrices, where `n_delays` is equal to:
`delays` when it is a int, and `1 + len(delays)` when it is a list.
"""
if isinstance(self.delays, int):
delays = range(1, self.delays)
elif isinstance(self.delays, list):
delays = self.delays
else:
raise ValueError('delays must be an integer or a list')
Xtd = [X]
for d in delays:
Xtd.append(np.roll(X, d, axis=-1))
self.Xtd_ = np.concatenate(Xtd, axis=-2)
covmats = covariances(self.Xtd_, estimator=self.estimator, **self.kwds)
return covmats
@deprecated(
"HankelCovariances is deprecated and will be removed in 0.8.0; "
"please use TimeDelayCovariances."
)
class HankelCovariances(TimeDelayCovariances):
pass
###############################################################################
class Kernels(BaseEstimator, TransformerMixin):
r"""Estimation of kernel matrices between channels of time series.
Perform a kernel matrix estimation for each given time series, evaluating a
kernel function between each pair of channels (rather than between pairs of
time samples) and allowing to extract nonlinear channel relationship [1]_.
For an input time series :math:`X \in \mathbb{R}^{c \times t}`, composed of
:math:`c` channels and :math:`t` time samples, kernel function
:math:`\kappa()` is computed between channels :math:`i` and :math:`j`:
.. math::
K_{i,j} = \kappa \left( X[i], X[j] \right)
Linear kernel is related to :class:`pyriemann.estimation.Covariances` [1]_,
but this class allows to generalize to nonlinear relationships.
Parameters
----------
metric : string, default='linear'
The metric to use when computing kernel function between channels [2]_:
'linear', 'poly', 'polynomial', 'rbf', 'laplacian', 'cosine'.
n_jobs : int, default=None
The number of jobs to use for the computation [2]_. This works by
breaking down the pairwise matrix into n_jobs even slices and computing
them in parallel.
**kwds : dict
Any further parameters are passed directly to the kernel function [2]_.
See Also
--------
Covariances
Notes
-----
.. versionadded:: 0.4
References
----------
.. [1] `Beyond Covariance: Feature Representation with Nonlinear Kernel
Matrices
<https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Wang_Beyond_Covariance_Feature_ICCV_2015_paper.pdf>`_
L. Wang, J. Zhang, L. Zhou, C. Tang, W Li. ICCV, 2015.
.. [2]
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.pairwise_kernels.html
""" # noqa
def __init__(self, metric='linear', n_jobs=None, **kwds):
"""Init."""
self.metric = metric
self.n_jobs = n_jobs
self.kwds = kwds
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : Kernels instance
The Kernels instance.
"""
return self
def transform(self, X):
"""Estimate kernel matrices from time series.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_times)
Multi-channel time-series.
Returns
-------
K : ndarray, shape (n_matrices, n_channels, n_channels)
Kernel matrices.
"""
if self.metric not in [
'linear', 'poly', 'polynomial', 'rbf', 'laplacian', 'cosine'
]:
raise TypeError('Unsupported metric for kernel estimation.')
K = [
pairwise_kernels(
x,
None,
metric=self.metric,
n_jobs=self.n_jobs,
**self.kwds
) for x in X
]
return np.asarray(K)
###############################################################################
class Shrinkage(BaseEstimator, TransformerMixin):
"""Regularization of SPD matrices by shrinkage.
This transformer applies a shrinkage regularization to any SPD matrix.
It directly uses the `shrunk_covariance` function from scikit-learn [1]_,
applied on each input.
Parameters
----------
shrinkage : float, default=0.1
Coefficient in the convex combination used for the computation of the
shrunk estimate. Must be between 0 and 1.
Notes
-----
.. versionadded:: 0.2.5
References
----------
.. [1] https://scikit-learn.org/stable/modules/generated/sklearn.covariance.shrunk_covariance.html
""" # noqa
def __init__(self, shrinkage=0.1):
"""Init."""
self.shrinkage = shrinkage
def fit(self, X, y=None):
"""Fit.
Do nothing. For compatibility purpose.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y : None
Not used, here for compatibility with sklearn API.
Returns
-------
self : Shrinkage instance
The Shrinkage instance.
"""
return self
def transform(self, X):
"""Shrink and return the SPD matrices.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
Returns
-------
covmats : ndarray, shape (n_matrices, n_channels, n_channels)
Set of shrunk SPD matrices.
"""
covmats = np.zeros_like(X)
for ii, x in enumerate(X):
covmats[ii] = shrunk_covariance(x, self.shrinkage)
return covmats