monarch-initiative/N2V

View on GitHub
embiggen/embedders/graph_embedding_pipeline.py

Summary

Maintainability
A
1 hr
Test Coverage
"""Module offering pipeline to embed graph nodes, node types and edge types."""
from typing import Union, Dict, Optional, Type
from ensmallen import Graph

from embiggen.utils.pipeline import iterate_graphs
from embiggen.utils.abstract_models import AbstractEmbeddingModel, EmbeddingResult



def embed_graph(
    graph: Union[Graph, str],
    embedding_model: Union[str, Type[AbstractEmbeddingModel]],
    repository: Optional[str] = None,
    version: Optional[str] = None,
    library_name: Optional[str] = None,
    smoke_test: bool = False,
    return_dataframe: bool = True,
    **kwargs: Dict
) -> EmbeddingResult:
    """Return embedding of the provided graph.

    Parameters
    ---------------------
    graph: Union[Graph, str]
        The graph or graph name to embed.
        If a graph name is provided, it will be retrieved from Ensmallen's automatic retrieval.
    embedding_model: Union[str, Type[AbstractEmbeddingModel]]
        Model or model name to use.
        If a model name is provided, it will be retrieved from the models' library.
    repository: Optional[str] = None
        Repository from where to retrieve the provided graph names
        from the Ensmallen automatic retrieval.
    version: Optional[str] = None
        Graph version to retrieve.
    library_name: Optional[str] = None
        The library from where to retrieve the embedding model.
    ring_bell: bool = False,
        enable_cache: bool = False
        Whether to enable the cache.
    smoke_test: bool = False
        Whether this run should be considered a smoke test
        and therefore use the smoke test configurations for
        the provided model names and feature names.
    cache_directory: str = "embedding"
        Path where to store the cache if it is enabled.
    return_dataframe: bool = True
        Whether to return a pandas DataFrame with the embedding.
    **kwargs: Dict
        Kwargs to forward to the embedding model creation.
        If a model name was NOT provided, an exception will
        be raised as it is unclear how to behave.
    """
    graph = next(iterate_graphs(
        graphs=graph,
        repositories=repository,
        versions=version
    ))

    if isinstance(embedding_model, str):
        embedding_model: Type[AbstractEmbeddingModel] = AbstractEmbeddingModel.get_model_from_library(
            model_name=embedding_model,
            library_name=library_name,
        )(**kwargs)
    elif kwargs:
        raise ValueError(
            "Please be advised that even though you have provided yourself "
            "the embedding model, you have also provided the kwargs which "
            "would normally be forwarded to the creation of the embedding "
            "model. It is unclear what to do with these arguments."
        )

    if not issubclass(embedding_model.__class__, AbstractEmbeddingModel):
        raise ValueError(
            "The provided object is not an embedding model, that is, "
            "it does not extend the class `AbstractEmbeddingModel`."
        )

    if smoke_test:
        try:
            embedding_model = embedding_model.into_smoke_test()
        except Exception as e:
            raise ValueError(
                "An exception was raised while trying to create "
                f"a smoke test version of the model called {embedding_model.model_name()} "
                f"from the library {library_name}, specifically "
                f"implemented in the class {embedding_model.__class__.__name__}. "
                "Most likely there is an error in this class smoke test parameters. "
                f"The body of the exception was: {str(e)}. "
            ) from e

    if embedding_model.requires_nodes_sorted_by_decreasing_node_degree():
        graph = graph.sort_by_decreasing_outbound_node_degree()

    try:
        return embedding_model.fit_transform(
            graph,
            return_dataframe=return_dataframe,
        )
    except Exception as e:
        raise ValueError(
            "An exception was raised while trying to compute "
            f"a node embedding on the graph {graph.get_name()} "
            f"using the model called {embedding_model.model_name()} "
            f"from the library {library_name}, specifically "
            f"implemented in the class {embedding_model.__class__.__name__}. "
            f"The body of the exception was: {str(e)}"
        ) from e