lucasmiranda42/deepof

View on GitHub
deepof/models.py

Summary

Maintainability
F
1 wk
Test Coverage
"""deep autoencoder models for unsupervised pose detection.

- VQ-VAE: a variational autoencoder with a vector quantization latent-space (https://arxiv.org/abs/1711.00937).
- VaDE: a variational autoencoder with a Gaussian mixture latent-space.
- Contrastive: an embedding model consisting of a single encoder, trained using a contrastive loss.

"""
# @author lucasmiranda42
# encoding: utf-8
# module deepof

from sklearn.mixture import GaussianMixture
from spektral.layers import CensNetConv
from tensorflow.keras import Input, Model
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.layers import Dense, GRU, RepeatVector, TimeDistributed
from tensorflow.keras.layers import LayerNormalization, Bidirectional
from tensorflow.keras.optimizers import Nadam
from typing import Any, NewType
import numpy as np
import tcn
import tensorflow as tf
import tensorflow_probability as tfp

import deepof.model_utils

tfb = tfp.bijectors
tfd = tfp.distributions
tfpl = tfp.layers

# DEFINE CUSTOM ANNOTATED TYPES #
project = NewType("deepof_project", Any)
coordinates = NewType("deepof_coordinates", Any)
table_dict = NewType("deepof_table_dict", Any)


# noinspection PyCallingNonCallable
def get_recurrent_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    gru_unroll: bool = False,
    bidirectional_merge: str = "concat",
    interaction_regularization: float = 0.0,
):
    """Return a deep recurrent neural encoder.

     Builds a neural network capable of encoding the motion tracking instances into a vector ready to be fed to
    one of the provided structured latent spaces.

    Args:
        input_shape (tuple): shape of the node features for the input data. Should be time x nodes x features.
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        latent_dim (int): dimension of the latent space.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        gru_unroll (bool): whether to unroll the GRU layers. Defaults to False.
        bidirectional_merge (str): how to merge the forward and backward GRU layers. Defaults to "concat".
        interaction_regularization (float): Regularization parameter for the interaction features.

    Returns:
        keras.Model: a keras model that can be trained to encode motion tracking instances into a vector.

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )

    else:
        x_reshaped = tf.expand_dims(x, axis=1)

    # Instantiate temporal RNN block
    encoder = deepof.model_utils.get_recurrent_block(
        x_reshaped, latent_dim, gru_unroll, bidirectional_merge
    )(x_reshaped)

    # Instantiate spatial graph block
    if use_gnn:

        # Embed edge features too
        a_encoder = deepof.model_utils.get_recurrent_block(
            a_reshaped, latent_dim, gru_unroll, bidirectional_merge
        )(a_reshaped)

        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [encoder, (laplacian, edge_laplacian, incidence), a_encoder], mask=None
        )

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        encoder = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        encoder = tf.squeeze(encoder, axis=1)

    encoder_output = tf.keras.layers.Dense(latent_dim, kernel_initializer="he_uniform")(
        encoder
    )

    return Model([x, a], encoder_output, name="recurrent_encoder")


# noinspection PyCallingNonCallable
def get_recurrent_decoder(
    input_shape: tuple,
    latent_dim: int,
    gru_unroll: bool = False,
    bidirectional_merge: str = "concat",
):
    """Return a recurrent neural decoder.

    Builds a deep neural network capable of decoding the structured latent space generated by one of the compatible
    classes into a sequence of motion tracking instances, either reconstructing the original
    input, or generating new data from given clusters.

    Args:
        input_shape (tuple): shape of the input data
        latent_dim (int): dimensionality of the latent space
        gru_unroll (bool): whether to unroll the GRU layers. Defaults to False.
        bidirectional_merge (str): how to merge the forward and backward GRU layers. Defaults to "concat".

    Returns:
        keras.Model: a keras model that can be trained to decode the latent space into a series of motion tracking
        sequences.

    """
    # Define and instantiate generator
    g = Input(shape=latent_dim)  # Decoder input, shaped as the latent space
    x = Input(shape=input_shape)  # Encoder input, used to generate an output mask
    validity_mask = tf.math.logical_not(tf.reduce_all(x == 0.0, axis=2))

    generator = RepeatVector(input_shape[0])(g)
    generator = Bidirectional(
        GRU(
            latent_dim,
            activation="tanh",
            recurrent_activation="sigmoid",
            return_sequences=True,
            unroll=gru_unroll,
            use_bias=True,
        ),
        merge_mode=bidirectional_merge,
    )(generator, mask=validity_mask)
    generator = LayerNormalization()(generator)
    generator = Bidirectional(
        GRU(
            2 * latent_dim,
            activation="tanh",
            recurrent_activation="sigmoid",
            return_sequences=True,
            unroll=gru_unroll,
            use_bias=True,
        ),
        merge_mode=bidirectional_merge,
    )(generator)
    generator = LayerNormalization()(generator)
    generator = tf.keras.layers.Conv1D(
        filters=2 * latent_dim,
        kernel_size=5,
        strides=1,
        padding="same",
        activation="relu",
        kernel_initializer=he_uniform(),
        use_bias=False,
    )(generator)
    generator = LayerNormalization()(generator)

    x_decoded = deepof.model_utils.ProbabilisticDecoder(input_shape)(
        [generator, validity_mask]
    )

    return Model([g, x], x_decoded, name="recurrent_decoder")


def get_TCN_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    conv_filters: int = 32,
    kernel_size: int = 4,
    conv_stacks: int = 2,
    conv_dilations: tuple = (1, 2, 4, 8),
    padding: str = "causal",
    use_skip_connections: bool = True,
    dropout_rate: int = 0,
    activation: str = "relu",
    interaction_regularization: float = 0.0,
):
    """Return a Temporal Convolutional Network (TCN) encoder.

    Builds a neural network that can be used to encode motion tracking instances into a
    vector. Each layer contains a residual block with a convolutional layer and a skip connection. See the following
    paper for more details: https://arxiv.org/pdf/1803.01271.pdf

    Args:
        input_shape: shape of the input data
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        latent_dim: dimensionality of the latent space
        conv_filters: number of filters in the TCN layers
        kernel_size: size of the convolutional kernels
        conv_stacks: number of TCN layers
        conv_dilations: list of dilation factors for each TCN layer
        padding: padding mode for the TCN layers
        use_skip_connections: whether to use skip connections between TCN layers
        dropout_rate: dropout rate for the TCN layers
        activation: activation function for the TCN layers
        interaction_regularization (float): Regularization parameter for the interaction features

    Returns:
        keras.Model: a keras model that can be trained to encode a sequence of motion tracking instances into a latent
        space using temporal convolutional networks.

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )

    else:
        x_reshaped = tf.expand_dims(x, axis=1)

    encoder = TimeDistributed(
        tcn.TCN(
            conv_filters,
            kernel_size,
            conv_stacks,
            conv_dilations,
            padding,
            use_skip_connections,
            dropout_rate,
            return_sequences=False,
            activation=activation,
            kernel_initializer="random_normal",
            use_batch_norm=True,
        )
    )(x_reshaped)

    # Instantiate spatial graph block
    if use_gnn:

        # Embed edge features too
        a_encoder = TimeDistributed(
            tcn.TCN(
                conv_filters,
                kernel_size,
                conv_stacks,
                conv_dilations,
                padding,
                use_skip_connections,
                dropout_rate,
                return_sequences=False,
                activation=activation,
                kernel_initializer="random_normal",
                use_batch_norm=True,
            )
        )(a_reshaped)

        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [encoder, (laplacian, edge_laplacian, incidence), a_encoder], mask=None
        )

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        encoder = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        encoder = tf.squeeze(encoder, axis=1)

    encoder = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = Dense(latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim)(encoder)

    return Model([x, a], encoder, name="TCN_encoder")


