sidekick-seq/sidekick/seq/iter.py
import itertools
import operator
from functools import wraps, cached_property
from .._utils import safe_repr
from ..functions import fn
from ..typing import Iterator, Tuple, T, TYPE_CHECKING
if TYPE_CHECKING:
from .. import api as sk # noqa: F401
NOT_GIVEN = object()
_iter = iter
class Iter(Iterator[T]):
"""
Base sidekick iterator class.
This class extends classical Python iterators with a few extra operators.
Sidekick iterators accepts slicing, indexing, concatenation (with the + sign)
repetition (with the * sign) and pretty printing.
Operations that return new iterators (e.g., slicing, concatenation, etc)
consume the data stream. Operations that simply peek at data execute the
generator (and thus may produce side-effects), but cache values and do not
consume data stream.
"""
__slots__ = ("_iterator", "_size_hint")
_iterator: Iterator[T]
if TYPE_CHECKING:
from .. import seq as _mod
_mod = _mod
else:
@cached_property
def _mod(self):
from .. import seq
return seq
def __new__(cls, iterator: Iterator[T], size_hint: int = None):
if isinstance(iterator, Iter):
return iterator
new = object.__new__(cls)
new._iterator = _iter(iterator)
new._size_hint = size_hint
return new
def __next__(self, _next=next):
return _next(self._iterator)
def __iter__(self):
return self._iterator
def __repr__(self):
it = self._iterator
head = []
for _ in range(7):
try:
head.append(next(it))
except StopIteration:
display = map(safe_repr, head)
self._iterator = _iter(head)
self._size_hint = len(head)
break
except Exception as ex:
ex_name = type(ex).__name__
display = [*map(safe_repr, head), f"... ({ex_name})"]
self._iterator = yield_and_raise(head, ex)
self._size_hint = len(head)
break
else:
self._iterator = itertools.chain(_iter(head), it)
display = [*map(safe_repr, head[:-1]), "..."]
data = ", ".join(display)
return f"sk.iter([{data}])"
def __getitem__(self, item, _chain=itertools.chain):
if isinstance(item, int):
if item >= 0:
head = []
for i, x in enumerate(self._iterator):
head.append(x)
if i == item:
self._iterator = _chain(head, self._iterator)
return x
else:
self._iterator = _iter(head)
self._size_hint = len(head)
raise IndexError(item)
else:
raise IndexError("negative indexes are not supported")
elif isinstance(item, slice):
a, b, c = item.start, item.step, item.stop
return Iter(itertools.islice(self._iterator, a, b, c))
elif callable(item):
return Iter(filter(item, self._iterator), self._size_hint)
elif isinstance(item, list):
if not item:
return []
if isinstance(item[0], bool):
self._iterator, data = itertools.tee(self._iterator, 2)
return [x for key, x in zip(item, data) if key]
elif isinstance(item[0], int):
self._iterator, data = itertools.tee(self._iterator, 2)
data = list(itertools.islice(data, max(item) + 1))
return [data[i] for i in item]
else:
raise TypeError("index must contain only integers or booleans")
else:
size = operator.length_hint(item, -1)
size = None if size == -1 else size
return Iter(compress_or_select(item, self._iterator), size)
def __add__(self, other, _chain=itertools.chain):
if hasattr(other, "__iter__"):
return Iter(_chain(self._iterator, other))
return NotImplemented
def __radd__(self, other, _chain=itertools.chain):
if hasattr(other, "__iter__"):
return Iter(_chain(other, self._iterator))
return NotImplemented
def __iadd__(self, other, _chain=itertools.chain):
self._iterator = _chain(self._iterator, other)
def __mul__(self, other):
if isinstance(other, int):
if other < 0:
raise ValueError("cannot multiply by negative integers")
return Iter(cycle_n(self._iterator, other))
try:
data = _iter(other)
except TypeError:
return NotImplemented
return Iter(itertools.product([self._iterator, data]))
def __rmul__(self, other):
if isinstance(other, int):
return self.__mul__(other)
try:
data = _iter(other)
except TypeError:
return NotImplemented
return Iter(itertools.product([data, self._iterator]))
def __rmatmul__(self, func):
if callable(func):
return Iter(map(func, self._iterator), self._size_hint)
return NotImplemented
def __length_hint__(self):
if self._size_hint is None:
return operator.length_hint(self._iterator)
return self._size_hint
#
# Conversion to collections
#
def list(self) -> list:
"""
Convert iterator to list consuming iterator.
Infinite operators do not terminate.
"""
return list(self)
def tuple(self) -> tuple:
"""
Convert iterator to tuple consuming iterator.
Infinite operators do not terminate.
"""
return tuple(self)
def set(self) -> set:
"""
Convert iterator to tuple consuming iterator.
Infinite operators do not terminate.
"""
return set(self)
def frozenset(self) -> frozenset:
"""
Convert iterator to tuple consuming iterator.
Infinite operators do not terminate.
"""
return frozenset(self)
def str(self) -> str:
"""
Convert iterator to string consuming iterator and concatenating
elements.
Infinite operators do not terminate.
"""
return "".join(self)
def bytes(self) -> str:
"""
Convert iterator to bytes consuming iterator and concatenating
elements.
Infinite operators do not terminate.
"""
return b"".join(self)
#
# API
#
def copy(self) -> "Iter":
"""
Return a copy of iterator. Consuming the copy do not consume the
original iterator.
Internally, this method uses itertools.tee to perform the copy. If you
known that the iterator will be consumed, it is faster and more memory
efficient to convert it to a list and produce multiple iterators.
"""
self._iterator, other = itertools.tee(self._iterator, 2)
return Iter(other, self._size_hint)
def tee(self, n=1) -> Tuple["Iter", ...]:
"""
Split iterator into n additional copies.
The copy method is simply an alias to iter.tee(1)[0]
"""
self._iterator, *rest = itertools.tee(self._iterator, n + 1)
n = self._size_hint
return tuple(Iter(it, n) for it in rest)
def peek(self, n: int) -> Tuple:
"""
Peek the first n elements without consuming the iterator.
"""
data = tuple(itertools.islice(self._iterator, n))
self._iterator = itertools.chain(data, self._iterator)
return data
#
# Wrapping the iterator API
#
def cycle_n(seq, n):
data = []
store = data.append
consumed = False
while n > 0:
if consumed:
yield from data
else:
for x in seq:
store(x)
yield x
if data:
consumed = True
else:
return
n -= 1
def compress(keys, seq):
for x, pred in zip(seq, keys):
if pred:
yield x
def select(keys, seq):
data = []
for i in keys:
try:
yield data[i]
except IndexError:
data.extend(itertools.islice(seq, i - len(data) + 1))
yield data[i]
def compress_or_select(keys, seq):
keys = _iter(keys)
seq = _iter(seq)
try:
key = next(keys)
if key is True:
func = compress
yield next(seq)
elif key is False:
func = compress
next(seq)
elif isinstance(key, int):
func = select
keys = itertools.chain([key], keys)
else:
raise TypeError(f"invalid key: {key!r}")
except StopIteration:
return
yield from func(keys, seq)
@fn
def generator(func):
"""
Decorates generator function to return a sidekick iterator instead of a
regular Python generator.
Examples:
>>> @sk.generator
... def fibonacci():
... x = y = 1
... while True:
... yield x
... x, y = y, x + y
>>> fibonacci()
sk.iter([1, 1, 2, 3, 5, 8, ...])
"""
@fn
@wraps(func)
def gen(*args, **kwargs):
return Iter(func(*args, **kwargs))
return gen
def stop(x=None):
"""
Raise StopIteration with the given argument.
"""
raise StopIteration(x)
def yield_and_raise(data, exc):
"""
Return content from data and then raise exception afterwards.
"""
yield from data
raise exc
fn.generator = staticmethod(generator)