"""
Collection of functionality to make using iterators transparently easier
"""
__all__ = ("partition", "expandable_chain", "caching_iter", "iter_sort")
import itertools
from collections import deque
[docs]
def partition(iterable, predicate=bool):
"""Partition an iterable into two iterables based on a given filter.
Taking care that the predicate is called only once for each element.
:param iterable: target iterable to split into two
:param predicate: filtering function used to split the iterable
:return: A tuple of iterators, the first containing items that don't match the
filter and the second the matched items.
"""
a, b = itertools.tee((predicate(x), x) for x in iterable)
return ((x for pred, x in a if not pred), (x for pred, x in b if pred))
[docs]
class expandable_chain:
"""
chained iterables, with the ability to add new iterables to the chain
as long as the instance hasn't raised ``StopIteration`` already. This is
fairly useful for implementing queues of things that must be processed.
>>> from snakeoil.iterables import expandable_chain
>>> l = range(5)
>>> i = expandable_chain(l)
>>> print(i.next())
0
>>> print(i.next())
1
>>> i.appendleft(range(5, 7))
>>> print(i.next())
5
>>> print(i.next())
6
>>> print(i.next())
2
"""
__slot__ = ("iterables", "__weakref__")
def __init__(self, *iterables):
"""
accepts N iterables, must have at least one specified
"""
self.iterables = deque()
self.extend(iterables)
def __iter__(self):
return self
def __next__(self):
if self.iterables is not None:
while self.iterables:
try:
return next(self.iterables[0])
except StopIteration:
self.iterables.popleft()
self.iterables = None
raise StopIteration()
[docs]
def append(self, iterable):
"""append an iterable to the chain to be consumed"""
if self.iterables is None:
raise StopIteration()
self.iterables.append(iter(iterable))
[docs]
def appendleft(self, iterable):
"""prepend an iterable to the chain to be consumed"""
if self.iterables is None:
raise StopIteration()
self.iterables.appendleft(iter(iterable))
[docs]
def extend(self, iterables):
"""extend multiple iterables to the chain to be consumed"""
if self.iterables is None:
raise StopIteration()
self.iterables.extend(iter(x) for x in iterables)
[docs]
def extendleft(self, iterables):
"""prepend multiple iterables to the chain to be consumed"""
if self.iterables is None:
raise StopIteration()
self.iterables.extendleft(iter(x) for x in iterables)
[docs]
class caching_iter:
"""
On demand consumes from an iterable so as to appear like a tuple
>>> from snakeoil.iterables import caching_iter
>>> i = iter(range(5))
>>> ci = caching_iter(i)
>>> print(ci[0])
0
>>> print(ci[2])
2
>>> print(i.next())
3
"""
__slots__ = ("iterable", "__weakref__", "cached_list", "sorter")
def __init__(self, iterable, sorter=None):
self.sorter = sorter
self.iterable = iter(iterable)
self.cached_list = []
def __setitem__(self, key, val):
raise TypeError("unmodifiable")
def __getitem__(self, index):
existing_len = len(self.cached_list)
if self.iterable is not None and self.sorter:
self.cached_list.extend(self.iterable)
self.cached_list = tuple(self.sorter(self.cached_list))
self.iterable = self.sorter = None
existing_len = len(self.cached_list)
if index < 0:
if self.iterable is not None:
self.cached_list = tuple(self.cached_list + list(self.iterable))
self.iterable = None
existing_len = len(self.cached_list)
index = existing_len + index
if index < 0:
raise IndexError("list index out of range")
elif index >= existing_len - 1:
if self.iterable is not None:
i = itertools.islice(self.iterable, 0, index - (existing_len - 1))
self.cached_list.extend(i)
if len(self.cached_list) - 1 != index:
# consumed, baby.
self.iterable = None
self.cached_list = tuple(self.cached_list)
raise IndexError("list index out of range")
return self.cached_list[index]
def _flatten(self):
if self.iterable is not None:
if self.sorter:
self.cached_list.extend(self.iterable)
self.cached_list = tuple(self.sorter(self.cached_list))
self.sorter = None
else:
self.cached_list = tuple(self.cached_list + list(self.iterable))
self.iterable = None
def __lt__(self, other):
self._flatten()
for x, y in itertools.zip_longest(self.cached_list, other):
if x != y:
return x < y
return False
def __gt__(self, other):
self._flatten()
for x, y in itertools.zip_longest(self.cached_list, other):
if x != y:
return x > y
return False
def __le__(self, other):
return self.__lt__(other) or self.__eq__(other)
def __ge__(self, other):
return not self.__lt__(other)
def __eq__(self, other):
self._flatten()
return self.cached_list == other
def __ne__(self, other):
return not self.__eq__(other)
def __bool__(self):
if self.cached_list:
return True
if self.iterable:
for x in self.iterable:
self.cached_list.append(x)
return True
# if we've made it here... then nothing more in the iterable.
self.iterable = self.sorter = None
self.cached_list = ()
return False
def __len__(self):
if self.iterable is not None:
self.cached_list.extend(self.iterable)
if self.sorter:
self.cached_list = tuple(self.sorter(self.cached_list))
self.sorter = None
else:
self.cached_list = tuple(self.cached_list)
self.iterable = None
return len(self.cached_list)
def __iter__(self):
if self.sorter is not None and self.iterable is not None:
if self.cached_list:
self.cached_list.extend(self.iterable)
self.cached_list = tuple(self.sorter(self.cached_list))
else:
self.cached_list = tuple(self.sorter(self.iterable))
self.iterable = self.sorter = None
for x in self.cached_list:
yield x
if self.iterable is not None:
for x in self.iterable:
self.cached_list.append(x)
yield x
else:
return
self.iterable = None
self.cached_list = tuple(self.cached_list)
def __hash__(self):
if self.iterable is not None:
self.cached_list.extend(self.iterable)
self.cached_list = tuple(self.cached_list)
self.iterable = None
return hash(self.cached_list)
def __str__(self):
return "iterable(%s), cached: %s" % (self.iterable, str(self.cached_list))
[docs]
def iter_sort(sorter, *iterables):
"""Merge a number of sorted iterables into a single sorted iterable.
:type sorter: callable.
:param sorter: function, passed a list of [element, iterable].
:param iterables: iterables to consume from. It's **required**
that each iterable to consume from is presorted already within
that specific iterable.
:return: yields items one by one in combined sorted order
For example:
>>> from snakeoil.iterables import iter_sort
>>> iter1 = range(0, 5, 2)
>>> iter2 = range(1, 6, 2)
>>> # note that these lists will be consumed as they go,
>>> # sorted is just being used to compare the individual items
>>> sorted_iter = iter_sort(sorted, iter1, iter2)
>>> print(list(sorted_iter))
[0, 1, 2, 3, 4, 5]
"""
l = []
for x in iterables:
try:
x = iter(x)
l.append([next(x), x])
except StopIteration:
pass
if len(l) == 1:
yield l[0][0]
for x in l[0][1]:
yield x
return
l = sorter(l)
while l:
yield l[0][0]
for y in l[0][1]:
l[0][0] = y
break
else:
del l[0]
if len(l) == 1:
yield l[0][0]
for x in l[0][1]:
yield x
break
continue
l = sorter(l)