Source code for cyanide.suite

from __future__ import absolute_import, print_function, unicode_literals

import celery
import cyanide

import inspect
import platform
import socket
import sys
import traceback

from collections import OrderedDict, defaultdict, namedtuple
from functools import partial
from itertools import count, cycle

from celery.exceptions import TimeoutError
from celery.five import items, monotonic, range, values
from celery.utils.debug import blockdetection
from celery.utils.imports import qualname
from celery.utils.text import pluralize, truncate
from celery.utils.term import colored
from kombu.utils import retry_over_time

from .fbi import FBI
from .tasks import marker, _marker

try:
    from celery.platforms import isatty
except ImportError:
    from celery.utils import isatty

try:
    import celery.utils.time as timeutils
except ImportError:
    from celery.utils import timeutils

BANNER = """\
Cyanide v{version} [celery {celery_version}]

{platform}

[config]
.> app:    {app}
.> broker: {conninfo}
.> suite: {suite}

[toc: {total} {TESTS} total]
{toc}
"""

F_PROGRESS = """\
{0.index:2d}: {0.test.__name__:<36} {status} \
({0.iteration}/{0.total_iterations}) rep#{0.repeats} runtime: \
{runtime}/{elapsed}\
"""

E_STILL_WAITING = """\
Still waiting for {0}. Trying again {when}: {exc!r}\
"""

Progress = namedtuple('Progress', (
    'test', 'iteration', 'total_iterations',
    'index', 'repeats', 'runtime', 'elapsed', 'completed',
))


Inf = float('Inf')


