from arsenal.iterextras import buf_iter, head_iter
from heapq import heapify, heappop, _siftup, heappush
[docs]def sorted_union(*iterators):
"""
Merge multiple sorted inputs into a single sorted output.
Equivalent to: sorted(itertools.chain(*iterables))
>>> list(merge_sorted([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
[0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]
"""
h = [head_iter(s) for s in iterators]
h = [s for s in h if not s.done]
heapify(h)
while h:
s = h[0]
yield s.__next__() # advance the top iterator
if s.done:
heappop(h) # remove empty iterator
else:
_siftup(h, 0) # restore heap condition
merge_sorted = sorted_union
[docs]class Item:
def __init__(self, cost, index, elems):
self.cost = cost
self.index = index
self.elems = elems
def __lt__(self, other):
return self.cost < other.cost
# XXX: better to just binarize the product (assuming its associative)
[docs]def sorted_product(p, *iters):
"""
Sorted product of `iters`, where the output is sorted by a monotonic product
operator `p`. Examples: tuples, or multiplication/addition of positive numbers.
"""
n = len(iters)
assert n > 1
# use buffered iterator to ensure (random) access to the previously emitted
# values of each iterator.
iters = [buf_iter(it) for it in iters]
def vals(z):
return tuple(it[j] for it, j in zip(iters, z))
# elements in the heap are wrapped `Item` to make it a min heap.
q = []
y = (0,)*n
heappush(q, Item(p(vals(y)), 0, y))
while q:
item = heappop(q)
x = item.elems
j = item.index
yield item.cost
# next best item must differ by one, enqueue all such items
# We reduce the number of pushes by the dotted rule trick.
for i in range(j, n):
y = list(x)
y[i] = x[i] + 1
y = tuple(y)
#if x[i] + 1 >= len(a[i]): continue
try:
iters[i][y[i]]
except IndexError:
# `IndexError` is thrown when `iter[i]` is finite and we
# requested more iterates than it has.
continue
# TODO: Efficiency improvement: memoize the prefix/suffix products
# to save on the cost of p here. We know that it differs from the
# priority of the emitted `p(vals(x))` in only position `i`.
# Alternatively, binarize the product.
heappush(q, Item(p(vals(y)), i, y))
[docs]def main():
import numpy as np
import itertools
from arsenal.iterextras import take
from itertools import count
# weighted tuples are the idea as a path weight with backpointers; our weighted
# tuple copies the tuple, so it is inefficient compared to the lazier
# backpointer variant.
class WeightedTuple:
def __init__(self, w, *key):
self.key = key
self.w = w
def __lt__(self, other):
return (self.w, self.key) < (other.w, other.key)
def __eq__(self, other):
return (self.w, self.key) == (other.w, other.key)
def __mul__(self, other):
return LWeightedTuple(self.w*other.w, self, other)
def __add__(self, other):
return LWeightedTuple(self.w+other.w, self, other)
def __iter__(self):
return iter((self.w, self.key))
def __repr__(self):
return repr((self.w, self.key))
class LWeightedTuple(WeightedTuple):
"WeightedTuple with lazy concatenation of keys."
def __init__(self, w, a, b):
self.w = w
self.a = a
self.b = b
@property
def key(self):
return self.a.key + self.b.key
def wprod(xs):
return np.product([WeightedTuple(x, x) for x in xs])
def wsum(xs):
return np.sum([WeightedTuple(x, x) for x in xs])
def check(iters):
for p in [np.product, np.sum, tuple, wprod]:
# enumerate and sort; not lazy
want = list(sorted(p(x) for x in itertools.product(*iters)))
have = list(sorted_product(p, *iters))
print()
print('product operator:', p.__name__)
print('HAVE:', have)
#if have != want:
print('WANT:', want)
assert have == want
print('pass.')
print('===========')
check([
(.1, .4, 0.5),
(0.09, 0.11, 0.8),
(0.111, .3, 0.6),
])
print('===========')
check([
(1, 2, 3),
(4, 7, 11),
])
print('===========')
check([
(0.01, .4, 0.5),
(0.11, 0.8),
(0.6,),
])
print('===========')
check([
(1, 2, 3, 100),
(4, 7, 9),
(14, 17, 19),
(24, 27, 29),
])
print('===========')
a = (3**i for i in count(1))
b = (4**i for i in count(1))
c = (5**i for i in count(1))
for s,x in take(20, sorted_product(wsum, a, b, c)):
print(s, x)
if __name__ == '__main__':
main()