lucasmiranda42/deepof

View on GitHub
docs/build/html/_generated/deepof.models.GaussianMixtureLatent.html

Summary

Maintainability
Test Coverage
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>deepof.models.GaussianMixtureLatent &mdash; deepof 0.6.2 documentation</title>
      <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
      <link rel="stylesheet" href="../_static/jupyter-sphinx.css" type="text/css" />
      <link rel="stylesheet" href="../_static/thebelab.css" type="text/css" />
      <link rel="stylesheet" href="../_static/custom.css" type="text/css" />
    <link rel="shortcut icon" href="../_static/deepof.ico"/>
  <!--[if lt IE 9]>
    <script src="../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
        <script src="../_static/jquery.js"></script>
        <script src="../_static/underscore.js"></script>
        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/sphinx_highlight.js"></script>
        <script src="../_static/thebelab-helper.js"></script>
        <script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
        <script src="https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@^1.0.1/dist/embed-amd.js"></script>
        <script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
    <script src="../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../genindex.html" />
    <link rel="search" title="Search" href="../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../index.html">
            
              <img src="../_static/deepof_sidebar.ico" class="logo" alt="Logo"/>
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <!-- Local TOC -->
              <div class="local-toc"><ul>
<li><a class="reference internal" href="#">deepof.models.GaussianMixtureLatent</a><ul>
<li><a class="reference internal" href="#deepof.models.GaussianMixtureLatent"><code class="docutils literal notranslate"><span class="pre">GaussianMixtureLatent</span></code></a><ul>
<li><a class="reference internal" href="#deepof.models.GaussianMixtureLatent.__init__"><code class="docutils literal notranslate"><span class="pre">GaussianMixtureLatent.__init__()</span></code></a></li>
<li><a class="reference internal" href="#id0"><code class="docutils literal notranslate"><span class="pre">GaussianMixtureLatent.__init__()</span></code></a></li>
</ul>
</li>
</ul>
</li>
</ul>
</div>
        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">deepof</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
      <li class="breadcrumb-item active">deepof.models.GaussianMixtureLatent</li>
      <li class="wy-breadcrumbs-aside">
            <a href="../_sources/_generated/deepof.models.GaussianMixtureLatent.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  
<style>
/* CSS overrides for sphinx_rtd_theme */

/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
    margin-bottom: 19px;  /* padding has already 5px */
}

/* ... except between code cells! */
.nblast.container + .nbinput.container {
    margin-top: -19px;
}

.admonition > p:before {
    margin-right: 4px;  /* make room for the exclamation icon */
}

