rusty1s/embedded_gcnn

View on GitHub
lib/layer/embedded_gcnn.py

Summary

Maintainability
A
1 hr
Test Coverage
from six.moves import xrange

import tensorflow as tf

from .var_layer import VarLayer
from ..tf import base, sparse_tensor_diag_matmul


def conv(features, adj_dist, adj_rad, weights, K=2):
    P = weights.get_shape()[0].value - 1

    degree = tf.sparse_reduce_sum(adj_dist, axis=1)
    degree = degree + tf.ones_like(degree)
    degree = tf.cast(degree, tf.float32)

    features_rescaled = tf.reshape(tf.pow(degree, -1), [-1, 1]) * features
    output = tf.matmul(features_rescaled, weights[P])

    degree = tf.pow(degree, -0.5)
    adj_dist = sparse_tensor_diag_matmul(adj_dist, degree, transpose=True)
    adj_dist = sparse_tensor_diag_matmul(adj_dist, degree, transpose=False)

    for p in xrange(P):
        partition = base(adj_rad, K, P, p)

        # Note that we can perform element-wise multiplication on the two
        # adjacency matrices, although the sparse partition matrix has way less
        # elements than adj_dist. `base()` doesn't remove any element from
        # adj_rad and instead fills the irrelevant values with zeros. It is
        # nevertheless important that adj_dist and adj_rad have the same number
        # of elements with equal ordering.
        adj_values = tf.multiply(adj_dist.values, partition.values)
        adj = tf.SparseTensor(adj_dist.indices, adj_values,
                              adj_dist.dense_shape)

        output_p = tf.sparse_tensor_dense_matmul(adj, features)
        output_p = tf.matmul(output_p, weights[p])

        output += output_p

    return output


class EmbeddedGCNN(VarLayer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 adjs_dist,
                 adjs_rad,
                 local_controllability=2,
                 sampling_points=8,
                 **kwargs):

        self.adjs_dist = adjs_dist
        self.adjs_rad = adjs_rad
        self.K = local_controllability
        self.P = sampling_points

        super(EmbeddedGCNN, self).__init__(
            weight_shape=[self.P + 1, in_channels, out_channels],
            bias_shape=[out_channels],
            **kwargs)

    def _call(self, inputs):
        batch_size = len(inputs)
        outputs = []

        for i in xrange(batch_size):
            output = conv(inputs[i], self.adjs_dist[i], self.adjs_rad[i],
                          self.vars['weights'], self.K)

            if self.bias:
                output = tf.nn.bias_add(output, self.vars['bias'])

            output = self.act(output)
            outputs.append(output)

        return outputs