eight0153/CartPole-NEAT

View on GitHub
neat/graph.py

Summary

Maintainability
A
1 hr
Test Coverage
A
91%
"""Describes a computation graph."""
from collections import defaultdict
from enum import Enum

from neat.connection import Connection
from neat.node import Node, Sensor, Output, Activations


class Verbosity(Enum):
    """An enum capturing different verbosity levels for logging."""
    SILENT = 0
    MINIMAL = 1
    FULL = 2


class InvalidGraphError(Exception):
    """An error that occurs when a graph is tried to be compiled but it does
    not have a valid structure.
    """
    pass


class GraphNotCompiledError(Exception):
    """An error that occurs when a graph is tried to be used but has not
    been compiled.
    """
    pass


class InvalidGraphInputError(Exception):
    """An error that occurs when the input to a graph does not match the number
    of input nodes.
    """
    pass


class Graph:
    """A computation graph for arbitrary neural networks that allow recurrent
    connections.
    """

    def __init__(self, verbosity=Verbosity.SILENT):
        """Initialise the graph.

        Arguments:
            verbosity: how much should be printed to console.
        """
        self.nodes = {}
        self.sensors = []
        self.outputs = []
        self.connections = []
        self.connections_dict = defaultdict(lambda: [])

        self.verbosity = verbosity
        self.is_compiled = False

    def copy(self):
        """Make a copy of a graph.

        Returns: the copy of the graph.
        """
        copy = Graph()

        for node in self.nodes:
            copy.add_node(self.nodes[node].copy())

            for connection in self.connections_dict[node]:
                copy.connections_dict[node].append(connection.copy())

        # If a graph is copied as-is, then it should still be compiled if the
        # original was compiled, and not compiled if the other was not
        # compiled.
        copy.is_compiled = self.is_compiled

        return copy

    def compile(self):
        """Make sure the graph is valid and prepare it for computation.

        Throws: InvalidGraphError
        """
        if not self.sensors:
            raise InvalidGraphError('Graph needs at least one sensor (input)'
                                    'node.')

        if not self.outputs:
            raise InvalidGraphError('Graph needs at least one output node.')

        has_path_to_input = False

        for output in self.outputs:
            self._mark_recurrent_inputs(output)
            has_path_to_input |= self._has_path_to_input(output)

        if not has_path_to_input:
            raise InvalidGraphError('Graph needs at least one sensor (input) '
                                    'to be connected to an output.')

        for node in self.nodes:
            self.connections += self.connections_dict[node]

        self.is_compiled = True

    def _mark_recurrent_inputs(self, node_id, visited=None):
        """Mark recurrent connections (i.e. cycles) in the graph.

        Arguments:
            node_id: the id (position in the nodes list) of the current node
                     that is being evaluated. This should initially be set
                     to a terminal node (such as an output node).
            visited: the list of visited nodes in the search. Can also be
                     thought of the current node's ancestor nodes. Initially
                     this should be an empty set.
        """
        if visited is None:
            visited = set()

        visited.add(node_id)

        for input_connection in self.connections_dict[node_id]:
            if input_connection.input_id in visited:
                input_connection.is_recurrent = True
            else:
                self._mark_recurrent_inputs(input_connection.input_id,
                                            visited.copy())

    def _has_path_to_input(self, node_id, visited=None):
        """Check if the given node has a path to the input.

        This is generally needed to check the the graph has at least one
        input connected to an output.

        Arguments:
            node_id: the id of the node that should be checked for a path to
                     an input node.
            visited: the list of visited nodes in the search. Initially this
                     should be an empty set.

        Returns: True if a path exists to an input node, False otherwise.
        """
        if visited is None:
            visited = set()

        visited.add(node_id)

        for node_input in self.connections_dict[node_id]:
            if node_input.input_id not in visited \
                    and self._has_path_to_input(node_input.input_id,
                                                visited.copy()):
                return True

        return isinstance(self.nodes[node_id], Sensor)

    def add_node(self, node):
        """Add a node to the graph.

        Arguments:
            node: The node to be added.
        """
        self.nodes[node.id] = node

        if isinstance(node, Sensor):
            self.sensors.append(node.id)
        elif isinstance(node, Output):
            self.outputs.append(node.id)

        # Adding a node may break the graph so we force the graph to be
        # compiled again to enforce a re-run of sanity and validity checks.
        self.is_compiled = False

    def add_nodes(self, nodes):
        """Helper function to add a list of nodes to the graph.

        Arguments:
            nodes: a list of nodes that are to be added to the graph.
        """
        for node in nodes:
            self.add_node(node)

    def add_connection(self, connection):
        """Add a connection directly to the graph.

        Arguments:
            connection: The Connection object to be added to the graph.
        """
        self.connections_dict[connection.target_id].append(connection)

        # Adding a connection may break the graph so we force the graph to be
        # compiled again to enforce a re-run of sanity and validity checks.
        self.is_compiled = False

    def add_input(self, node_id, other_id):
        """Add an input (form a connection) to a node.

        Arguments:
            node_id: the id of the node that will receive the input.
            other_id: the id of the node that will provide the input.
        """
        self.connections_dict[node_id].append(Connection(node_id, other_id))

        # Adding a connection may break the graph so we force the graph to be
        # compiled again to enforce a re-run of sanity and validity checks.
        self.is_compiled = False

    @property
    def recurrent_connections(self):
        return list(filter(lambda c: c.is_recurrent, self.connections))

    def compute(self, x):
        """Compute the output of the neural network graph.

        Arguments:
            x: the input vector (one dimensional).

        Returns: the softmax output of the neural network graph.
        """
        if not self.is_compiled:
            raise GraphNotCompiledError('The graph must be compiled before '
                                        'being used, or after a change '
                                        'occurred to the graph structure.')

        if len(x) != len(self.sensors):
            raise InvalidGraphInputError('The input dimensions do not match '
                                         'the number of input nodes in the '
                                         'graph.')

        for node in self.nodes:
            self.nodes[node].prev_output = self.nodes[node].output

        for x, sensor in zip(x, self.sensors):
            self.nodes[sensor].output = x

        network_output = []

        for output in self.outputs:
            network_output.append(self._compute_output(output))

        if len(network_output) == 1:
            return network_output[0]
        else:
            return Activations.softmax(network_output)

    def _compute_output(self, node_id, level=0):
        """Compute the output of a node.

        Arguments:
            node_id: the id of the node whose output should be computed.
            level: the level, or depth, the current node relative to the
                   starting node (typically an output node).

        Returns: the output of the node.
        """
        node = self.nodes[node_id]

        node_output = node.output if node.id in self.sensors else node.bias

        for input_connection in self.connections_dict[node_id]:
            target = self.nodes[input_connection.input_id]

            if input_connection.is_recurrent:
                node_output += input_connection.weight * target.prev_output
            else:
                node_output += input_connection.weight * \
                               self._compute_output(input_connection.input_id,
                                                    level=level + 1)

        node.output = node.activation(node_output)

        return node.output

    def print_connections(self):
        """Print the connections (inputs) of every node in the graph."""
        for node in self.nodes:
            for input_connection in self.connections_dict[node]:
                print(input_connection)

    def print(self, msg, format_args=None, verbosity=Verbosity.MINIMAL):
        """Print a message whose visibility is controlled by the verbosity of
        the message and the graphs verbosity setting.

        Arguments:
            msg: The string to print.
            format_args: any arguments needed for string formatting.
            verbosity: The verbosity of the message to print.
        """
        if self.verbosity.value >= verbosity.value:
            if format_args:
                print(msg % format_args)
            else:
                print(msg)

    def __len__(self):
        return len(self.nodes) + len(self.connections)

    def to_json(self):
        """Encode a graph as JSON.

        Returns: the graph encoded as JSON.
        """
        return dict(
            nodes=[node.to_json() for node in self.nodes.values()],
            connections=[connection.to_json()
                         for connection in self.connections],
        )

    @staticmethod
    def from_json(config):
        """Load a graph object from JSON.

        Arguments:
            config: the JSON dictionary loaded from file.

        Returns: a graph object.
        """
        graph = Graph()

        graph.add_nodes([Node.from_json(node)
                         for node in config['nodes']])
        for connection in [Connection.from_json(connection)
                           for connection in config['connections']]:
            graph.add_connection(connection)

        graph.compile()

        return graph