stellargraph/stellargraph

View on GitHub
stellargraph/__init__.py

Summary

Maintainability
A
0 mins
Test Coverage
# -*- coding: utf-8 -*-
#
# Copyright 2018-2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


__all__ = [
    "data",
    "datasets",
    "calibration",
    "ensemble",
    "interpretability",
    "losses",
    "layer",
    "mapper",
    "utils",
    "custom_keras_layers",
    "StellarDiGraph",
    "StellarGraph",
    "GraphSchema",
    "__version__",
]

# Version
from .version import __version__

# Import modules
from stellargraph import (
    data,
    calibration,
    datasets,
    ensemble,
    interpretability,
    losses,
    layer,
    mapper,
    utils,
)

# Top-level imports
from stellargraph.core.graph import StellarGraph, StellarDiGraph
from stellargraph.core.indexed_array import IndexedArray
from stellargraph.core.schema import GraphSchema
import warnings

# Custom layers for keras deserialization (this is computed from a manual list to make it clear
# what's included)

# the `link_inference` module is shadowed in `sg.layer` by the `link_inference` function, so these
# layers need to be manually imported
from .layer.link_inference import (
    LinkEmbedding as _LinkEmbedding,
    LeakyClippedLinear as _LeakyClippedLinear,
)

custom_keras_layers = {
    class_.__name__: class_
    for class_ in [
        layer.GraphConvolution,
        layer.ClusterGraphConvolution,
        layer.GraphAttention,
        layer.GraphAttentionSparse,
        layer.SqueezedSparseConversion,
        layer.graphsage.MeanAggregator,
        layer.graphsage.MaxPoolingAggregator,
        layer.graphsage.MeanPoolingAggregator,
        layer.graphsage.AttentionalAggregator,
        layer.hinsage.MeanHinAggregator,
        layer.rgcn.RelationalGraphConvolution,
        layer.ppnp.PPNPPropagationLayer,
        layer.appnp.APPNPPropagationLayer,
        layer.misc.GatherIndices,
        layer.deep_graph_infomax.DGIDiscriminator,
        layer.deep_graph_infomax.DGIReadout,
        layer.graphsage.GraphSAGEAggregator,
        layer.knowledge_graph.ComplExScore,
        layer.knowledge_graph.DistMultScore,
        layer.knowledge_graph.RotatEScore,
        layer.knowledge_graph.RotHEScore,
        layer.preprocessing_layer.GraphPreProcessingLayer,
        layer.preprocessing_layer.SymmetricGraphPreProcessingLayer,
        layer.watch_your_step.AttentiveWalk,
        layer.sort_pooling.SortPooling,
        layer.gcn_lstm.FixedAdjacencyGraphConvolution,
        _LinkEmbedding,
        _LeakyClippedLinear,
    ]
}
"""
A dictionary of the ``tensorflow.keras`` layers defined by StellarGraph.

When Keras models using StellarGraph layers are saved, they can be loaded by passing this value to
the ``custom_objects`` parameter to model loading functions like
``tensorflow.keras.models.load_model``.

Example::

    import stellargraph as sg
    from tensorflow import keras
    keras.models.load_model("/path/to/model", custom_objects=sg.custom_keras_layers)
"""


def _top_level_deprecation_warning(name, path):
    warnings.warn(
        f"'{name}' is no longer available at the top-level. "
        f"Please use 'stellargraph.{path}.{name}' instead.",
        DeprecationWarning,
        stacklevel=3,
    )


def expected_calibration_error(*args, **kwargs):
    _top_level_deprecation_warning("expected_calibration_error", "calibration")
    return calibration.expected_calibration_error(*args, **kwargs)


def plot_reliability_diagram(*args, **kwargs):
    _top_level_deprecation_warning("plot_reliability_diagram", "calibration")
    return calibration.plot_reliability_diagram(*args, **kwargs)


def Ensemble(*args, **kwargs):
    _top_level_deprecation_warning("Ensemble", "ensemble")
    return ensemble.Ensemble(*args, **kwargs)


def BaggingEnsemble(*args, **kwargs):
    _top_level_deprecation_warning("BaggingEnsemble", "ensemble")
    return ensemble.BaggingEnsemble(*args, **kwargs)


def TemperatureCalibration(*args, **kwargs):
    _top_level_deprecation_warning("TemperatureCalibration", "calibration")
    return calibration.TemperatureCalibration(*args, **kwargs)


def IsotonicCalibration(*args, **kwargs):
    _top_level_deprecation_warning("IsotonicCalibration", "calibration")
    return calibration.IsotonicCalibration(*args, **kwargs)