piotrmaslanka/satella

View on GitHub
satella/coding/structures/sparse_matrix.py

Summary

Maintainability
A
0 mins
Test Coverage
A
95%
import collections
import typing as tp

from satella.coding.recast_exceptions import silence_excs
from satella.coding.typing import T

KeyArg = tp.Tuple[tp.Union[int, slice], tp.Union[int, slice]]


def _cleanup_key(inst: tp.Union[int, slice], max_count: int):
    if isinstance(inst, slice):
        if not (inst.start is None and inst.stop is None and inst.step is None):
            raise IndexError('Custom slicing is not supported!')
        return Ellipsis
    elif isinstance(inst, int):
        if isinstance(inst, int):
            if inst < 0:
                inst += max_count
        return inst


class SparseMatrix(tp.Generic[T]):
    """
    A matrix of infinite size, that supports assignments.

    Set elements like this:

    >>> sm[1, 2] = 5
    >>> sm[:,1] = [5]
    >>> sm[1,:] = [5]
    >>> sm[:,:] = [[5]]

    where the first argument is the number of the column, counted from 0,
    and the second one is the number of the row, also counted from 0

    Note that custom slicing (ie. slices which are not :) will not be supported.
    Negative indices are supported.

    Undefined elements are considered to be of value None.

    Iterating over this matrix will yield it's consecutive rows.

    You can use the constructor in following way:

    >>> sm = SparseMatrix([[1, 2], [3, 4]])

    to construct a matrix that looks like

    ::

        |1 2|
        |3 4|
    """
    __slots__ = ('rows_dict', 'known_column_count', 'no_cols', 'no_rows')

    def max(self) -> T:
        """
        Return maximum element.

        None elements will be ignored.
        """
        item = self[0, 0]
        for row in self:
            for value in row:
                if value is None:
                    continue
                if value > item:
                    item = value
        return item

    def min(self) -> T:
        """
        Return minimum element.

        None elements will be ignored.
        """
        item = self[0, 0]
        for row in self:
            for value in row:
                if value is None:
                    continue

                if value < item:
                    item = value
        return item

    def __init__(self, matrix_data: tp.Optional[tp.List[tp.List[T]]] = None):
        self.rows_dict = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
        self.known_column_count = {}  # tp.Dict[int, int] column_no => amount
        self.no_cols = 0
        self.no_rows = 0
        if matrix_data:
            self[:, :] = matrix_data

    def get_neighbour_coordinates(self, col: int, row: int,
                                  include_diagonals: bool = True) -> tp.Iterator[
        tp.Tuple[int, int]]:
        """
        Return an iterator of coordinates to points neighbouring given point.
        :param col: column
        :param row: row
        :param include_diagonals: whether to include points having only a single point in common
        :return: an iterable of coordinates of neighbouring points. An iterator of tuple
            (col, row)
        """

        for delta_row in (-1, 0, 1):
            for delta_col in (-1, 0, 1):
                if not delta_row and not delta_col:
                    continue
                if abs(delta_col) + abs(delta_row) > 1:
                    if not include_diagonals:
                        continue
                cand_col = col + delta_col
                cand_row = row + delta_row
                if cand_col < 0 or cand_row < 0:
                    continue
                if cand_col >= self.no_cols or cand_row >= self.no_rows:
                    continue
                yield cand_col, cand_row

    def append_row(self, y: tp.Iterable[T]) -> None:
        """
        Append a row to the bottom of the matrix

        :param y: iterable with consequent columns
        """
        next_row = self.no_rows
        for col_no, z in enumerate(y):
            self[col_no, next_row] = z

    def clear(self) -> None:
        """
        Clear the contents of the sparse matrix
        """
        self.rows_dict = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
        self.known_column_count = {}  # tp.Dict[int, int] column_no => amount
        self.no_cols = 0
        self.no_rows = 0

    def __eq__(self, other: 'SparseMatrix') -> bool:
        return self.rows_dict == other.rows_dict

    def __bool__(self) -> bool:
        return self.no_rows == 0

    def _sanitize_key(self, key: KeyArg) -> KeyArg:
        col, row = key
        return _cleanup_key(col, self.no_cols), _cleanup_key(row, self.no_rows)

    @classmethod
    def from_iterable(cls, y: tp.Iterable[tp.Iterable[T]]):
        """
        Construct a sparse matrix given a row-first iterable. That iterable must
        return another iterable, that will yield values for given column.

        :param y: an iterable describing the sparse matrix
        :return: a sparse matrix object
        """
        sm = SparseMatrix()
        for row_no, cols in enumerate(y):
            for col_no, value in enumerate(cols):
                sm[col_no, row_no] = value
        return sm

    @property
    def columns(self) -> int:
        """Return the amount of columns"""
        return self.no_cols

    @property
    def rows(self) -> int:
        """Return the amount of rows"""
        return self.no_rows

    def __iter__(self) -> tp.Iterator[tp.List]:
        return (self.get_row(i) for i in range(self.no_rows))

    def __len__(self) -> int:
        return self.no_rows

    def get_row(self, row_no: int) -> tp.List[T]:
        """
        Return a single row of provided number.

        The returned array has the same length as .columns

        :param row_no: row number, numbered from 0
        """
        if row_no not in self.rows_dict:  # check so as to avoid adding new entries
            output = [None] * self.no_cols
        else:
            cols = self.rows_dict[row_no]
            output = []
            for i in range(self.no_cols):
                if i in cols:
                    output.append(cols[i])
                else:
                    output.append(None)
        return output

    def shoot(self) -> 'SparseMatrix':
        """
        Insert an empty cell between current cells. So the matrix which looked like
        [[1, 2], [3, 4]] will now look like [[1, None, 2], [None, None, None], [3, None, 4]]
        """
        new_sparse = SparseMatrix()
        for row_no, row in enumerate(self):
            for col_no, value in enumerate(row):
                new_sparse[col_no * 2, row_no * 2] = value
        return new_sparse

    def _increment_column_count(self, col_no: int) -> None:
        if col_no not in self.known_column_count:
            self.known_column_count[col_no] = 1
        else:
            self.known_column_count[col_no] += 1

    def _decrement_column_count(self, col_no: int) -> None:
        if self.known_column_count[col_no] == 1:
            del self.known_column_count[col_no]
        else:
            self.known_column_count[col_no] -= 1

    def delete_row(self, row_no: int) -> None:
        """
        Delete a row with specified number

        :param row_no: number of the row to delete
        """
        cols = list(self.rows_dict[row_no].keys())  # Copy it here
        for col_no in cols:
            del self[col_no, row_no]

    def __setitem__(self, key: KeyArg,
                    value: tp.Union[tp.Iterable[T], T]) -> None:
        """
        Use this to update either a single piece of the matrix, or entire row.

        Use like:

        >>> sm[1, 2] = 5
        >>> sm[...,1] = [5]
        >>> sm[1,...] = [5]
        >>> sm[...,...] = [[5]]
        >>> sm[:,1] = [5]
        >>> sm[1,:] = [5]
        >>> sm[:,:] = [[5]]

        Note that setting an element to None is the same as deleting it.
        """
        if value is None:
            del self[key]
            return

        col, row = self._sanitize_key(key)

        if col is Ellipsis and row is Ellipsis:
            sm = SparseMatrix.from_iterable(value)
            self.rows_dict = sm.rows_dict
            self.no_rows = sm.no_rows
            self.no_cols = sm.no_cols
            self.known_column_count = sm.known_column_count
        elif col is Ellipsis:
            for col_no, val in enumerate(value):
                self[col_no, row] = val
        elif row is Ellipsis:
            for row_no, val in enumerate(value):
                self[col, row_no] = val
        else:
            if col >= self.no_cols:
                self.no_cols = col + 1
            if row >= self.no_rows:
                self.no_rows = row + 1

            if row not in self.rows_dict:
                self._increment_column_count(col)
            elif col not in self.rows_dict[row]:
                self._increment_column_count(col)
            self.rows_dict[row][col] = value

    def __getitem__(self,
                    item: KeyArg) -> tp.Union[tp.List[T], tp.List[tp.List[T]], T]:
        col, row = self._sanitize_key(item)

        if col is Ellipsis and row is Ellipsis:
            v = list(self)
        elif col is Ellipsis:
            v = [self[col_no, row] for col_no in range(self.no_cols)]
        elif row is Ellipsis:
            v = [self[col, row_no] for row_no in range(self.no_rows)]
        else:
            if row >= self.no_rows:
                raise IndexError()
            elif col >= self.no_cols:
                raise IndexError()

            if row not in self.rows_dict:  # check so as to avoid adding new entries
                v = None
            elif col not in self.rows_dict[row]:
                v = None
            else:
                v = self.rows_dict[row][col]
        return v

    @silence_excs(TypeError, returns=0)
    def _calculate_column_count(self) -> int:
        return max(self.known_column_count) + 1

    @silence_excs(TypeError, returns=0)
    def _calculate_row_count(self) -> int:
        return max(self.rows_dict) + 1

    @silence_excs(KeyError)
    def __delitem__(self, key: KeyArg) -> None:
        col, row = self._sanitize_key(key)

        if row is Ellipsis and col is Ellipsis:
            self.clear()
        elif col is Ellipsis:
            for col_no in range(self.no_cols):
                del self[col_no, row]
        elif row is Ellipsis:
            for row_no in range(self.no_rows):
                del self[col, row_no]
        else:
            # Check if the element is there
            if (row not in self.rows_dict) or (col not in self.rows_dict[row]):
                return

            del self.rows_dict[row][col]

            if not self.rows_dict[row]:  # Have we got an empty row now?
                del self.rows_dict[row]

            self._decrement_column_count(col)

            self.no_cols = self._calculate_column_count()
            self.no_rows = self._calculate_row_count()