alexandrebarachant/pyRiemann

View on GitHub
pyriemann/transfer/_estimators.py

Summary

Maintainability
D
2 days
Test Coverage
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import (
    BaseEstimator,
    TransformerMixin,
    is_classifier,
    is_regressor
)
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, r2_score

from ..utils.mean import mean_covariance, mean_riemann
from ..utils.distance import distance
from ..utils.base import invsqrtm, powm, sqrtm
from ..utils.geodesic import geodesic
from ..utils.utils import check_weights, check_metric
from ._rotate import _get_rotation_matrix
from ..classification import MDM
from ..preprocessing import Whitening
from ._tools import decode_domains


class TLDummy(BaseEstimator, TransformerMixin):
    """No transformation on data for transfer learning.

    No transformation of the data points between the domains.
    This is what we call the Direct Center Transfer (DCT) method.

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(self):
        pass

    def fit(self, X, y_enc):
        """Do nothing.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        self : TLDummy instance
            The TLDummy instance.
        """
        return self

    def transform(self, X, y_enc=None):
        """Do nothing.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : None
            Not used, here for compatibility with sklearn API.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Same set of SPD matrices as in the input.
        """
        return X

    def fit_transform(self, X, y_enc):
        """Do nothing.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of SPD matrices with mean in the Identity.
        """
        return self.fit(X, y_enc).transform(X, y_enc)


class TLCenter(BaseEstimator, TransformerMixin):
    """Recenter data for transfer learning.

    Recenter the data points from each domain to the Identity on manifold, ie
    make the mean of the datasets become the identity. This operation
    corresponds to a whitening step if the SPD matrices represent the spatial
    covariance matrices of multivariate signals.

    .. note::
       Using .fit() and then .transform() will give different results than
       .fit_transform(). In fact, .fit_transform() should be applied on the
       training dataset (target and source) and .transform() on the test
       partition of the target dataset.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target in ``transform()`` function:

        * if not empty, ``transform()`` recenters matrices to the specified
          target domain;
        * else, ``transform()`` recenters matrices to the last fitted domain.
    metric : str, default="riemann"
        Metric used for mean estimation. For the list of supported metrics,
        see :func:`pyriemann.utils.mean.mean_covariance`.
        Note, however, that only when using the "riemann" metric that we are
        ensured to re-center the matrices precisely to the Identity.

    Attributes
    ----------
    recenter_ : dict
        If fit, dictionary with key=domain_name and value=domain_mean.

    References
    ----------
    .. [1] `Transfer Learning: A Riemannian Geometry Framework With
        Applications to Brain–Computer Interfaces
        <https://hal.archives-ouvertes.fr/hal-01923278/>`_
        P Zanini et al, IEEE Transactions on Biomedical Engineering, vol. 65,
        no. 5, pp. 1107-1116, August, 2017
    .. [2] `Transfer Learning for Brain-Computer Interfaces:
        A Euclidean Space Data Alignment Approach
        <https://arxiv.org/abs/1808.05464>`_
        He He and Dongrui Wu, IEEE Transactions on Biomedical Engineering, 2019

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(self, target_domain, metric="riemann"):
        """Init"""
        self.target_domain = target_domain
        self.metric = metric

    def fit(self, X, y_enc, sample_weight=None):
        """Fit TLCenter.

        For each domain, calculates the mean of matrices of this domain.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        self : TLCenter instance
            The TLCenter instance.
        """
        _, _, domains = decode_domains(X, y_enc)
        n_matrices, _, _ = X.shape
        sample_weight = check_weights(sample_weight, n_matrices)

        self.recenter_ = {}
        for d in np.unique(domains):
            idx = domains == d
            self.recenter_[d] = Whitening(metric=self.metric).fit(
                X[idx], sample_weight=sample_weight[idx]
            )
        return self

    def transform(self, X, y_enc=None):
        """Re-center matrices in the target domain.

        .. note::
           This method is designed for using at test time,
           recentering all matrices in target domain.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : None
            Not used, here for compatibility with sklearn API.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of recentered SPD matrices.
        """
        # if target domain is specified, use it
        if self.target_domain != "":
            target_domain = self.target_domain
        # else, use last calibrated domain as target domain
        else:
            target_domain = list(self.recenter_.keys())[-1]

        X_rct = self.recenter_[target_domain].transform(X)
        return X_rct

    def fit_transform(self, X, y_enc, sample_weight=None):
        """Fit TLCenter and then transform matrices.

        For each domain, calculates the mean of matrices of this domain and
        then recenters them to Identity.

        .. note::
           This method is designed for using at training time. The output for
           .fit_transform() will be different than using .fit() and
           .transform() separately.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of recentered SPD matrices.
        """
        self.fit(X, y_enc, sample_weight)
        _, _, domains = decode_domains(X, y_enc)

        X_rct = np.zeros_like(X)
        for d in np.unique(domains):
            idx = domains == d
            X_rct[idx] = self.recenter_[d].transform(X[idx])
        return X_rct