def get_TCN_decoder(
    input_shape: tuple,
    latent_dim: int,
    conv_filters: int = 64,
    kernel_size: int = 4,
    conv_stacks: int = 1,
    conv_dilations: tuple = (8, 4, 2, 1),
    padding: str = "causal",
    use_skip_connections: bool = True,
    dropout_rate: int = 0,
    activation: str = "relu",
):
    """Return a Temporal Convolutional Network (TCN) decoder.

    Builds a neural network that can be used to decode a latent space into a sequence of
    motion tracking instances. Each layer contains a residual block with a convolutional layer and a skip connection. See
    the following paper for more details: https://arxiv.org/pdf/1803.01271.pdf,

    Args:
        input_shape: shape of the input data
        latent_dim: dimensionality of the latent space
        conv_filters: number of filters in the TCN layers
        kernel_size: size of the convolutional kernels
        conv_stacks: number of TCN layers
        conv_dilations: list of dilation factors for each TCN layer
        padding: padding mode for the TCN layers
        use_skip_connections: whether to use skip connections between TCN layers
        dropout_rate: dropout rate for the TCN layers
        activation: activation function for the TCN layers

    Returns:
        keras.Model: a keras model that can be trained to decode a latent space into a sequence of motion tracking
        instances using temporal convolutional networks.

    """
    # Define and instantiate generator
    g = Input(shape=latent_dim)  # Decoder input, shaped as the latent space
    x = Input(shape=input_shape)  # Encoder input, used to generate an output mask
    validity_mask = tf.math.logical_not(tf.reduce_all(x == 0.0, axis=2))

    generator = tf.keras.layers.Dense(latent_dim)(g)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(4 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.RepeatVector(input_shape[0])(generator)

    generator = tcn.TCN(
        conv_filters,
        kernel_size,
        conv_stacks,
        conv_dilations,
        padding,
        use_skip_connections,
        dropout_rate,
        return_sequences=True,
        activation=activation,
        kernel_initializer="random_normal",
        use_batch_norm=True,
    )(generator)

    x_decoded = deepof.model_utils.ProbabilisticDecoder(input_shape)(
        [generator, validity_mask]
    )

    return Model([g, x], x_decoded, name="TCN_decoder")


# noinspection PyCallingNonCallable
def get_transformer_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    num_layers: int = 4,
    num_heads: int = 64,
    dff: int = 128,
    dropout_rate: float = 0.1,
    interaction_regularization: float = 0.0,
):
    """Build a Transformer encoder.

    Based on https://www.tensorflow.org/text/tutorials/transformer.
    Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true
    and https://arxiv.org/abs/1711.03905.

    Args:
        input_shape (tuple): shape of the input data
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        latent_dim (int): dimensionality of the latent space
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        num_layers (int): number of transformer layers to include
        num_heads (int): number of heads of the multi-head-attention layers used on the transformer encoder
        dff (int): dimensionality of the token embeddings
        dropout_rate (float): dropout rate
        interaction_regularization (float): regularization parameter for the interaction features

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )

    else:
        x_reshaped = tf.expand_dims(x, axis=1)

    transformer_embedding = TimeDistributed(
        deepof.model_utils.TransformerEncoder(
            num_layers=num_layers,
            seq_dim=input_shape[-1],
            key_dim=input_shape[-1],
            num_heads=num_heads,
            dff=dff,
            maximum_position_encoding=input_shape[0],
            rate=dropout_rate,
        )
    )(x_reshaped, training=False)
    transformer_embedding = tf.reshape(
        transformer_embedding,
        [
            -1,
            (adjacency_matrix.shape[0] if x_reshaped.shape[1] != 1 else 1),
            input_shape[0] * input_shape[1],
        ],
    )

    if use_gnn:

        # Embed edge features too
        transformer_a_embedding = TimeDistributed(
            deepof.model_utils.TransformerEncoder(
                num_layers=num_layers,
                seq_dim=input_shape[-1],
                key_dim=input_shape[-1],
                num_heads=num_heads,
                dff=dff,
                maximum_position_encoding=input_shape[0],
                rate=dropout_rate,
            )
        )(a_reshaped, training=False)

        transformer_a_embedding = tf.reshape(
            transformer_a_embedding,
            [-1, adjacency_matrix.shape[0], input_shape[0] * input_shape[1]],
        )

        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [
                transformer_embedding,
                (laplacian, edge_laplacian, incidence),
                transformer_a_embedding,
            ],
            mask=None,
        )

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        transformer_embedding = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        transformer_embedding = tf.squeeze(transformer_embedding, axis=1)

    encoder = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(
        transformer_embedding
    )
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim)(encoder)

    return tf.keras.models.Model([x, a], encoder, name="transformer_encoder")


def get_transformer_decoder(
    input_shape, latent_dim, num_layers=2, num_heads=8, dff=128, dropout_rate=0.1
):
    """Build a Transformer decoder.

    Based on https://www.tensorflow.org/text/tutorials/transformer.
    Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true
    and https://arxiv.org/abs/1711.03905.

    Args:
        input_shape (tuple): shape of the input data
        latent_dim (int): dimensionality of the latent space
        num_layers (int): number of transformer layers to include
        num_heads (int): number of heads of the multi-head-attention layers used on the transformer encoder
        dff (int): dimensionality of the token embeddings
        dropout_rate (float): dropout rate

    """
    x = tf.keras.layers.Input(input_shape)
    g = tf.keras.layers.Input([latent_dim])
    validity_mask = tf.math.logical_not(tf.reduce_all(x == 0.0, axis=2))

    generator = tf.keras.layers.Dense(latent_dim)(g)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(4 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.RepeatVector(input_shape[0])(generator)

    # Get masks for generated output
    _, look_ahead_mask, padding_mask = deepof.model_utils.create_masks(generator)

    (transformer_embedding, attention_weights,) = deepof.model_utils.TransformerDecoder(
        num_layers=num_layers,
        seq_dim=input_shape[-1],
        key_dim=input_shape[-1],
        num_heads=num_heads,
        dff=dff,
        maximum_position_encoding=input_shape[0],
        rate=dropout_rate,
    )(
        x,
        generator,
        training=False,
        look_ahead_mask=look_ahead_mask,
        padding_mask=padding_mask,
    )

    x_decoded = deepof.model_utils.ProbabilisticDecoder(input_shape)(
        [transformer_embedding, validity_mask]
    )

    return tf.keras.models.Model(
        [g, x], [x_decoded, attention_weights], name="transformer_decoder"
    )


class VectorQuantizer(tf.keras.models.Model):
    """Vector quantizer layer.

    Quantizes the input vectors into a fixed number of clusters using L2 norm. Based on
    https://arxiv.org/pdf/1509.03700.pdf, and adapted for clustering using https://arxiv.org/abs/1806.02199.
    Implementation based on https://keras.io/examples/generative/vq_vae/.

    """

    def __init__(
        self, n_components, embedding_dim, beta, kmeans_loss: float = 0.0, **kwargs
    ):
        """Initialize the VQ layer.

        Args:
            n_components (int): number of embeddings to use
            embedding_dim (int): dimensionality of the embeddings
            beta (float): beta value for the loss function
            kmeans_loss (float): regularization parameter for the Gram matrix
            **kwargs: additional arguments for the parent class

        """
        super(VectorQuantizer, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.n_components = n_components
        self.beta = beta
        self.kmeans = kmeans_loss

        # Initialize the VQ codebook
        w_init = tf.random_uniform_initializer()
        self.codebook = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.n_components), dtype="float32"
            ),
            trainable=True,
            name="vqvae_codebook",
        )

    def call(self, x):  # pragma: no cover
        """Compute the VQ layer.

        Args:
            x (tf.Tensor): input tensor

        Returns:
                x (tf.Tensor): output tensor
        """
        # Compute input shape and flatten, keeping the embedding dimension intact
        input_shape = tf.shape(x)

        # Add a disentangling penalty to the embeddings
        if self.kmeans:
            kmeans_loss = deepof.model_utils.compute_kmeans_loss(
                x, weight=self.kmeans, batch_size=input_shape[0]
            )
            self.add_loss(kmeans_loss)
            self.add_metric(kmeans_loss, name="kmeans_loss")

        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantize input using the codebook
        encoding_indices = tf.cast(
            self.get_code_indices(flattened, return_soft_counts=False), tf.int32
        )
        soft_counts = self.get_code_indices(flattened, return_soft_counts=True)

        encodings = tf.one_hot(encoding_indices, self.n_components)

        quantized = tf.matmul(encodings, self.codebook, transpose_b=True)
        quantized = tf.reshape(quantized, input_shape)

        # Compute vector quantization loss, and add it to the layer
        commitment_loss = self.beta * tf.reduce_mean(
            (tf.stop_gradient(quantized) - x) ** 2
        )
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss + codebook_loss)

        # Straight-through estimator (copy gradients through the undiferentiable layer)
        # This approach has been reported to have issues for clustering, so we use add an extra
        # reconstruction loss to ensure that the gradients can flow through the encoder.
        # quantized = x + tf.stop_gradient(quantized - x)

        return quantized, soft_counts

    # noinspection PyTypeChecker
    def get_code_indices(
        self, flattened_inputs, return_soft_counts=False
    ):  # pragma: no cover
        """Getter for the code indices at any given time.

        Args:
            flattened_inputs (tf.Tensor): flattened input tensor (encoder output)
            return_soft_counts (bool): whether to return soft counts based on the distance to the codes, instead of the code indices

        Returns:
            encoding_indices (tf.Tensor): code indices tensor with cluster assignments.
        """
        # Compute L2-norm distance between inputs and codes at a given time
        similarity = tf.matmul(flattened_inputs, self.codebook)
        distances = (
            tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
            + tf.reduce_sum(self.codebook**2, axis=0)
            - 2 * similarity
        )

        if return_soft_counts:
            # Compute soft counts based on the distance to the codes
            similarity = (1 / distances) ** 2
            soft_counts = similarity / tf.expand_dims(
                tf.reduce_sum(similarity, axis=1), axis=1
            )
            return soft_counts

        # Return index of the closest code
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


# noinspection PyCallingNonCallable
def get_vqvae(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool,
    n_components: int,
    beta: float = 1.0,
    kmeans_loss: float = 0.0,
    encoder_type: str = "recurrent",
    interaction_regularization: float = 0.0,
):
    """Build a Vector-Quantization variational autoencoder (VQ-VAE) model, adapted to the DeepOF setting.

    Args:
        input_shape (tuple): shape of the input to the encoder.
        edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
        adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
        latent_dim (int): dimension of the latent space.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        n_components (int): number of embeddings in the embedding layer.
        beta (float): beta parameter of the VQ loss.
        kmeans_loss (float): regularization parameter for the Gram matrix.
        encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
        interaction_regularization (float): Regularization parameter for the interaction features.

    Returns:
        encoder (tf.keras.Model): connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,).
        decoder (tf.keras.Model): connected decoder of the VQ-VAE model.
        grouper (tf.keras.Model): connected embedder layer of the VQ-VAE model. Outputs cluster indices of shape (batch_size,).
        vqvae (tf.keras.Model): complete VQ VAE model.

    """
    vq_layer = VectorQuantizer(
        n_components,
        latent_dim,
        beta=beta,
        kmeans_loss=kmeans_loss,
        name="vector_quantizer",
    )

    if encoder_type == "recurrent":
        encoder = get_recurrent_encoder(
            input_shape=input_shape[1:],
            edge_feature_shape=edge_feature_shape[1:],
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_recurrent_decoder(
            input_shape=input_shape[1:], latent_dim=latent_dim
        )

    elif encoder_type == "TCN":
        encoder = get_TCN_encoder(
            input_shape=input_shape[1:],
            edge_feature_shape=edge_feature_shape[1:],
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_TCN_decoder(input_shape=input_shape[1:], latent_dim=latent_dim)

    elif encoder_type == "transformer":
        encoder = get_transformer_encoder(
            input_shape[1:],
            edge_feature_shape=edge_feature_shape[1:],
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_transformer_decoder(input_shape[1:], latent_dim=latent_dim)

    # Connect encoder and quantizer
    inputs = tf.keras.layers.Input(input_shape[1:], name="encoder_input")
    a = tf.keras.layers.Input(edge_feature_shape[1:], name="encoder_edge_features")
    encoder_outputs = encoder([inputs, a])
    quantized_latents, soft_counts = vq_layer(encoder_outputs)

    # Connect full models
    encoder = tf.keras.Model([inputs, a], encoder_outputs, name="encoder")
    grouper = tf.keras.Model([inputs, a], quantized_latents, name="grouper")
    soft_grouper = tf.keras.Model([inputs, a], soft_counts, name="soft_grouper")
    vqvae = tf.keras.Model(
        grouper.inputs, decoder([grouper.outputs, inputs]), name="VQ-VAE"
    )

    models = [encoder, decoder, grouper, soft_grouper, vqvae]

    return models


class VQVAE(tf.keras.models.Model):
    """VQ-VAE model adapted to the DeepOF setting."""

    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray = None,
        latent_dim: int = 8,
        n_components: int = 15,
        beta: float = 1.0,
        kmeans_loss: float = 0.0,
        use_gnn: bool = True,
        encoder_type: str = "recurrent",
        interaction_regularization: float = 0.0,
        **kwargs,
    ):
        """Initialize a VQ-VAE model.

        Args:
            input_shape (tuple): Shape of the input to the full model.
            edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
            adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
            latent_dim (int): Dimensionality of the latent space.
            n_components (int): Number of embeddings (clusters) in the embedding layer.
            beta (float): Beta parameter of the VQ loss, as described in the original VQVAE paper.
            kmeans_loss (float): Regularization parameter for the Gram matrix.
            encoder_type (str): Type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
            interaction_regularization (float): Regularization parameter for the interaction features.
            **kwargs: Additional keyword arguments.

        """
        super(VQVAE, self).__init__(**kwargs)
        self.seq_shape = input_shape
        self.edge_feature_shape = edge_feature_shape
        self.adjacency_matrix = adjacency_matrix
        self.latent_dim = latent_dim
        self.use_gnn = use_gnn
        self.n_components = n_components
        self.beta = beta
        self.kmeans = kmeans_loss
        self.encoder_type = encoder_type
        self.interaction_regularization = interaction_regularization

        # Define VQ_VAE model
        (
            self.encoder,
            self.decoder,
            self.grouper,
            self.soft_grouper,
            self.vqvae,
        ) = get_vqvae(
            self.seq_shape,
            self.edge_feature_shape,
            self.adjacency_matrix,
            self.latent_dim,
            self.use_gnn,
            self.n_components,
            self.beta,
            self.kmeans,
            self.encoder_type,
            self.interaction_regularization,
        )

        # Define metrics to track
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.encoding_reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="encoding_reconstruction_loss"
        )
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = tf.keras.metrics.Mean(name="vq_loss")
        self.cluster_population = tf.keras.metrics.Mean(
            name="number_of_populated_clusters"
        )
        self.val_total_loss_tracker = tf.keras.metrics.Mean(name="val_total_loss")
        self.val_encoding_reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="val_encoding_reconstruction_loss"
        )
        self.val_reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="val_reconstruction_loss"
        )
        self.val_vq_loss_tracker = tf.keras.metrics.Mean(name="val_vq_loss")
        self.val_cluster_population = tf.keras.metrics.Mean(
            name="val_number_of_populated_clusters"
        )

    @tf.function
    def call(self, inputs, **kwargs):
        """Call the VQVAE model."""
        return self.vqvae(inputs, **kwargs)

    @property
    def metrics(self):  # pragma: no cover
        """Initialize VQVAE tracked metrics."""
        metrics = [
            self.total_loss_tracker,
            self.encoding_reconstruction_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
            self.cluster_population,
            self.val_total_loss_tracker,
            self.val_encoding_reconstruction_loss_tracker,
            self.val_reconstruction_loss_tracker,
            self.val_vq_loss_tracker,
            self.val_cluster_population,
        ]

        return metrics

    @tf.function
    def train_step(self, data):  # pragma: no cover
        """Perform a training step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        with tf.GradientTape() as tape:
            # Get outputs from the full model
            encoding_reconstructions = self.vqvae([x, a], training=True)
            reconstructions = self.decoder(
                [self.encoder([x, a], training=True), x], training=True
            )

            # Get rid of the attention scores that the transformer decoder outputs
            if self.encoder_type == "transformer":
                encoding_reconstructions = encoding_reconstructions[0]
                reconstructions = reconstructions[0]

            # Compute losses
            reconstruction_labels = next(y)
            encoding_reconstruction_loss = -tf.reduce_mean(
                encoding_reconstructions.log_prob(reconstruction_labels)
            )
            reconstruction_loss = -tf.reduce_mean(
                reconstructions.log_prob(reconstruction_labels)
            )

            total_loss = (
                encoding_reconstruction_loss
                + reconstruction_loss
                + sum(self.vqvae.losses)
            )

        # Backpropagation
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Compute populated clusters
        unique_indices = tf.unique(
            tf.reshape(tf.argmax(self.soft_grouper([x, a]), axis=1), [-1])
        ).y
        populated_clusters = tf.shape(unique_indices)[0]

        # Track losses
        self.total_loss_tracker.update_state(total_loss)
        self.encoding_reconstruction_loss_tracker.update_state(
            encoding_reconstruction_loss
        )
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
        self.cluster_population.update_state(populated_clusters)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.total_loss_tracker.result(),
            "encoding_reconstruction_loss": self.encoding_reconstruction_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vq_loss": self.vq_loss_tracker.result(),
            "number_of_populated_clusters": self.cluster_population.result(),
        }

        return {
            **log_dict,
            **{met.name: met.result() for met in self.vqvae.metrics},
        }

    @tf.function
    def test_step(self, data):  # pragma: no cover
        """Performs a test step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        # Get outputs from the full model
        encoding_reconstructions = self.vqvae([x, a], training=False)
        reconstructions = self.decoder(
            [self.encoder([x, a], training=False), x], training=False
        )

        # Get rid of the attention scores that the transformer decoder outputs
        if self.encoder_type == "transformer":
            encoding_reconstructions = encoding_reconstructions[0]
            reconstructions = reconstructions[0]

        # Compute losses
        reconstruction_labels = next(y)
        encoding_reconstruction_loss = -tf.reduce_mean(
            encoding_reconstructions.log_prob(reconstruction_labels)
        )
        reconstruction_loss = -tf.reduce_mean(
            reconstructions.log_prob(reconstruction_labels)
        )
        total_loss = (
            encoding_reconstruction_loss + reconstruction_loss + sum(self.vqvae.losses)
        )

        # Compute populated clusters
        unique_indices = tf.unique(
            tf.reshape(tf.argmax(self.soft_grouper([x, a]), axis=1), [-1])
        ).y
        populated_clusters = tf.shape(unique_indices)[0]

        # Track losses
        self.val_total_loss_tracker.update_state(total_loss)
        self.val_encoding_reconstruction_loss_tracker.update_state(
            encoding_reconstruction_loss
        )
        self.val_reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.val_vq_loss_tracker.update_state(sum(self.vqvae.losses))
        self.val_cluster_population.update_state(populated_clusters)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.val_total_loss_tracker.result(),
            "encoding_reconstruction_loss": self.val_encoding_reconstruction_loss_tracker.result(),
            "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
            "vq_loss": self.val_vq_loss_tracker.result(),
            "number_of_populated_clusters": self.val_cluster_population.result(),
        }

        return {
            **log_dict,
            **{met.name: met.result() for met in self.vqvae.metrics},
        }


class GaussianMixtureLatent(tf.keras.models.Model):
    """Gaussian Mixture probabilistic latent space model.

    Used to represent the embedding of motion tracking data in a mixture of Gaussians
    with a provided number of components, with means, covariances and weights.
    Implementation based on VaDE (https://arxiv.org/abs/1611.05148)
    and VaDE-SC (https://openreview.net/forum?id=RQ428ZptQfU).

    """

    def __init__(
        self,
        input_shape: tuple,
        n_components: int,
        latent_dim: int,
        batch_size: int,
        kl_warmup: int = 5,
        kl_annealing_mode: str = "linear",
        mc_kl: int = 100,
        mmd_warmup: int = 15,
        mmd_annealing_mode: str = "linear",
        kmeans_loss: float = 0.0,
        reg_cluster_variance: bool = False,
        **kwargs,
    ):
        """Initialize the Gaussian Mixture Latent layer.

        Args:
            input_shape (tuple): shape of the input data
            n_components (int): number of components in the Gaussian mixture.
            latent_dim (int): dimensionality of the latent space.
            batch_size (int): batch size for training.
            kl_warmup (int): number of epochs to warm up the KL divergence.
            kl_annealing_mode (str): mode to use for annealing the KL divergence. Must be one of "linear" and "sigmoid".
            mc_kl (int): number of Monte Carlo samples to use for computing the KL divergence.
            mmd_warmup (int): number of epochs to warm up the MMD.
            mmd_annealing_mode (str): mode to use for annealing the MMD. Must be one of "linear" and "sigmoid".
            kmeans_loss (float): weight of the Gram matrix regularization loss.
            reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
            **kwargs: keyword arguments passed to the parent class

        """
        super(GaussianMixtureLatent, self).__init__(**kwargs)
        self.seq_shape = input_shape
        self.n_components = n_components
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.kl_warmup = kl_warmup
        self.kl_annealing_mode = kl_annealing_mode
        self.mc_kl = mc_kl
        self.mmd_warmup = mmd_warmup
        self.mmd_annealing_mode = mmd_annealing_mode
        self.kmeans = kmeans_loss
        self.optimizer = Nadam(learning_rate=1e-3, clipvalue=0.75)
        self.reg_cluster_variance = reg_cluster_variance
        self.pretrain = tf.Variable(0.0, name="pretrain", trainable=False)

        # Initialize GM parameters
        self.c_mu = tf.Variable(
            tf.initializers.GlorotNormal()(shape=[self.n_components, self.latent_dim]),
            name="mu_c",
        )
        self.log_c_sigma = tf.Variable(
            tf.initializers.GlorotNormal()([self.n_components, self.latent_dim]),
            name="log_sigma_c",
        )

        # Initialize the Gaussian Mixture prior with the specified number of components
        self.prior = tf.constant(tf.ones([self.n_components]) * (1 / self.n_components))

        # Initialize layers
        self.z_gauss_mean = Dense(
            tfpl.IndependentNormal.params_size(self.latent_dim) // 2,
            name="cluster_means",
            activation="linear",
            kernel_initializer="glorot_uniform",
            activity_regularizer=None,
        )
        self.z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(self.latent_dim) // 2,
            name="cluster_variances",
            activation="softplus",
            kernel_initializer="glorot_uniform",
            activity_regularizer=tf.keras.regularizers.l1(0.1),
        )

        self.cluster_control_layer = deepof.model_utils.ClusterControl(
            batch_size=self.batch_size,
            n_components=self.n_components,
            encoding_dim=self.latent_dim,
            k=self.n_components,
        )

        # control KL weight
        self.kl_warm_up_iters = tf.cast(
            self.kl_warmup * (self.seq_shape // self.batch_size), tf.int64
        )
        self._kl_weight = tf.Variable(
            1.0, trainable=False, dtype=tf.float32, name="kl_weight"
        )

    def call(self, inputs, training=False):  # pragma: no cover
        """Compute the output of the layer."""
        z_gauss_mean = self.z_gauss_mean(inputs)
        z_gauss_var = self.z_gauss_var(inputs)

        z = tfd.MultivariateNormalDiag(
            loc=z_gauss_mean, scale_diag=tf.math.sqrt(tf.math.exp(z_gauss_var))
        )
        z_sample = tf.squeeze(z.sample())

        # Compute embedding probabilities given each cluster
        p_z_c = tf.stack(
            [
                tfd.MultivariateNormalDiag(
                    loc=self.c_mu[i, :],
                    scale_diag=tf.math.exp(self.log_c_sigma)[i, :],
                ).log_prob((z_sample if training else z_gauss_mean))
                + 1e-6
                for i in range(self.n_components)
            ],
            axis=-1,
        )

        # Update prior
        prior = self.prior

        # Compute cluster probabilitie given embedding
        z_cat = tf.math.log(prior + 1e-6) + p_z_c
        z_cat = tf.nn.log_softmax(z_cat, axis=-1)
        z_cat = tf.math.exp(z_cat)

        # Add clustering loss
        loss_clustering = -tf.reduce_sum(
            tf.multiply(z_cat, tf.math.softmax(p_z_c, axis=-1)), axis=-1
        ) * (1.0 - tf.cast(self.pretrain, tf.float32))
        loss_prior = -tf.math.reduce_sum(
            tf.math.xlogy(z_cat, 1e-6 + prior), axis=-1
        ) * (1.0 - tf.cast(self.pretrain, tf.float32))

        self.add_metric(loss_clustering, name="clustering_loss", aggregation="mean")
        self.add_metric(loss_prior, name="prior_loss", aggregation="mean")

        # Update KL weight based on the current iteration
        if self.kl_warm_up_iters > 0:
            if self.kl_annealing_mode in ["linear", "sigmoid"]:
                self._kl_weight = tf.cast(
                    tf.keras.backend.min(
                        [self.optimizer.iterations / self.kl_warm_up_iters, 1.0]
                    ),
                    tf.float32,
                )
                if self.kl_annealing_mode == "sigmoid":
                    self._kl_weight = tf.math.sigmoid(
                        (2 * self._kl_weight - 1)
                        / (self._kl_weight - self._kl_weight**2)
                    )
            else:
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
        else:
            self._kl_weight = tf.cast(1.0, tf.float32)

        loss_variational_1 = -1 / 2 * tf.reduce_sum(z_gauss_var + 1, axis=-1)
        loss_variational_2 = tf.math.reduce_sum(
            tf.math.xlogy(z_cat, 1e-6 + z_cat), axis=-1
        )
        kl = loss_variational_1 + loss_variational_2 * (
            1.0 - tf.cast(self.pretrain, tf.float32)
        )
        kl_batch = self._kl_weight * kl

        self.add_metric(self._kl_weight, aggregation="mean", name="kl_weight")
        self.add_metric(kl, aggregation="mean", name="kl_divergence")

        self.add_loss(tf.math.reduce_mean(loss_clustering))
        self.add_loss(tf.math.reduce_mean(loss_prior))
        self.add_loss(tf.math.reduce_mean(kl_batch))

        if training:
            z = z_sample
        else:
            # Select corresponding mean
            z = z_gauss_mean

        # Tracks clustering metrics
        if self.n_components > 1:
            z = self.cluster_control_layer([z, z_cat])

        if self.kmeans:
            kmeans_loss = deepof.model_utils.compute_kmeans_loss(
                z, weight=self.kmeans, batch_size=self.batch_size
            )
            self.add_loss(kmeans_loss)
            self.add_metric(kmeans_loss, name="kmeans_loss")

        return z, z_cat


# noinspection PyCallingNonCallable
def get_vade(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool,
    n_components: int,
    batch_size: int = 64,
    kl_warmup: int = 15,
    kl_annealing_mode: str = "sigmoid",
    mc_kl: int = 100,
    kmeans_loss: float = 1.0,
    reg_cluster_variance: bool = False,
    encoder_type: str = "recurrent",
    interaction_regularization: float = 0.0,
):
    """Build a Gaussian mixture variational autoencoder (VaDE) model, adapted to the DeepOF setting.

    Args:
        input_shape (tuple): shape of the input data.
        edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
        adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
        latent_dim (int): dimensionality of the latent space.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        n_components (int): number of components in the Gaussian mixture.
        batch_size (int): batch size for training.
        kl_warmup (int): Number of iterations during which to warm up the KL divergence.
        kl_annealing_mode (str): mode to use for annealing the KL divergence. Must be one of "linear" and "sigmoid".
        mc_kl (int): number of Monte Carlo samples to use for computing the KL divergence.
        kmeans_loss (float): weight of the Gram matrix loss as described in deepof.model_utils.compute_kmeans_loss.
        reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
        encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
        interaction_regularization (float): weight of the interaction regularization term.

    Returns:
        encoder (tf.keras.Model): connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,).
        decoder (tf.keras.Model): connected decoder of the VQ-VAE model.
        grouper (tf.keras.Model): deep clustering branch of the VQ-VAE model. Outputs a vector of shape (n_components,) for each training instance, corresponding to the soft counts for each cluster.
        vade (tf.keras.Model): complete VaDE model

    """
    if encoder_type == "recurrent":
        encoder = get_recurrent_encoder(
            input_shape=input_shape[1:],
            adjacency_matrix=adjacency_matrix,
            edge_feature_shape=edge_feature_shape[1:],
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_recurrent_decoder(
            input_shape=input_shape[1:], latent_dim=latent_dim
        )

    elif encoder_type == "TCN":
        encoder = get_TCN_encoder(
            input_shape=input_shape[1:],
            adjacency_matrix=adjacency_matrix,
            edge_feature_shape=edge_feature_shape[1:],
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_TCN_decoder(input_shape=input_shape[1:], latent_dim=latent_dim)

    elif encoder_type == "transformer":
        encoder = get_transformer_encoder(
            input_shape[1:],
            edge_feature_shape=edge_feature_shape[1:],
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_transformer_decoder(input_shape[1:], latent_dim=latent_dim)

    latent_space = GaussianMixtureLatent(
        input_shape=input_shape[0],
        n_components=n_components,
        latent_dim=latent_dim,
        batch_size=batch_size,
        kl_warmup=kl_warmup,
        kl_annealing_mode=kl_annealing_mode,
        mc_kl=mc_kl,
        kmeans_loss=kmeans_loss,
        reg_cluster_variance=reg_cluster_variance,
        name="gaussian_mixture_latent",
    )

    # Connect encoder and latent space
    inputs = Input(input_shape[1:])
    a = tf.keras.layers.Input(edge_feature_shape[1:], name="encoder_edge_features")
    encoder_outputs = encoder([inputs, a])
    latent, categorical = latent_space(encoder_outputs)
    embedding = tf.keras.Model([inputs, a], latent, name="encoder")
    grouper = tf.keras.Model([inputs, a], categorical, name="grouper")

    # Connect decoder
    vade_outputs = decoder([embedding.outputs, inputs])

    # Instantiate fully connected model
    vade = tf.keras.Model(embedding.inputs, vade_outputs, name="VaDE")

    return embedding, decoder, grouper, vade


class Classifier(tf.keras.Model):
    """Classifier for supervised pose motif elucidation."""

    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray = None,
        use_gnn: bool = True,
        batch_size: int = 2048,
        bias_initializer: float = 0.0,
        encoder_type: str = "recurrent",
        **kwargs,
    ):
        """Initialize a classifier model.

        Args:
            input_shape (tuple): shape of the input data.
            edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
            adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
            use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
            batch_size (int): batch size for training.
            encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
            bias_initializer (float): value to initialize the bias of the last layer to (default: 0.0).

        """
        super().__init__(**kwargs)

        if encoder_type == "recurrent":
            self.encoder = get_recurrent_encoder(
                input_shape=input_shape[1:],
                adjacency_matrix=adjacency_matrix,
                edge_feature_shape=edge_feature_shape[1:],
                latent_dim=1,
                use_gnn=use_gnn,
            )
        elif encoder_type == "TCN":
            self.encoder = get_TCN_encoder(
                input_shape=input_shape[1:],
                adjacency_matrix=adjacency_matrix,
                edge_feature_shape=edge_feature_shape[1:],
                latent_dim=1,
                use_gnn=use_gnn,
            )
        elif encoder_type == "transformer":
            self.encoder = get_transformer_encoder(
                input_shape[1:],
                edge_feature_shape=edge_feature_shape[1:],
                adjacency_matrix=adjacency_matrix,
                latent_dim=1,
                use_gnn=use_gnn,
            )

        self.dense = tf.keras.layers.Dense(16, activation="relu", name="classifier")
        self.dropout = tf.keras.layers.Dropout(0.5)
        self.bias_initializer = tf.keras.initializers.Constant(bias_initializer)
        self.clf = tf.keras.layers.Dense(
            1,
            activation="sigmoid",
            name="classifier",
            bias_initializer=self.bias_initializer,
        )

    def call(self, inputs, training=None, mask=None):
        """Apply a forward pass of the classifier.

        Args:
            - inputs (tf.Tensor): input data.
            - training (bool): whether the model is in training mode.
            - mask (tf.Tensor): mask for the input data.
        """
        x = self.encoder(inputs)
        x = self.dense(x)
        x = self.dropout(x, training=training)
        x = self.clf(x)

        return x


# noinspection PyDefaultArgument,PyCallingNonCallable
class VaDE(tf.keras.models.Model):
    """Gaussian Mixture Variational Autoencoder for pose motif elucidation."""

    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray = None,
        latent_dim: int = 8,
        use_gnn: bool = True,
        n_components: int = 15,
        batch_size: int = 64,
        kl_annealing_mode: str = "linear",
        kl_warmup_epochs: int = 15,
        montecarlo_kl: int = 100,
        kmeans_loss: float = 1.0,
        reg_cat_clusters: float = 1.0,
        reg_cluster_variance: bool = False,
        encoder_type: str = "recurrent",
        interaction_regularization: float = 0.0,
        **kwargs,
    ):
        """Init a VaDE model.

        Args:
            input_shape (tuple): Shape of the input to the full model.
            edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
            adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
            batch_size (int): Batch size for training.
            latent_dim (int): Dimensionality of the latent space.
            use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
            kl_annealing_mode (str): Annealing mode for KL annealing. Can be one of 'linear' and 'sigmoid'.
            kl_warmup_epochs (int): Number of epochs to warmup KL annealing.
            montecarlo_kl (int): Number of Monte Carlo samples for KL divergence.
            n_components (int): Number of mixture components in the latent space.
            kmeans_loss (float): weight of the gram matrix regularization loss.
            reg_cat_clusters (bool): whether to use the penalized uneven cluster membership in the latent space, by minimizing the KL divergence between cluster membership and a uniform categorical distribution.
            reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
            encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
            interaction_regularization (float): Regularization parameter for the interaction features.
            **kwargs: Additional keyword arguments.

        """
        super(VaDE, self).__init__(**kwargs)
        self.seq_shape = input_shape
        self.edge_feature_shape = edge_feature_shape
        self.adjacency_matrix = adjacency_matrix
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        self.use_gnn = use_gnn
        self.kl_annealing_mode = kl_annealing_mode
        self.kl_warmup = kl_warmup_epochs
        self.mc_kl = montecarlo_kl
        self.n_components = n_components
        self.optimizer = Nadam(learning_rate=1e-3, clipvalue=0.75)
        self.kmeans = kmeans_loss
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
        self.encoder_type = encoder_type
        self.interaction_regularization = interaction_regularization

        # Define VaDE model
        self.encoder, self.decoder, self.grouper, self.vade = get_vade(
            input_shape=self.seq_shape,
            edge_feature_shape=self.edge_feature_shape,
            adjacency_matrix=self.adjacency_matrix,
            n_components=self.n_components,
            latent_dim=self.latent_dim,
            use_gnn=use_gnn,
            batch_size=self.batch_size,
            kl_warmup=self.kl_warmup,
            kl_annealing_mode=self.kl_annealing_mode,
            mc_kl=self.mc_kl,
            kmeans_loss=self.kmeans,
            reg_cluster_variance=self.reg_cluster_variance,
            encoder_type=self.encoder_type,
            interaction_regularization=self.interaction_regularization,
        )

        # Propagate the optimizer to all relevant sub-models, to enable metric annealing
        self.vade.optimizer = self.optimizer
        self.vade.get_layer("gaussian_mixture_latent").optimizer = self.optimizer

        # Define metrics to track

        # Track all loss function components
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.val_total_loss_tracker = tf.keras.metrics.Mean(name="val_total_loss")

        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.val_reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="val_reconstruction_loss"
        )

        if self.reg_cat_clusters:
            self.cat_cluster_loss_tracker = tf.keras.metrics.Mean(
                name="cat_cluster_loss"
            )
            self.val_cat_cluster_loss_tracker = tf.keras.metrics.Mean(
                name="val_cat_cluster_loss"
            )

    @property
    def metrics(self):  # pragma: no cover
        """Initializes tracked metrics of VaDE model."""
        metrics = [
            self.total_loss_tracker,
            self.val_total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.val_reconstruction_loss_tracker,
        ]

        if self.reg_cat_clusters:
            metrics += [
                self.cat_cluster_loss_tracker,
                self.val_cat_cluster_loss_tracker,
            ]

        return metrics

    @property
    def get_gmm_params(self):
        """Return the GMM parameters of the model."""
        # Get GMM parameters
        return {
            "means": self.grouper.get_layer("gaussian_mixture_latent").c_mu,
            "sigmas": tf.math.exp(
                self.grouper.get_layer("gaussian_mixture_latent").log_c_sigma
            ),
            "weights": tf.math.softmax(
                self.grouper.get_layer("gaussian_mixture_latent").prior
            ),
        }

    def set_pretrain_mode(self, switch):
        """Set the pretrain mode of the model."""
        self.grouper.get_layer("gaussian_mixture_latent").pretrain.assign(switch)

    def pretrain(
        self,
        data,
        embed_x,
        embed_a,
        epochs=10,
        samples=10000,
        gmm_initialize=True,
        **kwargs,
    ):
        """Run a GMM directed pretraining of the encoder, to minimize the likelihood of getting stuck in a local minimum."""
        # Turn on pretrain mode
        self.set_pretrain_mode(1.0)

        # pre-train
        self.fit(
            data,
            epochs=epochs,
            **kwargs,
        )

        # Turn off pretrain mode
        self.set_pretrain_mode(0.0)

        if gmm_initialize:

            with tf.device("CPU"):
                # Get embedding samples
                emb_idx = np.random.choice(range(embed_x.shape[0]), samples)

                # map to latent
                z = self.encoder([embed_x[emb_idx], embed_a[emb_idx]])
                # fit GMM
                gmm = GaussianMixture(
                    n_components=self.n_components,
                    covariance_type="diag",
                    reg_covar=1e-04,
                    **kwargs,
                ).fit(z)
                # get GMM parameters
                mu = gmm.means_
                sigma2 = gmm.covariances_

            # initialize mixture components
            self.grouper.get_layer("gaussian_mixture_latent").c_mu.assign(
                tf.convert_to_tensor(value=mu, dtype=tf.float32)
            )
            self.grouper.get_layer("gaussian_mixture_latent").log_c_sigma.assign(
                tf.math.log(
                    tf.math.sqrt(tf.convert_to_tensor(value=sigma2, dtype=tf.float32))
                )
            )

    @tf.function
    def call(self, inputs, **kwargs):
        """Call the VaDE model."""
        return self.vade(inputs, **kwargs)

    def train_step(self, data):  # pragma: no cover
        """Perform a training step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        with tf.GradientTape() as tape:

            # Get outputs from the full model
            outputs = self.vade([x, a], training=True)

            # Get rid of the attention scores that the transformer decoder outputs
            if self.encoder_type == "transformer":
                outputs = outputs[0]

            if isinstance(outputs, list):
                reconstructions = outputs[0]
            else:
                reconstructions = outputs

            # Regularize embeddings
            # groups = self.grouper(x, training=True)

            # Compute losses
            seq_inputs = next(y)
            total_loss = sum(self.vade.losses)

            # Add a regularization term to the soft_counts, to prevent the embedding layer from
            # collapsing into a few clusters.
            if self.reg_cat_clusters:

                soft_counts = self.grouper([x, a], training=True)
                soft_counts_regulrization = (
                    self.reg_cat_clusters
                    * deepof.model_utils.cluster_frequencies_regularizer(
                        soft_counts=soft_counts, k=self.n_components
                    )
                )
                total_loss += soft_counts_regulrization

            # Compute reconstruction loss
            reconstruction_loss = -tf.reduce_mean(reconstructions.log_prob(seq_inputs))
            total_loss += reconstruction_loss

        # Backpropagation
        grads = tape.gradient(total_loss, self.vade.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vade.trainable_variables))

        # Track losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
        }

        if self.reg_cat_clusters:
            self.cat_cluster_loss_tracker.update_state(soft_counts_regulrization)
            log_dict["cat_cluster_loss"] = self.cat_cluster_loss_tracker.result()

        # Log to TensorBoard, both explicitly and implicitly (within model) tracked metrics
        return {**log_dict, **{met.name: met.result() for met in self.vade.metrics}}

    # noinspection PyUnboundLocalVariable
    @tf.function
    def test_step(self, data):  # pragma: no cover
        """Performs a test step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        # Get outputs from the full model
        outputs = self.vade([x, a], training=False)

        # Get rid of the attention scores that the transformer decoder outputs
        if self.encoder_type == "transformer":
            outputs = outputs[0]

        if isinstance(outputs, list):
            reconstructions = outputs[0]
        else:
            reconstructions = outputs

        # Compute losses
        seq_inputs = next(y)
        total_loss = sum(self.vade.losses)

        # Add a regularization term to the soft_counts, to prevent the embedding layer from
        # collapsing into a few clusters.
        if self.reg_cat_clusters:
            soft_counts = self.grouper([x, a], training=False)
            soft_counts_regulrization = (
                self.reg_cat_clusters
                * deepof.model_utils.cluster_frequencies_regularizer(
                    soft_counts=soft_counts, k=self.n_components
                )
            )
            total_loss += soft_counts_regulrization

        # Compute reconstruction loss
        reconstruction_loss = -tf.reduce_mean(reconstructions.log_prob(seq_inputs))
        total_loss += reconstruction_loss

        # Track losses
        self.val_total_loss_tracker.update_state(total_loss)
        self.val_reconstruction_loss_tracker.update_state(reconstruction_loss)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.val_total_loss_tracker.result(),
            "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
        }

        if self.reg_cat_clusters:
            self.val_cat_cluster_loss_tracker.update_state(soft_counts_regulrization)
            log_dict["cat_cluster_loss"] = self.val_cat_cluster_loss_tracker.result()

        return {**log_dict, **{met.name: met.result() for met in self.vade.metrics}}


# noinspection PyDefaultArgument,PyCallingNonCallable
class Contrastive(tf.keras.models.Model):
    """Self-supervised contrastive embeddings."""

    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray = None,
        encoder_type: str = "TCN",
        latent_dim: int = 8,
        use_gnn: bool = True,
        temperature: float = 0.1,
        similarity_function: str = "cosine",
        loss_function: str = "nce",
        beta: float = 0.1,
        tau: float = 0.1,
        interaction_regularization: float = 0.0,
        **kwargs,
    ):
        """Init a self-supervised Contrastive embedding model.

        Args:
            input_shape (tuple): Shape of the input to the full model.
            edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
            adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
            encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
            latent_dim (int): Dimensionality of the latent space.
            use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
            temperature (float):
            similarity_function (str):
            loss_function (str):
            beta (float):
            tau (float):
            interaction_regularization (float): Regularization parameter for the interaction features.
            **kwargs: Additional keyword arguments.

        """
        super(Contrastive, self).__init__(**kwargs)
        self.seq_shape = input_shape
        self.edge_feature_shape = edge_feature_shape
        self.adjacency_matrix = adjacency_matrix
        self.latent_dim = latent_dim
        self.use_gnn = use_gnn
        self.window_length = self.seq_shape[1] // 2
        self.temperature = temperature
        self.similarity_function = similarity_function
        self.loss_function = loss_function
        self.beta = beta
        self.tau = tau
        self.optimizer = Nadam(learning_rate=1e-3, clipvalue=0.75)
        self.encoder_type = encoder_type
        self.interaction_regularization = interaction_regularization

        # Define Contrastive model
        if encoder_type == "recurrent":

            self.encoder = get_recurrent_encoder(
                input_shape=(self.window_length, input_shape[-1]),
                edge_feature_shape=(
                    self.window_length,
                    self.edge_feature_shape[2],
                ),
                adjacency_matrix=self.adjacency_matrix,
                latent_dim=latent_dim,
                use_gnn=use_gnn,
                interaction_regularization=interaction_regularization,
            )

        elif encoder_type == "TCN":
            self.encoder = get_TCN_encoder(
                input_shape=(self.window_length, input_shape[-1]),
                edge_feature_shape=(
                    self.window_length,
                    self.edge_feature_shape[2],
                ),
                adjacency_matrix=self.adjacency_matrix,
                latent_dim=latent_dim,
                use_gnn=use_gnn,
                interaction_regularization=interaction_regularization,
            )

        elif encoder_type == "transformer":

            self.encoder = get_transformer_encoder(
                (self.window_length, input_shape[-1]),
                edge_feature_shape=(
                    self.window_length,
                    self.edge_feature_shape[2],
                ),
                adjacency_matrix=self.adjacency_matrix,
                latent_dim=latent_dim,
                use_gnn=use_gnn,
                interaction_regularization=interaction_regularization,
            )

        # Define metrics to track

        # Track all loss function components
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.val_total_loss_tracker = tf.keras.metrics.Mean(name="val_total_loss")
        self.mean_sim_tracker = tf.keras.metrics.Mean(name="pos_similarity")
        self.val_mean_sim_tracker = tf.keras.metrics.Mean(name="val_pos_similarity")
        self.neg_sim_tracker = tf.keras.metrics.Mean(name="neg_similarity")
        self.val_neg_sim_tracker = tf.keras.metrics.Mean(name="val_neg_similarity")

    @property
    def metrics(self):  # pragma: no cover
        """Initializes tracked metrics of the contrastive model."""
        metrics = [
            self.total_loss_tracker,
            self.val_total_loss_tracker,
            self.mean_sim_tracker,
            self.val_mean_sim_tracker,
            self.neg_sim_tracker,
            self.val_neg_sim_tracker,
        ]

        return metrics

    @tf.function
    def call(self, inputs, **kwargs):
        """Call the contrastive model."""
        return self.encoder(inputs, **kwargs)

    def train_step(self, data):  # pragma: no cover
        """Perform a training step."""
        # Unpack data
        x, a, y = data
        if not isinstance(y, tuple):
            y = [
                y
            ]  # Labels won't be used for now, but may come handy if exploring regularizers in the future

        with tf.GradientTape() as tape:

            # Get positive and negative pairs
            def ts_samples(mbatch, win):
                x = mbatch[:, 1 : win + 1]
                y = mbatch[:, -win:]

                return x, y

            pos, neg = ts_samples(x, self.window_length)
            pos_a, neg_a = ts_samples(a, self.window_length)

            # Compute contrastive loss
            enc_pos = self.encoder([pos, pos_a], training=True)
            enc_neg = self.encoder([neg, neg_a], training=True)

            # normalize projection feature vectors
            enc_pos = tf.math.l2_normalize(enc_pos, axis=1)
            enc_neg = tf.math.l2_normalize(enc_neg, axis=1)

            # loss, mean_sim = ls.dcl_loss_fn(zis, zjs, temperature, lfn)
            (
                contrastive_loss,
                mean_sim,
                neg_sim,
            ) = deepof.model_utils.select_contrastive_loss(
                enc_pos,
                enc_neg,
                similarity=self.similarity_function,
                loss_fn=self.loss_function,
                temperature=self.temperature,
                tau=self.tau,
                beta=self.beta,
                elimination_topk=0.1,
                attraction=False,
            )

            total_loss = contrastive_loss

        # Backpropagation
        grads = tape.gradient(total_loss, self.encoder.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.encoder.trainable_variables))

        # Track losses
        self.total_loss_tracker.update_state(total_loss)
        self.mean_sim_tracker.update_state(mean_sim)
        self.neg_sim_tracker.update_state(neg_sim)

        # Log to TensorBoard, both explicitly and implicitly (within model) tracked metrics
        return {met.name: met.result() for met in self.metrics if "val" not in met.name}

    # noinspection PyUnboundLocalVariable
    @tf.function
    def test_step(self, data):  # pragma: no cover
        """Performs a test step."""
        # Unpack data
        x, a, y = data
        if not isinstance(y, tuple):
            y = [
                y
            ]  # Labels won't be used for now, but may come handy if exploring regularizers in the future

        # Get positive and negative pairs
        def ts_samples(mbatch, win):
            x = mbatch[:, 1 : win + 1]
            y = mbatch[:, -win:]

            return x, y

        pos, neg = ts_samples(x, self.window_length)
        pos_a, neg_a = ts_samples(a, self.window_length)

        # Compute contrastive loss
        enc_pos = self.encoder([pos, pos_a], training=False)
        enc_neg = self.encoder([neg, neg_a], training=False)

        # normalize projection feature vectors
        enc_pos = tf.math.l2_normalize(enc_pos, axis=1)
        enc_neg = tf.math.l2_normalize(enc_neg, axis=1)

        # loss, mean_sim = ls.dcl_loss_fn(zis, zjs, temperature, lfn)
        (
            contrastive_loss,
            mean_sim,
            neg_sim,
        ) = deepof.model_utils.select_contrastive_loss(
            enc_pos,
            enc_neg,
            similarity=self.similarity_function,
            loss_fn=self.loss_function,
            temperature=self.temperature,
            tau=self.tau,
            beta=self.beta,
            elimination_topk=0.1,
            attraction=False,
        )

        total_loss = contrastive_loss

        # Track losses
        self.val_total_loss_tracker.update_state(total_loss)
        self.val_mean_sim_tracker.update_state(mean_sim)
        self.val_neg_sim_tracker.update_state(neg_sim)

        # Log to TensorBoard, both explicitly and implicitly (within model) tracked metrics
        return {
            met.name.replace("val_", ""): met.result()
            for met in self.metrics
            if "val" in met.name
        }