/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
    text-align: unset;
}
</style>
<section id="deepof-models-gaussianmixturelatent">
<h1>deepof.models.GaussianMixtureLatent<a class="headerlink" href="#deepof-models-gaussianmixturelatent" title="Permalink to this heading"></a></h1>
<dl class="py class">
<dt class="sig sig-object py" id="deepof.models.GaussianMixtureLatent">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">deepof.models.</span></span><span class="sig-name descname"><span class="pre">GaussianMixtureLatent</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#deepof.models.GaussianMixtureLatent" title="Permalink to this definition"></a></dt>
<dd><p>Gaussian Mixture probabilistic latent space model.</p>
<p>Used to represent the embedding of motion tracking data in a mixture of Gaussians
with a provided number of components, with means, covariances and weights.
Implementation based on VaDE (<a class="reference external" href="https://arxiv.org/abs/1611.05148">https://arxiv.org/abs/1611.05148</a>)
and VaDE-SC (<a class="reference external" href="https://openreview.net/forum?id=RQ428ZptQfU">https://openreview.net/forum?id=RQ428ZptQfU</a>).</p>
<dl class="py method">
<dt class="sig sig-object py" id="deepof.models.GaussianMixtureLatent.__init__">
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input_shape</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">tuple</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_components</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">latent_dim</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_annealing_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'linear'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mc_kl</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mmd_warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">15</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mmd_annealing_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'linear'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kmeans_loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_cluster_variance</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#deepof.models.GaussianMixtureLatent.__init__" title="Permalink to this definition"></a></dt>
<dd><p>Initialize the Gaussian Mixture Latent layer.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>input_shape</strong> (<em>tuple</em>) – shape of the input data</p></li>
<li><p><strong>n_components</strong> (<em>int</em>) – number of components in the Gaussian mixture.</p></li>
<li><p><strong>latent_dim</strong> (<em>int</em>) – dimensionality of the latent space.</p></li>
<li><p><strong>batch_size</strong> (<em>int</em>) – batch size for training.</p></li>
<li><p><strong>kl_warmup</strong> (<em>int</em>) – number of epochs to warm up the KL divergence.</p></li>
<li><p><strong>kl_annealing_mode</strong> (<em>str</em>) – mode to use for annealing the KL divergence. Must be one of “linear” and “sigmoid”.</p></li>
<li><p><strong>mc_kl</strong> (<em>int</em>) – number of Monte Carlo samples to use for computing the KL divergence.</p></li>
<li><p><strong>mmd_warmup</strong> (<em>int</em>) – number of epochs to warm up the MMD.</p></li>
<li><p><strong>mmd_annealing_mode</strong> (<em>str</em>) – mode to use for annealing the MMD. Must be one of “linear” and “sigmoid”.</p></li>
<li><p><strong>kmeans_loss</strong> (<em>float</em>) – weight of the Gram matrix regularization loss.</p></li>
<li><p><strong>reg_cluster_variance</strong> (<em>bool</em>) – whether to penalize uneven cluster variances in the latent space.</p></li>
<li><p><strong>**kwargs</strong> – keyword arguments passed to the parent class</p></li>
</ul>
</dd>
</dl>
</dd></dl>

<p class="rubric">Methods</p>
<table class="autosummary longtable docutils align-default">
<tbody>
<tr class="row-odd"><td><p><a class="reference internal" href="#id0" title="deepof.models.GaussianMixtureLatent.__init__"><code class="xref py py-obj docutils literal notranslate"><span class="pre">__init__</span></code></a>(input_shape, n_components, ...[, ...])</p></td>
<td><p>Initialize the Gaussian Mixture Latent layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">add_loss</span></code>(losses, **kwargs)</p></td>
<td><p>Add loss tensor(s), potentially dependent on layer inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">add_metric</span></code>(value[, name])</p></td>
<td><p>Adds metric tensor to the layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">add_update</span></code>(updates)</p></td>
<td><p>Add update op(s), potentially dependent on layer inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">add_variable</span></code>(*args, **kwargs)</p></td>
<td><p>Deprecated, do NOT use! Alias for <cite>add_weight</cite>.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">add_weight</span></code>([name, shape, dtype, ...])</p></td>
<td><p>Adds a new variable to the layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">build</span></code>(input_shape)</p></td>
<td><p>Builds the model based on input shapes received.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">build_from_config</span></code>(config)</p></td>
<td><p></p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">call</span></code>(inputs[, training])</p></td>
<td><p>Compute the output of the layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compile</span></code>([optimizer, loss, metrics, ...])</p></td>
<td><p>Configures the model for training.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compile_from_config</span></code>(config)</p></td>
<td><p></p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_loss</span></code>([x, y, y_pred, sample_weight])</p></td>
<td><p>Compute the total loss, validate it, and return it.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_mask</span></code>(inputs[, mask])</p></td>
<td><p>Computes an output mask tensor.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_metrics</span></code>(x, y, y_pred, sample_weight)</p></td>
<td><p>Update metric states and collect all metrics to be returned.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_output_shape</span></code>(input_shape)</p></td>
<td><p>Computes the output shape of the layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_output_signature</span></code>(input_signature)</p></td>
<td><p>Compute the output tensor signature of the layer based on the inputs.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">count_params</span></code>()</p></td>
<td><p>Count the total number of scalars composing the weights.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">evaluate</span></code>([x, y, batch_size, verbose, ...])</p></td>
<td><p>Returns the loss value &amp; metrics values for the model in test mode.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">evaluate_generator</span></code>(generator[, steps, ...])</p></td>
<td><p>Evaluates the model on a data generator.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">export</span></code>(filepath)</p></td>
<td><p>Create a SavedModel artifact for inference (e.g.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">finalize_state</span></code>()</p></td>
<td><p>Finalizes the layers state after updating layer weights.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">fit</span></code>([x, y, batch_size, epochs, verbose, ...])</p></td>
<td><p>Trains the model for a fixed number of epochs (dataset iterations).</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">fit_generator</span></code>(generator[, steps_per_epoch, ...])</p></td>
<td><p>Fits the model on data yielded batch-by-batch by a Python generator.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">from_config</span></code>(config[, custom_objects])</p></td>
<td><p>Creates a layer from its config.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_build_config</span></code>()</p></td>
<td><p></p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_compile_config</span></code>()</p></td>
<td><p></p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_config</span></code>()</p></td>
<td><p>Returns the config of the <cite>Model</cite>.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_input_at</span></code>(node_index)</p></td>
<td><p>Retrieves the input tensor(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_input_mask_at</span></code>(node_index)</p></td>
<td><p>Retrieves the input mask tensor(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_input_shape_at</span></code>(node_index)</p></td>
<td><p>Retrieves the input shape(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_layer</span></code>([name, index])</p></td>
<td><p>Retrieves a layer based on either its name (unique) or index.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_metrics_result</span></code>()</p></td>
<td><p>Returns the model's metrics values as a dict.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_output_at</span></code>(node_index)</p></td>
<td><p>Retrieves the output tensor(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_output_mask_at</span></code>(node_index)</p></td>
<td><p>Retrieves the output mask tensor(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_output_shape_at</span></code>(node_index)</p></td>
<td><p>Retrieves the output shape(s) of a layer at a given node.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_weight_paths</span></code>()</p></td>
<td><p>Retrieve all the variables and their paths for the model.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">get_weights</span></code>()</p></td>
<td><p>Retrieves the weights of the model.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">load_weights</span></code>(filepath[, skip_mismatch, ...])</p></td>
<td><p>Loads all layer weights from a saved files.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">make_predict_function</span></code>([force])</p></td>
<td><p>Creates a function that executes one step of inference.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">make_test_function</span></code>([force])</p></td>
<td><p>Creates a function that executes one step of evaluation.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">make_train_function</span></code>([force])</p></td>
<td><p>Creates a function that executes one step of training.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">predict</span></code>(x[, batch_size, verbose, steps, ...])</p></td>
<td><p>Generates output predictions for the input samples.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">predict_generator</span></code>(generator[, steps, ...])</p></td>
<td><p>Generates predictions for the input samples from a data generator.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">predict_on_batch</span></code>(x)</p></td>
<td><p>Returns predictions for a single batch of samples.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">predict_step</span></code>(data)</p></td>
<td><p>The logic for one inference step.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">reset_metrics</span></code>()</p></td>
<td><p>Resets the state of all the metrics in the model.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">reset_states</span></code>()</p></td>
<td><p></p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">save</span></code>(filepath[, overwrite, save_format])</p></td>
<td><p>Saves a model as a TensorFlow SavedModel or HDF5 file.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">save_spec</span></code>([dynamic_batch])</p></td>
<td><p>Returns the <cite>tf.TensorSpec</cite> of call args as a tuple <cite>(args, kwargs)</cite>.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">save_weights</span></code>(filepath[, overwrite, ...])</p></td>
<td><p>Saves all layer weights.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">set_weights</span></code>(weights)</p></td>
<td><p>Sets the weights of the layer, from NumPy arrays.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">summary</span></code>([line_length, positions, print_fn, ...])</p></td>
<td><p>Prints a string summary of the network.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">test_on_batch</span></code>(x[, y, sample_weight, ...])</p></td>
<td><p>Test the model on a single batch of samples.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">test_step</span></code>(data)</p></td>
<td><p>The logic for one evaluation step.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">to_json</span></code>(**kwargs)</p></td>
<td><p>Returns a JSON string containing the network configuration.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">to_yaml</span></code>(**kwargs)</p></td>
<td><p>Returns a yaml string containing the network configuration.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">train_on_batch</span></code>(x[, y, sample_weight, ...])</p></td>
<td><p>Runs a single gradient update on a single batch of data.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">train_step</span></code>(data)</p></td>
<td><p>The logic for one training step.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">with_name_scope</span></code>(method)</p></td>
<td><p>Decorator to automatically enter the module name scope.</p></td>
</tr>
</tbody>
</table>
<p class="rubric">Attributes</p>
<table class="autosummary longtable docutils align-default">
<tbody>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">activity_regularizer</span></code></p></td>
<td><p>Optional regularizer function for the output of this layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">compute_dtype</span></code></p></td>
<td><p>The dtype of the layer's computations.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">distribute_reduction_method</span></code></p></td>
<td><p>The method employed to reduce per-replica values during training.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">distribute_strategy</span></code></p></td>
<td><p>The <cite>tf.distribute.Strategy</cite> this model was created under.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">dtype</span></code></p></td>
<td><p>The dtype of the layer weights.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">dtype_policy</span></code></p></td>
<td><p>The dtype policy associated with this layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">dynamic</span></code></p></td>
<td><p>Whether the layer is dynamic (eager-only); set in the constructor.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">inbound_nodes</span></code></p></td>
<td><p>Return Functional API nodes upstream of this layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">input</span></code></p></td>
<td><p>Retrieves the input tensor(s) of a layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">input_mask</span></code></p></td>
<td><p>Retrieves the input mask tensor(s) of a layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">input_shape</span></code></p></td>
<td><p>Retrieves the input shape(s) of a layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">input_spec</span></code></p></td>
<td><p><cite>InputSpec</cite> instance(s) describing the input format for this layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">jit_compile</span></code></p></td>
<td><p>Specify whether to compile the model with XLA.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">layers</span></code></p></td>
<td><p></p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">losses</span></code></p></td>
<td><p>List of losses added using the <cite>add_loss()</cite> API.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">metrics</span></code></p></td>
<td><p>Return metrics added using <cite>compile()</cite> or <cite>add_metric()</cite>.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">metrics_names</span></code></p></td>
<td><p>Returns the model's display labels for all outputs.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">name</span></code></p></td>
<td><p>Name of the layer (string), set in the constructor.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">name_scope</span></code></p></td>
<td><p>Returns a <cite>tf.name_scope</cite> instance for this class.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">non_trainable_variables</span></code></p></td>
<td><p>Sequence of non-trainable variables owned by this module and its submodules.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">non_trainable_weights</span></code></p></td>
<td><p>List of all non-trainable weights tracked by this layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">outbound_nodes</span></code></p></td>
<td><p>Return Functional API nodes downstream of this layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">output</span></code></p></td>
<td><p>Retrieves the output tensor(s) of a layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">output_mask</span></code></p></td>
<td><p>Retrieves the output mask tensor(s) of a layer.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">output_shape</span></code></p></td>
<td><p>Retrieves the output shape(s) of a layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">run_eagerly</span></code></p></td>
<td><p>Settable attribute indicating whether the model should run eagerly.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">state_updates</span></code></p></td>
<td><p>Deprecated, do NOT use!</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">stateful</span></code></p></td>
<td><p></p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">submodules</span></code></p></td>
<td><p>Sequence of all sub-modules.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">supports_masking</span></code></p></td>
<td><p>Whether this layer supports computing a mask using <cite>compute_mask</cite>.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">trainable</span></code></p></td>
<td><p></p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">trainable_variables</span></code></p></td>
<td><p>Sequence of trainable variables owned by this module and its submodules.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">trainable_weights</span></code></p></td>
<td><p>List of all trainable weights tracked by this layer.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">updates</span></code></p></td>
<td><p></p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">variable_dtype</span></code></p></td>
<td><p>Alias of <cite>Layer.dtype</cite>, the dtype of the weights.</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">variables</span></code></p></td>
<td><p>Returns the list of all layer variables/weights.</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">weights</span></code></p></td>
<td><p>Returns the list of all layer variables/weights.</p></td>
</tr>
</tbody>
</table>
<dl class="py method">
<dt class="sig sig-object py" id="id0">
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input_shape</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">tuple</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_components</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">latent_dim</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_annealing_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'linear'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mc_kl</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mmd_warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">15</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mmd_annealing_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'linear'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kmeans_loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_cluster_variance</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#id0" title="Permalink to this definition"></a></dt>
<dd><p>Initialize the Gaussian Mixture Latent layer.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>input_shape</strong> (<em>tuple</em>) – shape of the input data</p></li>
<li><p><strong>n_components</strong> (<em>int</em>) – number of components in the Gaussian mixture.</p></li>
<li><p><strong>latent_dim</strong> (<em>int</em>) – dimensionality of the latent space.</p></li>
<li><p><strong>batch_size</strong> (<em>int</em>) – batch size for training.</p></li>
<li><p><strong>kl_warmup</strong> (<em>int</em>) – number of epochs to warm up the KL divergence.</p></li>
<li><p><strong>kl_annealing_mode</strong> (<em>str</em>) – mode to use for annealing the KL divergence. Must be one of “linear” and “sigmoid”.</p></li>
<li><p><strong>mc_kl</strong> (<em>int</em>) – number of Monte Carlo samples to use for computing the KL divergence.</p></li>
<li><p><strong>mmd_warmup</strong> (<em>int</em>) – number of epochs to warm up the MMD.</p></li>
<li><p><strong>mmd_annealing_mode</strong> (<em>str</em>) – mode to use for annealing the MMD. Must be one of “linear” and “sigmoid”.</p></li>
<li><p><strong>kmeans_loss</strong> (<em>float</em>) – weight of the Gram matrix regularization loss.</p></li>
<li><p><strong>reg_cluster_variance</strong> (<em>bool</em>) – whether to penalize uneven cluster variances in the latent space.</p></li>
<li><p><strong>**kwargs</strong> – keyword arguments passed to the parent class</p></li>
</ul>
</dd>
</dl>
</dd></dl>

</dd></dl>

</section>


           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2024, Lucas Miranda.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>