class TLStretch(BaseEstimator, TransformerMixin):
    """Stretch data for transfer learning.

    Change the dispersion of the datapoints around their geometric mean
    for each dataset so that they all have the same desired value.

    .. note::
       Using .fit() and then .transform() will give different results than
       .fit_transform(). In fact, .fit_transform() should be applied on the
       training dataset (target and source) and .transform() on the test
       partition of the target dataset.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target.
    dispersion : float, default=1.0
        Target value for the dispersion of the data points.
    centered_data : bool, default=False
        Whether the data has been re-centered to the Identity beforehand.
    metric : str, default="riemann"
        Metric used for calculating the dispersion.
        For the list of supported metrics,
        see :func:`pyriemann.utils.distance.distance`.

    Attributes
    ----------
    dispersions_ : dict
        Dictionary with key=domain_name and value=domain_dispersion.

    References
    ----------
    .. [1] `Riemannian Procrustes analysis: transfer learning for
        brain-computer interfaces
        <https://hal.archives-ouvertes.fr/hal-01971856>`_
        PLC Rodrigues et al, IEEE Transactions on Biomedical Engineering,
        vol. 66, no. 8, pp. 2390-2401, December, 2018

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(self, target_domain, final_dispersion=1.0,
                 centered_data=False, metric="riemann"):
        """Init"""
        self.target_domain = target_domain
        self.final_dispersion = final_dispersion
        self.centered_data = centered_data
        self.metric = metric

    def fit(self, X, y_enc, sample_weight=None):
        """Fit TLStretch.

        Calculate the dispersion around the mean for each domain.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        self : TLStretch instance
            The TLStretch instance.
        """
        _, _, domains = decode_domains(X, y_enc)
        n_matrices, n_channels, _ = X.shape
        sample_weight = check_weights(sample_weight, n_matrices)

        self._means, self.dispersions_ = {}, {}
        for d in np.unique(domains):
            idx = domains == d
            sample_weight_d = check_weights(sample_weight[idx], np.sum(idx))
            if self.centered_data:
                self._means[d] = np.eye(n_channels)
            else:
                self._means[d] = mean_riemann(
                    X[idx], sample_weight=sample_weight_d
                )
            dist = distance(
                X[idx],
                self._means[d],
                metric=self.metric,
                squared=True,
            )
            self.dispersions_[d] = np.sum(sample_weight_d * np.squeeze(dist))

        return self

    def _center(self, X, mean):
        Mean_isqrt = invsqrtm(mean)
        return Mean_isqrt @ X @ Mean_isqrt

    def _uncenter(self, X, mean):
        Mean_sqrt = sqrtm(mean)
        return Mean_sqrt @ X @ Mean_sqrt

    def _strech(self, X, dispersion_in, dispersion_out):
        return powm(X, np.sqrt(dispersion_out / dispersion_in))

    def transform(self, X, y_enc=None):
        """Stretch the data points in the target domain.

        .. note::
           The stretching operation is properly defined only for the riemann
           metric.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : None
            Not used, here for compatibility with sklearn API.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of SPD matrices with desired final dispersion.
        """

        if not self.centered_data:
            # center matrices to Identity
            X = self._center(X, self._means[self.target_domain])

        # stretch
        X_str = self._strech(
            X, self.dispersions_[self.target_domain], self.final_dispersion
        )

        if not self.centered_data:
            # re-center back to previous mean
            X_str = self._uncenter(X_str, self._means[self.target_domain])

        return X_str

    def fit_transform(self, X, y_enc, sample_weight=None):
        """Fit TLStretch and then transform data points.

        Calculate the dispersion around the mean for each domain and then
        stretch the data points to the desired final dispersion.

        .. note::
           This method is designed for using at training time. The output for
           .fit_transform() will be different than using .fit() and
           .transform() separately.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of SPD matrices with desired final dispersion.
        """

        # used during fit, in pipeline
        self.fit(X, y_enc, sample_weight)
        _, _, domains = decode_domains(X, y_enc)

        X_str = np.zeros_like(X)
        for d in np.unique(domains):
            idx = domains == d

            if not self.centered_data:
                # re-center matrices to Identity
                X[idx] = self._center(X[idx], self._means[d])

            # stretch
            X_str[idx] = self._strech(
                X[idx], self.dispersions_[d], self.final_dispersion
            )

            if not self.centered_data:
                # re-center back to previous mean
                X_str[idx] = self._uncenter(X_str[idx], self._means[d])

        return X_str


class TLRotate(BaseEstimator, TransformerMixin):
    """Rotate data for transfer learning.

    Rotate the data points from each source domain so to match its class means
    with those from the target domain. The loss function for this matching was
    first proposed in [1]_ and the optimization procedure for mininimizing it
    follows the presentation from [2]_.

    .. note::
       The data points from each domain must have been re-centered to the
       identity before calculating the rotation.

    .. note::
       Using .fit() and then .transform() will give different results than
       .fit_transform(). In fact, .fit_transform() should be applied on the
       training dataset (target and source) and .transform() on the test
       partition of the target dataset.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target.
    weights : None | array, shape (n_classes,), default=None
        Weights to assign for each class. If None, then give the same weight
        for each class.
    metric : {"euclid", "riemann"}, default="euclid"
        Metric for the distance to minimize between class means.
    n_jobs : int, default=1
        The number of jobs to use for the computation. This works by computing
        the rotation matrix for each source domain in parallel. If -1 all CPUs
        are used.

    Attributes
    ----------
    rotations_ : dict
        Dictionary with key=domain_name and value=domain_rotation_matrix.

    References
    ----------
    .. [1] `Riemannian Procrustes analysis: transfer learning for
        brain-computer interfaces
        <https://hal.archives-ouvertes.fr/hal-01971856>`_
        PLC Rodrigues et al, IEEE Transactions on Biomedical Engineering,
        vol. 66, no. 8, pp. 2390-2401, December, 2018
    .. [2] `An introduction to optimization on smooth manifolds
        <https://www.nicolasboumal.net/book/>`_
        N. Boumal. To appear with Cambridge University Press. June, 2022

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(self, target_domain, weights=None, metric='euclid', n_jobs=1):
        """Init"""
        self.target_domain = target_domain
        self.weights = weights
        self.metric = metric
        self.n_jobs = n_jobs

    def fit(self, X, y_enc, sample_weight=None):
        """Fit TLRotate.

        Calculate the rotations matrices to transform each source domain into
        the target domain.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        self : TLRotate instance
            The TLRotate instance.
        """

        _, _, domains = decode_domains(X, y_enc)
        n_matrices, _, _ = X.shape
        sample_weight = check_weights(sample_weight, n_matrices)

        idx = domains == self.target_domain
        X_target, y_target = X[idx], y_enc[idx]
        M_target = np.stack([
            mean_riemann(X_target[y_target == label],
                         sample_weight=sample_weight[idx][y_target == label])
            for label in np.unique(y_target)
        ])

        source_domains = np.unique(domains)
        source_domains = source_domains[source_domains != self.target_domain]
        rotations = Parallel(n_jobs=self.n_jobs)(
            delayed(_get_rotation_matrix)(
                np.stack([
                    mean_riemann(
                        X[domains == d][y_enc[domains == d] == label],
                        sample_weight=sample_weight[domains == d][
                            y_enc[domains == d] == label
                        ]
                    ) for label in np.unique(y_enc[domains == d])
                ]),
                M_target,
                self.weights,
                metric=self.metric,
            ) for d in source_domains
        )

        self.rotations_ = {}
        for d, rot in zip(source_domains, rotations):
            self.rotations_[d] = rot

        return self

    def transform(self, X, y_enc=None):
        """Rotate the data points in the target domain.

        The rotations are done from source to target, so in this step the data
        points suffer no transformation at all.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : None
            Not used, here for compatibility with sklearn API.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Same set of SPD matrices as in the input.
        """

        # used during inference on target domain
        return X

    def fit_transform(self, X, y_enc, sample_weight=None):
        """Fit TLRotate and then transform data points.

        Calculate the rotation matrix for matching each source domain to the
        target domain.

        .. note::
           This method is designed for using at training time. The output for
           .fit_transform() will be different than using .fit() and
           .transform() separately.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices,), default=None
            Weights for each matrix. If None, it uses equal weights.

        Returns
        -------
        X : ndarray, shape (n_matrices, n_classes)
            Set of SPD matrices after rotation step.
        """

        # used during fit in pipeline, rotate each source domain
        self.fit(X, y_enc, sample_weight)
        _, _, domains = decode_domains(X, y_enc)

        X_rot = np.zeros_like(X)
        for d in np.unique(domains):
            idx = domains == d
            if d != self.target_domain:
                X_rot[idx] = self.rotations_[d] @ X[idx] @ self.rotations_[d].T
            else:
                X_rot[idx] = X[idx]
        return X_rot


