diff --git a/src/reader.py b/src/reader.py index 48b6155..31aaf6f 100644 --- a/src/reader.py +++ b/src/reader.py @@ -6,6 +6,7 @@ import sys import select import random import log +import bucket __interval__ = .1 @@ -68,36 +69,19 @@ class StdinReader: return None class RandomReader: - def __init__(self, keys={"a":{"weight":1}, "b":{"weight":1}}): + def __init__(self, keys): self.keys = keys - self.pool = RandomPool(keys) + self.pool = bucket.Bucket(1) + self.pool.chooser = bucket.BucketChooserProportionalRandom() + for k,v in keys.items(): + for i in range(v["weight"]): + self.pool.push(k) def read(self): - return self.pool.pop() - -class RandomPool: - def __init__(self, values): - self.values = { - k:int(1000*v["weight"]) for k,v in values.items() - } - self.total = sum(self.values.values()) - self.consumed = set() - log.info("RandomPool with", {k:int(100.0*v/self.total) for k,v in self.values.items()}) - - def reset(self): - self.consumed = set() - - def pop(self): - idx = random.randint(0, self.total-1) - idx_offset = 0 - for k in sorted(self.values.keys()): - if self.values[k] > idx+idx_offset: - return k - idx_offset -= self.values[k] - raise Exception(":(") - - def should_reset(self): - return len(self.consumed) >= len(self.values)/2 + result = self.pool.pick_n(1) + if not result: + return None + return result[0] class FileReader: def __init__(self, path):