eight0153/CartPole-NEAT

View on GitHub
neat/connection.py

Summary

Maintainability
A
0 mins
Test Coverage
A
97%
import random


class Connection:
    """A connection between two nodes in a neural network computation graph."""
    count = 0  # a count of unique nodes

    def __init__(self, target_id, input_id):
        """Create a connection between nodes.

        Arguments:
            target_id: The id of the node that receives the input.
            input_id: The id of the node that provides the input.
        """
        self.target_id = target_id
        self.input_id = input_id
        self.weight = random.gauss(0, 1)
        self.is_recurrent = False

        Connection.count += 1
        self.id = Connection.count
        self.object_id = id(self)

    def copy(self):
        """Make a copy of a connection.

        Returns: the copy of the connection.
        """
        copy = Connection(self.target_id, self.input_id)
        # copies of connections are not unique and therefore not counted.
        Connection.count -= 1
        copy.id = self.id
        copy.weight = self.weight
        copy.is_recurrent = self.is_recurrent

        return copy

    def to_json(self):
        """Encode the connection as JSON.

        Returns: the connection encoded as JSON.
        """
        return dict(
            target_id=self.target_id,
            input_id=self.input_id,
            id=self.id,
            object_id=self.object_id,
            weight=self.weight
        )

    @staticmethod
    def from_json(config):
        """Load a node object from JSON.

        Arguments:
            config: the JSON dictionary loaded from file.

        Returns: a connection object.
        """
        connection = Connection(config['target_id'], config['input_id'])
        connection.id = config['id']
        connection.object_id = config['object_id']
        connection.weight = config['weight']

        return connection

    def __str__(self):
        return 'Connection_{}->{}'.format(self.target_id, self.input_id) + \
               (' (recurrent)' if self.is_recurrent else '')

    def __eq__(self, other):
        return self.target_id == other.target_id and \
               self.input_id == other.input_id

    def __hash__(self):
        hash_code = 7
        hash_code += hash_code * self.target_id % 17
        hash_code += hash_code * self.input_id % 37

        return hash_code