[docs]def assert_equal(a, b): assert a == b, '{0!r} != {1!r}'.format(a, b)
[docs]class StopSuite(Exception): pass
[docs]class Sentinel(Exception): pass
[docs]def humanize_seconds(secs, prefix='', sep='', now='now', **kwargs): s = timeutils.humanize_seconds(secs, prefix, sep, now, **kwargs) if s == now and float(secs) > 0: return '{prefix}{sep}{0:.2f} seconds'.format( float(secs), prefix=prefix, sep=sep) return s
[docs]def pstatus(p, status=None): runtime = format(monotonic() - p.runtime, '.4f') elapsed = format(monotonic() - p.elapsed, '.4f') return F_PROGRESS.format( p, status=status or '', runtime=humanize_seconds(runtime, now=runtime), elapsed=humanize_seconds(elapsed, now=elapsed), )
[docs]class Speaker(object): def __init__(self, gap=5.0, file=None): self.gap = gap self.file = sys.stdout if file is None else file self.last_noise = monotonic() - self.gap * 2
[docs] def beep(self): now = monotonic() if now - self.last_noise >= self.gap: self.emit() self.last_noise = now
[docs] def emit(self): print('\a', file=self.file, end='')
[docs]class Meter(object): def __init__(self, s='.', end='', cursor='/-\\', file=None): self.s = s self.end = end self.file = sys.stdout if file is None else file self.counter = 0 self.cursor = cycle(cursor)
[docs] def emit(self, *args, **kwargs): self.counter += len(self.s) print(self.s * (self.counter - 1) + next(self.cursor), end='\r', file=self.file) self.file.flush()
[docs] def revert(self): self.counter = 0
[docs]class DummyMeter(object): def __init__(self, *args, **kwargs): pass
[docs] def emit(self, *args, **kwargs): pass
[docs] def revert(self): pass
[docs]def testgroup(*funs): return OrderedDict((fun.__name__, fun) for fun in funs)
[docs]class ManagerMixin(object): TaskPredicate = StopSuite Meter = Meter def _init_manager(self, app, block_timeout=30 * 60, stdout=None, stderr=None): self.stdout = sys.stdout if stdout is None else stdout self.stderr = sys.stderr if stderr is None else stderr self.connerrors = self.app.connection().recoverable_connection_errors self.block_timeout = block_timeout self.progress = None self.speaker = Speaker(file=self.stdout) self.fbi = FBI(app)
[docs] def new_meter(self): return self.Meter(file=self.stdout)
[docs] def missing_results(self, r): return [res.id for res in r if res.id not in res.backend._cache]
[docs] def wait_for(self, fun, catch, desc='thing', args=(), kwargs={}, errback=None, max_retries=10, interval_start=0.1, interval_step=0.5, interval_max=5.0, emit_warning=False, **options): meter = self.new_meter() def on_error(exc, intervals, retries): interval = next(intervals) if emit_warning: self.warn(E_STILL_WAITING.format( desc, when=humanize_seconds(interval, 'in', ' '), exc=exc, )) else: meter.emit() if errback: errback(exc, interval, retries) return interval return self.retry_over_time( fun, catch, args=args, kwargs=kwargs, errback=on_error, max_retries=max_retries, interval_start=interval_start, interval_step=interval_step, **options )
[docs] def ensure_not_for_a_while(self, fun, catch, desc='thing', max_retries=20, interval_start=0.1, interval_step=0.02, interval_max=1.0, emit_warning=False, **options): meter = self.new_meter() try: return self.wait_for( fun, catch, desc=desc, max_retries=max_retries, interval_start=interval_start, interval_step=interval_step, interval_max=interval_max, emit_warning=emit_warning, errback=meter.emit, ) except catch: pass else: raise AssertionError('Should not have happened: {0}'.format(desc)) finally: meter.revert()
[docs] def retry_over_time(self, *args, **kwargs): return retry_over_time(*args, **kwargs)
[docs] def join(self, r, propagate=False, max_retries=10, **kwargs): if self.no_join: return received = [] def on_result(task_id, value): received.append(task_id) for i in range(max_retries) if max_retries else count(0): received[:] = [] try: return r.get(callback=on_result, propagate=propagate, **kwargs) except (socket.timeout, TimeoutError) as exc: waiting_for = self.missing_results(r) self.speaker.beep() marker( 'Still waiting for {0}/{1}: [{2}]: {3!r}'.format( len(r) - len(received), len(r), truncate(', '.join(waiting_for)), exc), '!', ) self.fbi.diag(waiting_for) except self.connerrors as exc: self.speaker.beep() marker('join: connection lost: {0!r}'.format(exc), '!') raise self.TaskPredicate('Test failed: Missing task results')
[docs] def inspect(self, timeout=1): return self.app.control.inspect(timeout=timeout)
[docs] def query_tasks(self, ids, timeout=0.5): for reply in items(self.inspect(timeout).query_task(ids) or []): yield reply
[docs] def query_task_states(self, ids, timeout=0.5): states = defaultdict(set) for hostname, reply in self.query_tasks(ids, timeout=timeout): for task_id, (state, _) in items(reply): states[state].add(task_id) return states
[docs] def assert_accepted(self, ids, interval=0.5, desc='waiting for tasks to be accepted', **policy): return self.assert_task_worker_state( self.is_accepted, ids, interval=interval, desc=desc, **policy )
[docs] def assert_received(self, ids, interval=0.5, desc='waiting for tasks to be received', **policy): return self.assert_task_worker_state( self.is_accepted, ids, interval=interval, desc=desc, **policy )
[docs] def assert_task_worker_state(self, fun, ids, interval=0.5, **policy): return self.wait_for( partial(self.true_or_raise, fun, ids, timeout=interval), (Sentinel,), **policy )
[docs] def is_received(self, ids, **kwargs): return self._ids_matches_state( ['reserved', 'active', 'ready'], ids, **kwargs)
[docs] def is_accepted(self, ids, **kwargs): return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5): states = self.query_task_states(ids, timeout=timeout) return all( any(t in s for s in [states[k] for k in expected_states]) for t in ids )
[docs] def true_or_raise(self, fun, *args, **kwargs): res = fun(*args, **kwargs) if not res: raise Sentinel() return res
[docs]class Suite(ManagerMixin): TaskPredicate = StopSuite def __init__(self, app, no_color=False, **kwargs): self.app = app self._init_manager(app, **kwargs) if not isatty(self.stdout): no_color = True self.colored = colored(enabled=not no_color) self.init_groups()
[docs] def setup(self): pass
[docs] def teardown(self): pass
[docs] def print(self, message, file=None): print(message, file=self.stdout if file is None else file)
[docs] def error(self, message): print(self.colored.red(message), file=self.stderr)
[docs] def warn(self, message): print(self.colored.cyan(message), file=self.stdout)
[docs] def init_groups(self): acc = defaultdict(list) for attr in dir(self): if not _is_descriptor(self, attr): meth = getattr(self, attr) try: groups = meth.__func__.__testgroup__ except AttributeError: pass else: for g in groups: acc[g].append(meth) # sort the tests by the order in which they are defined in the class for g in values(acc): g[:] = sorted(g, key=lambda m: m.__func__.__testsort__) self.groups = dict( (name, testgroup(*tests)) for name, tests in items(acc) )
[docs] def run(self, names=None, iterations=50, offset=0, numtests=None, list_all=False, repeat=0, group='all', diag=False, no_join=False, **kw): self.no_join = no_join self.fbi.enable(diag) tests = self.filtertests(group, names)[offset:numtests or None] if list_all: return self.print(self.testlist(tests)) self.print(self.banner(tests)) self.print('+enable worker task events...') self.app.control.enable_events() it = count() if repeat == Inf else range(int(repeat) or 1) for i in it: marker( '{0} (repetition {1})'.format( self.colored.bold('suite start'), i + 1), '+', ) for j, test in enumerate(tests): self.runtest(test, iterations, j + 1, i + 1) marker( '{0} (repetition {1})'.format( self.colored.bold('suite end'), i + 1), '+', )
[docs] def assert_equal(self, a, b): return assert_equal(a, b)
[docs] def filtertests(self, group, names): tests = self.groups[group] try: return ([tests[n] for n in names] if names else list(values(tests))) except KeyError as exc: raise KeyError('Unknown test name: {0}'.format(exc))
[docs] def testlist(self, tests): return ',\n'.join( '.> {0}) {1}'.format(i + 1, t.__name__) for i, t in enumerate(tests) )
[docs] def banner(self, tests): app = self.app return BANNER.format( app='{0}:0x{1:x}'.format(app.main or '__main__', id(app)), version=cyanide.__version__, celery_version=celery.VERSION_BANNER, conninfo=app.connection().as_uri(), platform=platform.platform(), toc=self.testlist(tests), TESTS=pluralize(len(tests), 'test'), total=len(tests), suite=':'.join(qualname(self).rsplit('.', 1)), )
[docs] def runtest(self, fun, n=50, index=0, repeats=1): n = getattr(fun, '__iterations__', None) or n header = '[[[{0}({1})]]]'.format(fun.__name__, n) if repeats > 1: header = '{0} #{1}'.format(header, repeats) self.print(header) with blockdetection(self.block_timeout): with self.fbi.investigation(): runtime = elapsed = monotonic() i = 0 failed = False self.progress = Progress( fun, i, n, index, repeats, elapsed, runtime, 0, ) _marker.delay(pstatus(self.progress)) try: for i in range(n): runtime = monotonic() self.progress = Progress( fun, i + 1, n, index, repeats, runtime, elapsed, 0, ) self.execute_test(fun) except Exception: failed = True self.speaker.beep() raise finally: if n > 1 or failed: self.print('{0} {1} iterations in {2}'.format( 'failed after' if failed else 'completed', i + 1, humanize_seconds(monotonic() - elapsed), ), file=self.stderr if failed else self.stdout) if not failed: self.progress = Progress( fun, i + 1, n, index, repeats, runtime, elapsed, 1, )
[docs] def execute_test(self, fun): self.setup() try: try: fun() except StopSuite: raise except AssertionError as exc: self.on_test_error(exc, 'FAILED') except Exception as exc: self.on_test_error(exc, 'ERROR') else: self.print(pstatus(self.progress, self.colored.green('OK'))) finally: self.teardown()
[docs] def on_test_error(self, exc, status): self.error('-> {0!r}'.format(exc)) self.error(traceback.format_exc()) self.error(pstatus(self.progress, self.colored.red(status)))
[docs] def dump_progress(self): return pstatus(self.progress) if self.progress else 'No test running'
_creation_counter = count(0)
[docs]def testcase(*groups, **kwargs): if not groups: raise ValueError('@testcase requires at least one group name') def _mark_as_case(fun): fun.__testgroup__ = groups fun.__testsort__ = next(_creation_counter) fun.__iterations__ = kwargs.get('iterations') return fun return _mark_as_case
def _is_descriptor(obj, attr): try: cattr = getattr(obj.__class__, attr) except AttributeError: pass else: ismethod = inspect.ismethod(cattr) or inspect.isfunction(cattr) return not ismethod and hasattr(cattr, '__get__') return False