diff --git a/shoebox/shoebox.py b/shoebox/shoebox.py index c3316c1..4a0c3f0 100644 --- a/shoebox/shoebox.py +++ b/shoebox/shoebox.py @@ -55,12 +55,19 @@ def now(): class RollChecker(object): def start(self, archive): + """Called when a new archive is selected.""" pass def check(self, archive): + """Should the current archive roll?""" pass +class NeverRollChecker(RollChecker): + def check(self, archive): + return False + + class TimeRollChecker(RollChecker): def __init__(self, timedelta): self.timedelta = timedelta @@ -102,7 +109,7 @@ class RollManager(object): return self.active_archive def _should_roll_archive(self): - return False + return self.roll_checker.check(self.active_archive) def _roll_archive(self): pass @@ -134,7 +141,7 @@ class WritingRollManager(RollManager): def write(self, payload): a = self.get_active_archive() a.write(payload) - if self._should_roll_archive(a): + if self._should_roll_archive(): self._roll_archive() @@ -158,7 +165,6 @@ class ArchiveWriter(Archive): pass - class ArchiveReader(Archive): """The active Archive for consuming. """ diff --git a/test/test_shoebox.py b/test/test_shoebox.py index a2ed1f7..0999ec9 100644 --- a/test/test_shoebox.py +++ b/test/test_shoebox.py @@ -68,4 +68,28 @@ class TestWritingRollManager(unittest.TestCase): x = shoebox.WritingRollManager(filename_template, roll_checker) archive = x.get_active_archive() self.assertTrue(isinstance(archive, shoebox.ArchiveWriter)) - self.assertTrue(roll_checker.start.called) \ No newline at end of file + self.assertTrue(roll_checker.start.called) + + def test_write_always_roll(self): + roll_checker = mock.Mock() + roll_checker.check.return_value = True + x = shoebox.WritingRollManager("template", roll_checker) + with mock.patch.object(x, "_roll_archive") as ra: + x.write("payload") + self.assertTrue(ra.called) + + def test_write_never_roll(self): + roll_checker = mock.Mock() + roll_checker.check.return_value = False + x = shoebox.WritingRollManager("template", roll_checker) + with mock.patch.object(x, "_roll_archive") as ra: + x.write("payload") + self.assertFalse(ra.called) + +class TestWriting(unittest.TestCase): + def test_write(self): + roll_checker = shoebox.NeverRollChecker() + x = shoebox.WritingRollManager("template_%s", roll_checker) + + for index in range(10): + x.write("payload_%d" % index)