lucasmiranda42/deepof

View on GitHub
docs/build/html/_generated/deepof.model_utils.embedding_model_fitting.html

Summary

Maintainability
Test Coverage
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>deepof.model_utils.embedding_model_fitting &mdash; deepof 0.6.2 documentation</title>
      <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
      <link rel="stylesheet" href="../_static/jupyter-sphinx.css" type="text/css" />
      <link rel="stylesheet" href="../_static/thebelab.css" type="text/css" />
      <link rel="stylesheet" href="../_static/custom.css" type="text/css" />
    <link rel="shortcut icon" href="../_static/deepof.ico"/>
  <!--[if lt IE 9]>
    <script src="../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
        <script src="../_static/jquery.js"></script>
        <script src="../_static/underscore.js"></script>
        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/sphinx_highlight.js"></script>
        <script src="../_static/thebelab-helper.js"></script>
        <script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
        <script src="https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@^1.0.1/dist/embed-amd.js"></script>
        <script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
    <script src="../_static/js/theme.js"></script>
    <link rel="index" title="Index" href="../genindex.html" />
    <link rel="search" title="Search" href="../search.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >

          
          
          <a href="../index.html">
            
              <img src="../_static/deepof_sidebar.ico" class="logo" alt="Logo"/>
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <!-- Local TOC -->
              <div class="local-toc"><ul>
<li><a class="reference internal" href="#">deepof.model_utils.embedding_model_fitting</a><ul>
<li><a class="reference internal" href="#deepof.model_utils.embedding_model_fitting"><code class="docutils literal notranslate"><span class="pre">embedding_model_fitting()</span></code></a></li>
</ul>
</li>
</ul>
</div>
        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">deepof</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
      <li class="breadcrumb-item active">deepof.model_utils.embedding_model_fitting</li>
      <li class="wy-breadcrumbs-aside">
            <a href="../_sources/_generated/deepof.model_utils.embedding_model_fitting.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  
<style>
/* CSS overrides for sphinx_rtd_theme */

/* 24px margin */
.nbinput.nblast.container,
.nboutput.nblast.container {
    margin-bottom: 19px;  /* padding has already 5px */
}

/* ... except between code cells! */
.nblast.container + .nbinput.container {
    margin-top: -19px;
}

.admonition > p:before {
    margin-right: 4px;  /* make room for the exclamation icon */
}