class TLEstimator(BaseEstimator):
    """Transfer learning wrapper for estimators.

    This is a wrapper for any BaseEstimator (i.e. classifier or regressor) that
    converts extended labels used in Transfer Learning into the usual y array
    to train a classifier/regressor of choice.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target.
    estimator : BaseEstimator
        The estimator to apply on matrices. It can be any regressor or
        classifier from pyRiemann.
    domain_weight : None | dict, default=None
        Weights to combine matrices from each domain to train the estimator.
        The dict contains key=domain_name and value=weight_to_assign.
        If None, it uses equal weights.

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(self, target_domain, estimator, domain_weight=None):
        """Init."""
        self.target_domain = target_domain
        self.domain_weight = domain_weight
        self.estimator = estimator

    def fit(self, X, y_enc):
        """Fit TLEstimator.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        self : TLEstimator instance
            The TLEstimator instance.
        """
        if not (is_regressor(self.estimator) or is_classifier(self.estimator)):
            raise TypeError(
                'Estimator has to be either a classifier or a regressor.')

        X_dec, y_dec, domains = decode_domains(X, y_enc)

        if is_regressor(self.estimator):
            y_dec = y_dec.astype(float)

        if self.domain_weight is not None:
            w = np.zeros(len(X_dec))
            for d in np.unique(domains):
                w[domains == d] = self.domain_weight[d]
        else:
            w = None

        if isinstance(self.estimator, Pipeline):
            sample_weight = {}
            for step in self.estimator.steps:
                step_name = step[0]
                sample_weight[step_name + '__sample_weight'] = w
            self.estimator.fit(X_dec, y_dec, **sample_weight)
        else:
            self.estimator.fit(X_dec, y_dec, sample_weight=w)

        return self

    def predict(self, X):
        """Get the predictions.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.

        Returns
        -------
        pred : ndarray, shape (n_matrices,)
            Predictions for each matrix according to the estimator.
        """
        return self.estimator.predict(X)


class TLClassifier(TLEstimator):
    """Transfer learning wrapper for classifiers.

    This is a wrapper for any classifier that converts extended labels used in
    Transfer Learning into the usual y array to train a classifier of choice.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target.
    estimator : BaseClassifier
        The classifier to apply on matrices.
    domain_weight : None | dict, default=None
        Weights to combine matrices from each domain to train the classifier.
        The dict contains key=domain_name and value=weight_to_assign.
        If None, it uses equal weights.

    Notes
    -----
    .. versionadded:: 0.4
    """

    def fit(self, X, y_enc):
        """Fit TLClassifier.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        self : TLClassifier instance
            The TLClassifier instance.
        """
        if not is_classifier(self.estimator):
            raise TypeError('Estimator has to be a classifier.')

        return super().fit(X, y_enc)

    def predict_proba(self, X):
        """Get the probability.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.

        Returns
        -------
        pred : ndarray, shape (n_matrices, n_classes)
            Predictions for each matrix.
        """
        return self.estimator.predict_proba(X)

    def score(self, X, y_enc):
        """Return the mean accuracy on the given test data and labels.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Test set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended true labels for each matrix.

        Returns
        -------
        score : float
            Mean accuracy of self.predict(X) wrt. y.
        """
        _, y_true, _ = decode_domains(X, y_enc)
        y_pred = self.predict(X)
        return accuracy_score(y_true, y_pred)


class TLRegressor(TLEstimator):
    """Transfer learning wrapper for regressors.

    This is a wrapper for any regressor that converts extended labels used in
    Transfer Learning into the usual y array to train a regressor of choice.

    Parameters
    ----------
    target_domain : str
        Domain to consider as target.
    estimator : BaseRegressor
        The regressor to apply on matrices.
    domain_weight : None | dict, default=None
        Weights to combine matrices from each domain to train the regressor.
        The dict contains key=domain_name and value=weight_to_assign.
        If None, it uses equal weights.

    Notes
    -----
    .. versionadded:: 0.4
    """

    def fit(self, X, y_enc):
        """Fit TLRegressor.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        self : TLRegressor instance
            The TLRegressor instance.
        """
        if not is_regressor(self.estimator):
            raise TypeError('Estimator has to be a regressor.')

        return super().fit(X, y_enc)

    def score(self, X, y_enc):
        """Return the coefficient of determination of the prediction.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Test set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended true values for each matrix.

        Returns
        -------
        score : float
            R2 of self.predict(X) wrt. y.
        """
        _, y_true, _ = decode_domains(X, y_enc)
        y_pred = self.predict(X)
        return r2_score(y_true.astype(float), y_pred)


class MDWM(MDM):
    """Classification by Minimum Distance to Weighted Mean.

    Classification by nearest centroid. For each of the given classes, a
    centroid is estimated, according to the chosen metric, as a weighted mean
    of SPD matrices from the source domain, combined with the class centroid of
    the target domain [1]_ [2]_.
    For classification, a given new matrix is attibuted to the class whose
    centroid is the nearest according to the chosen metric.

    Parameters
    ----------
    domain_tradeoff : float
        Coefficient in [0,1] controlling the transfer, ie the trade-off between
        source and target domains.
        At 0, there is no transfer, only matrices acquired from the source
        domain are used.
        At 1, this is a calibration-free system as no matrices are required
        from the source domain.
    target_domain : string
        Name of the target domain in extended labels.
    metric : string | dict, default="riemann"
        Metric used for mean estimation (for the list of supported metrics,
        see :func:`pyriemann.utils.mean.mean_covariance`) and
        for distance estimation
        (see :func:`pyriemann.utils.distance.distance`).
        The metric can be a dict with two keys, "mean" and "distance"
        in order to pass different metrics.
    n_jobs : int, default=1
        The number of jobs to use for the computation. This works by computing
        each of the class centroid in parallel.
        If -1 all CPUs are used. If 1 is given, no parallel computing code is
        used at all, which is useful for debugging. For n_jobs below -1,
        (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
        are used.

    Attributes
    ----------
    classes_ : ndarray, shape (n_classes,)
        Labels for each class.
    covmeans_ : ndarray, shape (n_classes, n_channels, n_channels)
        Centroids for each class.

    See Also
    --------
    MDM

    References
    ----------
    .. [1] `Transfer learning for SSVEP-based BCI using Riemannian similarities
        between users
        <https://hal.archives-ouvertes.fr/hal-01911092/>`_
        E. Kalunga, S. Chevallier and Q. Barthelemy, in 26th European Signal
        Processing Conference (EUSIPCO), pp. 1685-1689. IEEE, 2018.
    .. [2] `Minimizing Subject-dependent Calibration for BCI with Riemannian
        Transfer Learning
        <https://hal.archives-ouvertes.fr/hal-03202360/>`_
        S. Khazem, S. Chevallier, Q. Barthelemy, K. Haroun and C. Nous, 10th
        International IEEE/EMBS Conference on Neural Engineering (NER), pp.
        523-526. IEEE, 2021.

    Notes
    -----
    .. versionadded:: 0.4
    """

    def __init__(
            self,
            domain_tradeoff,
            target_domain,
            metric="riemann",
            n_jobs=1):
        """Init."""
        self.domain_tradeoff = domain_tradeoff
        self.target_domain = target_domain
        self.metric = metric
        self.n_jobs = n_jobs

    def fit(self, X, y_enc, sample_weight=None):
        """Fit (estimates) the centroids.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices from source and target domain.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.
        sample_weight : None | ndarray, shape (n_matrices_source,), \
                default=None
            Weights for each matrix from the source domains.
            If None, it uses equal weights.

        Returns
        -------
        self : MDWM instance
            The MDWM instance.
        """
        self.metric_mean, self.metric_dist = check_metric(self.metric)

        if not 0 <= self.domain_tradeoff <= 1:
            raise ValueError(
                "Value domain_tradeoff must be included in [0, 1] (Got %d)"
                % self.domain_tradeoff)

        X_dec, y_dec, domains = decode_domains(X, y_enc)
        X_src = X_dec[domains != self.target_domain]
        y_src = y_dec[domains != self.target_domain]
        X_tgt = X_dec[domains == self.target_domain]
        y_tgt = y_dec[domains == self.target_domain]

        self.classes_ = np.unique(y_src)

        if self.domain_tradeoff != 0 and set(y_tgt) != set(y_src):
            raise ValueError(
                "Classes in source domain must match classes in target domain."
                f"Classes in source are {self.classes_} while classes in "
                f"target are {np.unique(y_tgt)}"
            )

        sample_weight = check_weights(sample_weight, X_src.shape[0])

        self.source_means_ = np.stack(
            Parallel(n_jobs=self.n_jobs)(
                delayed(mean_covariance)(
                    X_src[y_src == ll],
                    metric=self.metric_mean,
                    sample_weight=sample_weight[y_src == ll],
                ) for ll in self.classes_
            )
        )

        self.target_means_ = np.stack(
            Parallel(n_jobs=self.n_jobs)(
                delayed(mean_covariance)(
                    X_tgt[y_tgt == ll],
                    metric=self.metric_mean,
                ) for ll in self.classes_
            )
        )

        self.covmeans_ = geodesic(
            self.source_means_,
            self.target_means_,
            self.domain_tradeoff,
            metric=self.metric_mean,
        )
        return self

    def score(self, X, y_enc, sample_weight=None):
        """Return the mean accuracy on the given test data and labels.

        Parameters
        ----------
        X : ndarray, shape (n_matrices, n_channels, n_channels)
            Set of SPD matrices.
        y_enc : ndarray, shape (n_matrices,)
            Extended labels for each matrix.

        Returns
        -------
        score : float
            Mean accuracy of clf.predict(X) wrt. y_enc.
        """
        _, y_true, _ = decode_domains(X, y_enc)
        y_pred = self.predict(X)
        return accuracy_score(y_true, y_pred)