monarch-initiative/N2V

View on GitHub
embiggen/embedders/pykeen_embedders/distmult.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Submodule providing wrapper for PyKEEN's DistMult model."""
from typing import Union, Type, Dict, Any, Optional
from pykeen.training import TrainingLoop
from pykeen.models import DistMult
from embiggen.embedders.pykeen_embedders.entity_relation_embedding_model_pykeen import EntityRelationEmbeddingModelPyKEEN
from pykeen.triples import CoreTriplesFactory


class DistMultPyKEEN(EntityRelationEmbeddingModelPyKEEN):

    @classmethod
    def model_name(cls) -> str:
        """Return name of the model."""
        return "DistMult"

    def _build_model(
        self,
        triples_factory: CoreTriplesFactory
    ) -> DistMult:
        """Build new DistMult model for embedding.

        Parameters
        ------------------
        graph: Graph
            The graph to build the model for.
        """
        return DistMult(
            triples_factory=triples_factory,
            embedding_dim=self._embedding_size,
            random_seed=self._random_state
        )