/* Fix math alignment, see https://github.com/rtfd/sphinx_rtd_theme/pull/686 */
.math {
    text-align: unset;
}
</style>
<section id="deepof-model-utils-embedding-model-fitting">
<h1>deepof.model_utils.embedding_model_fitting<a class="headerlink" href="#deepof-model-utils-embedding-model-fitting" title="Permalink to this heading"></a></h1>
<dl class="py function">
<dt class="sig sig-object py" id="deepof.model_utils.embedding_model_fitting">
<span class="sig-prename descclassname"><span class="pre">deepof.model_utils.</span></span><span class="sig-name descname"><span class="pre">embedding_model_fitting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preprocessed_object</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">ndarray</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">ndarray</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">ndarray</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">ndarray</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">adjacency_matrix</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">ndarray</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">embedding_model</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">encoder_type</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">latent_dim</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epochs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_history</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">log_hparams</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_components</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">output_path</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kmeans_loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pretrained</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">save_checkpoints</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">save_weights</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_type</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_annealing_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kl_warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_cat_clusters</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recluster</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">temperature</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">contrastive_similarity_function</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">contrastive_loss_function</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tau</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">interaction_regularization</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">run</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#deepof.model_utils.embedding_model_fitting" title="Permalink to this definition"></a></dt>
<dd><p>Trains the specified embedding model on the preprocessed data.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>coordinates</strong> (<em>np.ndarray</em>) – Coordinates of the data.</p></li>
<li><p><strong>preprocessed_object</strong> (<em>tuple</em>) – Tuple containing the preprocessed data.</p></li>
<li><p><strong>adjacency_matrix</strong> (<em>np.ndarray</em>) – adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.</p></li>
<li><p><strong>embedding_model</strong> (<em>str</em>) – Model to use to embed and cluster the data. Must be one of VQVAE (default), VaDE, and contrastive.</p></li>
<li><p><strong>encoder_type</strong> (<em>str</em>) – Encoder architecture to use. Must be one of “recurrent”, “TCN”, and “transformer”.</p></li>
<li><p><strong>batch_size</strong> (<em>int</em>) – Batch size to use for training.</p></li>
<li><p><strong>latent_dim</strong> (<em>int</em>) – Encoding size to use for training.</p></li>
<li><p><strong>epochs</strong> (<em>int</em>) – Number of epochs to train the autoencoder for.</p></li>
<li><p><strong>log_history</strong> (<em>bool</em>) – Whether to log the history of the autoencoder.</p></li>
<li><p><strong>log_hparams</strong> (<em>bool</em>) – Whether to log the hyperparameters used for training.</p></li>
<li><p><strong>n_components</strong> (<em>int</em>) – Number of components to fit to the data.</p></li>
<li><p><strong>output_path</strong> (<em>str</em>) – Path to the output directory.</p></li>
<li><p><strong>kmeans_loss</strong> (<em>float</em>) – Weight of the gram loss, which adds a regularization term to VQVAE models which penalizes the correlation between the dimensions in the latent space.</p></li>
<li><p><strong>pretrained</strong> (<em>str</em>) – Path to the pretrained weights to use for the autoencoder.</p></li>
<li><p><strong>save_checkpoints</strong> (<em>bool</em>) – Whether to save checkpoints during training.</p></li>
<li><p><strong>save_weights</strong> (<em>bool</em>) – Whether to save the weights of the autoencoder after training.</p></li>
<li><p><strong>input_type</strong> (<em>str</em>) – Input type of the TableDict objects used for preprocessing. For logging purposes only.</p></li>
<li><p><strong>interaction_regularization</strong> (<em>float</em>) – Weight of the interaction regularization term (L1 penalization to all features not related to interactions).</p></li>
<li><p><strong>run</strong> (<em>int</em>) – Run number to use for logging.</p></li>
<li><p><strong>parameters</strong> (<em># Contrastive Model specific</em>) – </p></li>
<li><p><strong>kl_annealing_mode</strong> (<em>str</em>) – Mode to use for KL annealing. Must be one of “linear” (default), or “sigmoid”.</p></li>
<li><p><strong>kl_warmup</strong> (<em>int</em>) – Number of epochs during which KL is annealed.</p></li>
<li><p><strong>reg_cat_clusters</strong> (<em>bool</em>) – whether to penalize uneven cluster membership in the latent space, by minimizing the KL divergence between cluster membership and a uniform categorical distribution.</p></li>
<li><p><strong>recluster</strong> (<em>bool</em>) – Whether to recluster the data after each training using a Gaussian Mixture Model.</p></li>
<li><p><strong>parameters</strong> – </p></li>
<li><p><strong>temperature</strong> (<em>float</em>) – temperature parameter for the contrastive loss functions. Higher values put harsher penalties on negative pair similarity.</p></li>
<li><p><strong>contrastive_similarity_function</strong> (<em>str</em>) – similarity function between positive and negative pairs. Must be one of ‘cosine’ (default), ‘euclidean’, ‘dot’, and ‘edit’.</p></li>
<li><p><strong>contrastive_loss_function</strong> (<em>str</em>) – contrastive loss function. Must be one of ‘nce’ (default), ‘dcl’, ‘fc’, and ‘hard_dcl’. See specific documentation for details.</p></li>
<li><p><strong>beta</strong> (<em>float</em>) – Beta (concentration) parameter for the hard_dcl contrastive loss. Higher values lead to ‘harder’ negative samples.</p></li>
<li><p><strong>tau</strong> (<em>float</em>) – Tau parameter for the dcl and hard_dcl contrastive losses, indicating positive class probability.</p></li>
</ul>
</dd>
<dt class="field-even">Returns<span class="colon">:</span></dt>
<dd class="field-even"><p>List of trained models corresponding to the selected model class. The full trained model is last.</p>
</dd>
</dl>
</dd></dl>

</section>


           </div>
          </div>
          <footer>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2024, Lucas Miranda.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>