From 8ff274ccfcfb522b82916259e17cdbdf1bb7e7a6 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Sun, 20 Oct 2019 20:53:00 -0600 Subject: [PATCH 01/11] fixes #95 - refactor huge assertion class --- assertpy/__init__.py | 3 +- assertpy/assertpy.py | 1294 +--------------------------------------- assertpy/base.py | 128 ++++ assertpy/collection.py | 92 +++ assertpy/contains.py | 170 ++++++ assertpy/date.py | 81 +++ assertpy/dict.py | 108 ++++ assertpy/dynamic.py | 82 +++ assertpy/exception.py | 71 +++ assertpy/extracting.py | 110 ++++ assertpy/file.py | 123 ++++ assertpy/helpers.py | 228 +++++++ assertpy/numeric.py | 208 +++++++ assertpy/snapshot.py | 127 ++++ assertpy/string.py | 207 +++++++ 15 files changed, 1759 insertions(+), 1273 deletions(-) create mode 100644 assertpy/base.py create mode 100644 assertpy/collection.py create mode 100644 assertpy/contains.py create mode 100644 assertpy/date.py create mode 100644 assertpy/dict.py create mode 100644 assertpy/dynamic.py create mode 100644 assertpy/exception.py create mode 100644 assertpy/extracting.py create mode 100644 assertpy/file.py create mode 100644 assertpy/helpers.py create mode 100644 assertpy/numeric.py create mode 100644 assertpy/snapshot.py create mode 100644 assertpy/string.py diff --git a/assertpy/__init__.py b/assertpy/__init__.py index f9cf769..4db1c0a 100644 --- a/assertpy/__init__.py +++ b/assertpy/__init__.py @@ -1,2 +1,3 @@ from __future__ import absolute_import -from .assertpy import assert_that, assert_warn, soft_assertions, contents_of, fail, soft_fail, __version__ +from .assertpy import assert_that, assert_warn, soft_assertions, fail, soft_fail, add_extension, __version__ +from .file import contents_of diff --git a/assertpy/assertpy.py b/assertpy/assertpy.py index a6c2b7b..f1e3d7e 100644 --- a/assertpy/assertpy.py +++ b/assertpy/assertpy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2018, Activision Publishing, Inc. +# Copyright (c) 2015-2019, Activision Publishing, Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -28,34 +28,27 @@ """Assertion library for python unit testing with a fluent API""" -from __future__ import division, print_function -import re import os -import sys -import datetime -import numbers -import collections -import inspect -import math import contextlib -import json +from .base import BaseMixin +from .contains import ContainsMixin +from .numeric import NumericMixin +from .string import StringMixin +from .collection import CollectionMixin +from .dict import DictMixin +from .date import DateMixin +from .file import FileMixin +from .extracting import ExtractingMixin +from .snapshot import SnapshotMixin +from .exception import ExceptionMixin +from .dynamic import DynamicMixin +from .helpers import HelpersMixin __version__ = '0.14' __tracebackhide__ = True # clean tracebacks via py.test integration contextlib.__tracebackhide__ = True # monkey patch contextlib with clean py.test tracebacks -if sys.version_info[0] == 3: - str_types = (str,) - xrange = range - unicode = str - Iterable = collections.abc.Iterable -else: - str_types = (basestring,) - xrange = xrange - unicode = unicode - Iterable = collections.Iterable - ### soft assertions ### _soft_ctx = 0 @@ -99,38 +92,6 @@ def assert_warn(val, description=''): just warn on assertion failures instead of raisings exceptions.""" return AssertionBuilder(val, description, 'warn') -def contents_of(f, encoding='utf-8'): - """Helper to read the contents of the given file or path into a string with the given encoding. - Encoding defaults to 'utf-8', other useful encodings are 'ascii' and 'latin-1'.""" - - try: - contents = f.read() - except AttributeError: - try: - with open(f, 'r') as fp: - contents = fp.read() - except TypeError: - raise ValueError('val must be file or path, but was type <%s>' % type(f).__name__) - except OSError: - if not isinstance(f, str_types): - raise ValueError('val must be file or path, but was type <%s>' % type(f).__name__) - raise - - if sys.version_info[0] == 3 and type(contents) is bytes: - # in PY3 force decoding of bytes to target encoding - return contents.decode(encoding, 'replace') - elif sys.version_info[0] == 2 and encoding == 'ascii': - # in PY2 force encoding back to ascii - return contents.encode('ascii', 'replace') - else: - # in all other cases, try to decode to target encoding - try: - return contents.decode(encoding, 'replace') - except AttributeError: - pass - # if all else fails, just return the contents "as is" - return contents - def fail(msg=''): """Force test failure with the given message.""" raise AssertionError('Fail: %s!' % msg if msg else 'Fail!') @@ -146,958 +107,24 @@ def soft_fail(msg=''): fail(msg) -class AssertionBuilder(object): +class AssertionBuilder(DynamicMixin, ExceptionMixin, SnapshotMixin, ExtractingMixin, + FileMixin, DateMixin, DictMixin, CollectionMixin, StringMixin, NumericMixin, + ContainsMixin, HelpersMixin, BaseMixin, object): """Assertion builder.""" def __init__(self, val, description='', kind=None, expected=None): """Construct the assertion builder.""" self.val = val self.description = description + if kind not in set([None, 'soft', 'warn']): + raise ValueError('bad assertion kind') self.kind = kind self.expected = expected - def described_as(self, description): - """Describes the assertion. On failure, the description is included in the error message.""" - self.description = str(description) - return self - - def is_equal_to(self, other, **kwargs): - """Asserts that val is equal to other.""" - if self._check_dict_like(self.val, check_values=False, return_as_bool=True) and \ - self._check_dict_like(other, check_values=False, return_as_bool=True): - if self._dict_not_equal(self.val, other, ignore=kwargs.get('ignore'), include=kwargs.get('include')): - self._dict_err(self.val, other, ignore=kwargs.get('ignore'), include=kwargs.get('include')) - else: - if self.val != other: - self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val, other)) - return self - - def is_not_equal_to(self, other): - """Asserts that val is not equal to other.""" - if self.val == other: - self._err('Expected <%s> to be not equal to <%s>, but was.' % (self.val, other)) - return self - - def is_same_as(self, other): - """Asserts that the val is identical to other, via 'is' compare.""" - if self.val is not other: - self._err('Expected <%s> to be identical to <%s>, but was not.' % (self.val, other)) - return self - - def is_not_same_as(self, other): - """Asserts that the val is not identical to other, via 'is' compare.""" - if self.val is other: - self._err('Expected <%s> to be not identical to <%s>, but was.' % (self.val, other)) - return self - - def is_true(self): - """Asserts that val is true.""" - if not self.val: - self._err('Expected , but was not.') - return self - - def is_false(self): - """Asserts that val is false.""" - if self.val: - self._err('Expected , but was not.') - return self - - def is_none(self): - """Asserts that val is none.""" - if self.val is not None: - self._err('Expected <%s> to be , but was not.' % self.val) - return self - - def is_not_none(self): - """Asserts that val is not none.""" - if self.val is None: - self._err('Expected not , but was.') - return self - - def is_type_of(self, some_type): - """Asserts that val is of the given type.""" - if type(some_type) is not type and\ - not issubclass(type(some_type), type): - raise TypeError('given arg must be a type') - if type(self.val) is not some_type: - if hasattr(self.val, '__name__'): - t = self.val.__name__ - elif hasattr(self.val, '__class__'): - t = self.val.__class__.__name__ - else: - t = 'unknown' - self._err('Expected <%s:%s> to be of type <%s>, but was not.' % (self.val, t, some_type.__name__)) - return self - - def is_instance_of(self, some_class): - """Asserts that val is an instance of the given class.""" - try: - if not isinstance(self.val, some_class): - if hasattr(self.val, '__name__'): - t = self.val.__name__ - elif hasattr(self.val, '__class__'): - t = self.val.__class__.__name__ - else: - t = 'unknown' - self._err('Expected <%s:%s> to be instance of class <%s>, but was not.' % (self.val, t, some_class.__name__)) - except TypeError: - raise TypeError('given arg must be a class') - return self - - def is_length(self, length): - """Asserts that val is the given length.""" - if type(length) is not int: - raise TypeError('given arg must be an int') - if length < 0: - raise ValueError('given arg must be a positive int') - if len(self.val) != length: - self._err('Expected <%s> to be of length <%d>, but was <%d>.' % (self.val, length, len(self.val))) - return self - - def contains(self, *items): - """Asserts that val contains the given item or items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - elif len(items) == 1: - if items[0] not in self.val: - if self._check_dict_like(self.val, return_as_bool=True): - self._err('Expected <%s> to contain key <%s>, but did not.' % (self.val, items[0])) - else: - self._err('Expected <%s> to contain item <%s>, but did not.' % (self.val, items[0])) - else: - missing = [] - for i in items: - if i not in self.val: - missing.append(i) - if missing: - if self._check_dict_like(self.val, return_as_bool=True): - self._err('Expected <%s> to contain keys %s, but did not contain key%s %s.' % (self.val, self._fmt_items(items), '' if len(missing) == 0 else 's', self._fmt_items(missing))) - else: - self._err('Expected <%s> to contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) - return self - - def does_not_contain(self, *items): - """Asserts that val does not contain the given item or items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - elif len(items) == 1: - if items[0] in self.val: - self._err('Expected <%s> to not contain item <%s>, but did.' % (self.val, items[0])) - else: - found = [] - for i in items: - if i in self.val: - found.append(i) - if found: - self._err('Expected <%s> to not contain items %s, but did contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(found))) - return self - - def contains_only(self, *items): - """Asserts that val contains only the given item or items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - else: - extra = [] - for i in self.val: - if i not in items: - extra.append(i) - if extra: - self._err('Expected <%s> to contain only %s, but did contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(extra))) - - missing = [] - for i in items: - if i not in self.val: - missing.append(i) - if missing: - self._err('Expected <%s> to contain only %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) - return self - - def contains_sequence(self, *items): - """Asserts that val contains the given sequence of items in order.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - else: - try: - for i in xrange(len(self.val) - len(items) + 1): - for j in xrange(len(items)): - if self.val[i+j] != items[j]: - break - else: - return self - except TypeError: - raise TypeError('val is not iterable') - self._err('Expected <%s> to contain sequence %s, but did not.' % (self.val, self._fmt_items(items))) - - def contains_duplicates(self): - """Asserts that val is iterable and contains duplicate items.""" - try: - if len(self.val) != len(set(self.val)): - return self - except TypeError: - raise TypeError('val is not iterable') - self._err('Expected <%s> to contain duplicates, but did not.' % self.val) - - def does_not_contain_duplicates(self): - """Asserts that val is iterable and does not contain any duplicate items.""" - try: - if len(self.val) == len(set(self.val)): - return self - except TypeError: - raise TypeError('val is not iterable') - self._err('Expected <%s> to not contain duplicates, but did.' % self.val) - - def is_empty(self): - """Asserts that val is empty.""" - if len(self.val) != 0: - if isinstance(self.val, str_types): - self._err('Expected <%s> to be empty string, but was not.' % self.val) - else: - self._err('Expected <%s> to be empty, but was not.' % self.val) - return self - - def is_not_empty(self): - """Asserts that val is not empty.""" - if len(self.val) == 0: - if isinstance(self.val, str_types): - self._err('Expected not empty string, but was empty.') - else: - self._err('Expected not empty, but was empty.') - return self - - def is_in(self, *items): - """Asserts that val is equal to one of the given items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - else: - for i in items: - if self.val == i: - return self - self._err('Expected <%s> to be in %s, but was not.' % (self.val, self._fmt_items(items))) - - def is_not_in(self, *items): - """Asserts that val is not equal to one of the given items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - else: - for i in items: - if self.val == i: - self._err('Expected <%s> to not be in %s, but was.' % (self.val, self._fmt_items(items))) - return self - -### numeric assertions ### - - COMPAREABLE_TYPES = set([datetime.datetime, datetime.timedelta, datetime.date, datetime.time]) - NON_COMPAREABLE_TYPES = set([complex]) - - def _validate_compareable(self, other): - self_type = type(self.val) - other_type = type(other) - - if self_type in self.NON_COMPAREABLE_TYPES: - raise TypeError('ordering is not defined for type <%s>' % self_type.__name__) - if self_type in self.COMPAREABLE_TYPES: - if other_type is not self_type: - raise TypeError('given arg must be <%s>, but was <%s>' % (self_type.__name__, other_type.__name__)) - return - if isinstance(self.val, numbers.Number): - if not isinstance(other, numbers.Number): - raise TypeError('given arg must be a number, but was <%s>' % other_type.__name__) - return - raise TypeError('ordering is not defined for type <%s>' % self_type.__name__) - - def _validate_number(self): - """Raise TypeError if val is not numeric.""" - if isinstance(self.val, numbers.Number) is False: - raise TypeError('val is not numeric') - - def _validate_real(self): - """Raise TypeError if val is not real number.""" - if isinstance(self.val, numbers.Real) is False: - raise TypeError('val is not real number') - - def is_zero(self): - """Asserts that val is numeric and equal to zero.""" - self._validate_number() - return self.is_equal_to(0) - - def is_not_zero(self): - """Asserts that val is numeric and not equal to zero.""" - self._validate_number() - return self.is_not_equal_to(0) - - def is_nan(self): - """Asserts that val is real number and NaN (not a number).""" - self._validate_number() - self._validate_real() - if not math.isnan(self.val): - self._err('Expected <%s> to be , but was not.' % self.val) - return self - - def is_not_nan(self): - """Asserts that val is real number and not NaN (not a number).""" - self._validate_number() - self._validate_real() - if math.isnan(self.val): - self._err('Expected not , but was.') - return self - - def is_inf(self): - """Asserts that val is real number and Inf (infinity).""" - self._validate_number() - self._validate_real() - if not math.isinf(self.val): - self._err('Expected <%s> to be , but was not.' % self.val) - return self - - def is_not_inf(self): - """Asserts that val is real number and not Inf (infinity).""" - self._validate_number() - self._validate_real() - if math.isinf(self.val): - self._err('Expected not , but was.') - return self - - def is_greater_than(self, other): - """Asserts that val is numeric and is greater than other.""" - self._validate_compareable(other) - if self.val <= other: - if type(self.val) is datetime.datetime: - self._err('Expected <%s> to be greater than <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to be greater than <%s>, but was not.' % (self.val, other)) - return self - - def is_greater_than_or_equal_to(self, other): - """Asserts that val is numeric and is greater than or equal to other.""" - self._validate_compareable(other) - if self.val < other: - if type(self.val) is datetime.datetime: - self._err('Expected <%s> to be greater than or equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to be greater than or equal to <%s>, but was not.' % (self.val, other)) - return self - - def is_less_than(self, other): - """Asserts that val is numeric and is less than other.""" - self._validate_compareable(other) - if self.val >= other: - if type(self.val) is datetime.datetime: - self._err('Expected <%s> to be less than <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to be less than <%s>, but was not.' % (self.val, other)) - return self - - def is_less_than_or_equal_to(self, other): - """Asserts that val is numeric and is less than or equal to other.""" - self._validate_compareable(other) - if self.val > other: - if type(self.val) is datetime.datetime: - self._err('Expected <%s> to be less than or equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to be less than or equal to <%s>, but was not.' % (self.val, other)) - return self - - def is_positive(self): - """Asserts that val is numeric and greater than zero.""" - return self.is_greater_than(0) - - def is_negative(self): - """Asserts that val is numeric and less than zero.""" - return self.is_less_than(0) - - def is_between(self, low, high): - """Asserts that val is numeric and is between low and high.""" - val_type = type(self.val) - self._validate_between_args(val_type, low, high) - - if self.val < low or self.val > high: - if val_type is datetime.datetime: - self._err('Expected <%s> to be between <%s> and <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), low.strftime('%Y-%m-%d %H:%M:%S'), high.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to be between <%s> and <%s>, but was not.' % (self.val, low, high)) - return self - - def is_not_between(self, low, high): - """Asserts that val is numeric and is between low and high.""" - val_type = type(self.val) - self._validate_between_args(val_type, low, high) - - if self.val >= low and self.val <= high: - if val_type is datetime.datetime: - self._err('Expected <%s> to not be between <%s> and <%s>, but was.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), low.strftime('%Y-%m-%d %H:%M:%S'), high.strftime('%Y-%m-%d %H:%M:%S'))) - else: - self._err('Expected <%s> to not be between <%s> and <%s>, but was.' % (self.val, low, high)) - return self - - def is_close_to(self, other, tolerance): - """Asserts that val is numeric and is close to other within tolerance.""" - self._validate_close_to_args(self.val, other, tolerance) - - if self.val < (other-tolerance) or self.val > (other+tolerance): - if type(self.val) is datetime.datetime: - tolerance_seconds = tolerance.days * 86400 + tolerance.seconds + tolerance.microseconds / 1000000 - h, rem = divmod(tolerance_seconds, 3600) - m, s = divmod(rem, 60) - self._err('Expected <%s> to be close to <%s> within tolerance <%d:%02d:%02d>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'), h, m, s)) - else: - self._err('Expected <%s> to be close to <%s> within tolerance <%s>, but was not.' % (self.val, other, tolerance)) - return self - - def is_not_close_to(self, other, tolerance): - """Asserts that val is numeric and is not close to other within tolerance.""" - self._validate_close_to_args(self.val, other, tolerance) - - if self.val >= (other-tolerance) and self.val <= (other+tolerance): - if type(self.val) is datetime.datetime: - tolerance_seconds = tolerance.days * 86400 + tolerance.seconds + tolerance.microseconds / 1000000 - h, rem = divmod(tolerance_seconds, 3600) - m, s = divmod(rem, 60) - self._err('Expected <%s> to not be close to <%s> within tolerance <%d:%02d:%02d>, but was.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'), h, m, s)) - else: - self._err('Expected <%s> to not be close to <%s> within tolerance <%s>, but was.' % (self.val, other, tolerance)) - return self - -### string assertions ### - def is_equal_to_ignoring_case(self, other): - """Asserts that val is case-insensitive equal to other.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if not isinstance(other, str_types): - raise TypeError('given arg must be a string') - if self.val.lower() != other.lower(): - self._err('Expected <%s> to be case-insensitive equal to <%s>, but was not.' % (self.val, other)) - return self - - def contains_ignoring_case(self, *items): - """Asserts that val is string and contains the given item or items.""" - if len(items) == 0: - raise ValueError('one or more args must be given') - if isinstance(self.val, str_types): - if len(items) == 1: - if not isinstance(items[0], str_types): - raise TypeError('given arg must be a string') - if items[0].lower() not in self.val.lower(): - self._err('Expected <%s> to case-insensitive contain item <%s>, but did not.' % (self.val, items[0])) - else: - missing = [] - for i in items: - if not isinstance(i, str_types): - raise TypeError('given args must all be strings') - if i.lower() not in self.val.lower(): - missing.append(i) - if missing: - self._err('Expected <%s> to case-insensitive contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) - elif isinstance(self.val, Iterable): - missing = [] - for i in items: - if not isinstance(i, str_types): - raise TypeError('given args must all be strings') - found = False - for v in self.val: - if not isinstance(v, str_types): - raise TypeError('val items must all be strings') - if i.lower() == v.lower(): - found = True - break - if not found: - missing.append(i) - if missing: - self._err('Expected <%s> to case-insensitive contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) - else: - raise TypeError('val is not a string or iterable') - return self - - def starts_with(self, prefix): - """Asserts that val is string or iterable and starts with prefix.""" - if prefix is None: - raise TypeError('given prefix arg must not be none') - if isinstance(self.val, str_types): - if not isinstance(prefix, str_types): - raise TypeError('given prefix arg must be a string') - if len(prefix) == 0: - raise ValueError('given prefix arg must not be empty') - if not self.val.startswith(prefix): - self._err('Expected <%s> to start with <%s>, but did not.' % (self.val, prefix)) - elif isinstance(self.val, Iterable): - if len(self.val) == 0: - raise ValueError('val must not be empty') - first = next(iter(self.val)) - if first != prefix: - self._err('Expected %s to start with <%s>, but did not.' % (self.val, prefix)) - else: - raise TypeError('val is not a string or iterable') - return self - - def ends_with(self, suffix): - """Asserts that val is string or iterable and ends with suffix.""" - if suffix is None: - raise TypeError('given suffix arg must not be none') - if isinstance(self.val, str_types): - if not isinstance(suffix, str_types): - raise TypeError('given suffix arg must be a string') - if len(suffix) == 0: - raise ValueError('given suffix arg must not be empty') - if not self.val.endswith(suffix): - self._err('Expected <%s> to end with <%s>, but did not.' % (self.val, suffix)) - elif isinstance(self.val, Iterable): - if len(self.val) == 0: - raise ValueError('val must not be empty') - last = None - for last in self.val: - pass - if last != suffix: - self._err('Expected %s to end with <%s>, but did not.' % (self.val, suffix)) - else: - raise TypeError('val is not a string or iterable') - return self - - def matches(self, pattern): - """Asserts that val is string and matches regex pattern.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if not isinstance(pattern, str_types): - raise TypeError('given pattern arg must be a string') - if len(pattern) == 0: - raise ValueError('given pattern arg must not be empty') - if re.search(pattern, self.val) is None: - self._err('Expected <%s> to match pattern <%s>, but did not.' % (self.val, pattern)) - return self - - def does_not_match(self, pattern): - """Asserts that val is string and does not match regex pattern.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if not isinstance(pattern, str_types): - raise TypeError('given pattern arg must be a string') - if len(pattern) == 0: - raise ValueError('given pattern arg must not be empty') - if re.search(pattern, self.val) is not None: - self._err('Expected <%s> to not match pattern <%s>, but did.' % (self.val, pattern)) - return self - - def is_alpha(self): - """Asserts that val is non-empty string and all characters are alphabetic.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if len(self.val) == 0: - raise ValueError('val is empty') - if not self.val.isalpha(): - self._err('Expected <%s> to contain only alphabetic chars, but did not.' % self.val) - return self - - def is_digit(self): - """Asserts that val is non-empty string and all characters are digits.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if len(self.val) == 0: - raise ValueError('val is empty') - if not self.val.isdigit(): - self._err('Expected <%s> to contain only digits, but did not.' % self.val) - return self - - def is_lower(self): - """Asserts that val is non-empty string and all characters are lowercase.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if len(self.val) == 0: - raise ValueError('val is empty') - if self.val != self.val.lower(): - self._err('Expected <%s> to contain only lowercase chars, but did not.' % self.val) - return self - - def is_upper(self): - """Asserts that val is non-empty string and all characters are uppercase.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a string') - if len(self.val) == 0: - raise ValueError('val is empty') - if self.val != self.val.upper(): - self._err('Expected <%s> to contain only uppercase chars, but did not.' % self.val) - return self - - def is_unicode(self): - """Asserts that val is a unicode string.""" - if type(self.val) is not unicode: - self._err('Expected <%s> to be unicode, but was <%s>.' % (self.val, type(self.val).__name__)) - return self - -### collection assertions ### - def is_iterable(self): - """Asserts that val is iterable collection.""" - if not isinstance(self.val, Iterable): - self._err('Expected iterable, but was not.') - return self - - def is_not_iterable(self): - """Asserts that val is not iterable collection.""" - if isinstance(self.val, Iterable): - self._err('Expected not iterable, but was.') - return self - - def is_subset_of(self, *supersets): - """Asserts that val is iterable and a subset of the given superset or flattened superset if multiple supersets are given.""" - if not isinstance(self.val, Iterable): - raise TypeError('val is not iterable') - if len(supersets) == 0: - raise ValueError('one or more superset args must be given') - - missing = [] - if hasattr(self.val, 'keys') and callable(getattr(self.val, 'keys')) and hasattr(self.val, '__getitem__'): - # flatten superset dicts - superdict = {} - for l,j in enumerate(supersets): - self._check_dict_like(j, check_values=False, name='arg #%d' % (l+1)) - for k in j.keys(): - superdict.update({k: j[k]}) - - for i in self.val.keys(): - if i not in superdict: - missing.append({i: self.val[i]}) # bad key - elif self.val[i] != superdict[i]: - missing.append({i: self.val[i]}) # bad val - if missing: - self._err('Expected <%s> to be subset of %s, but %s %s missing.' % (self.val, self._fmt_items(superdict), self._fmt_items(missing), 'was' if len(missing) == 1 else 'were')) - else: - # flatten supersets - superset = set() - for j in supersets: - try: - for k in j: - superset.add(k) - except Exception: - superset.add(j) - - for i in self.val: - if i not in superset: - missing.append(i) - if missing: - self._err('Expected <%s> to be subset of %s, but %s %s missing.' % (self.val, self._fmt_items(superset), self._fmt_items(missing), 'was' if len(missing) == 1 else 'were')) - - return self - -### dict assertions ### - def contains_key(self, *keys): - """Asserts the val is a dict and contains the given key or keys. Alias for contains().""" - self._check_dict_like(self.val, check_values=False, check_getitem=False) - return self.contains(*keys) - - def does_not_contain_key(self, *keys): - """Asserts the val is a dict and does not contain the given key or keys. Alias for does_not_contain().""" - self._check_dict_like(self.val, check_values=False, check_getitem=False) - return self.does_not_contain(*keys) - - def contains_value(self, *values): - """Asserts that val is a dict and contains the given value or values.""" - self._check_dict_like(self.val, check_getitem=False) - if len(values) == 0: - raise ValueError('one or more value args must be given') - missing = [] - for v in values: - if v not in self.val.values(): - missing.append(v) - if missing: - self._err('Expected <%s> to contain values %s, but did not contain %s.' % (self.val, self._fmt_items(values), self._fmt_items(missing))) - return self - - def does_not_contain_value(self, *values): - """Asserts that val is a dict and does not contain the given value or values.""" - self._check_dict_like(self.val, check_getitem=False) - if len(values) == 0: - raise ValueError('one or more value args must be given') - else: - found = [] - for v in values: - if v in self.val.values(): - found.append(v) - if found: - self._err('Expected <%s> to not contain values %s, but did contain %s.' % (self.val, self._fmt_items(values), self._fmt_items(found))) - return self - - def contains_entry(self, *args, **kwargs): - """Asserts that val is a dict and contains the given entry or entries.""" - self._check_dict_like(self.val, check_values=False) - entries = list(args) + [{k:v} for k,v in kwargs.items()] - if len(entries) == 0: - raise ValueError('one or more entry args must be given') - missing = [] - for e in entries: - if type(e) is not dict: - raise TypeError('given entry arg must be a dict') - if len(e) != 1: - raise ValueError('given entry args must contain exactly one key-value pair') - k = next(iter(e)) - if k not in self.val: - missing.append(e) # bad key - elif self.val[k] != e[k]: - missing.append(e) # bad val - if missing: - self._err('Expected <%s> to contain entries %s, but did not contain %s.' % (self.val, self._fmt_items(entries), self._fmt_items(missing))) - return self - - def does_not_contain_entry(self, *args, **kwargs): - """Asserts that val is a dict and does not contain the given entry or entries.""" - self._check_dict_like(self.val, check_values=False) - entries = list(args) + [{k:v} for k,v in kwargs.items()] - if len(entries) == 0: - raise ValueError('one or more entry args must be given') - found = [] - for e in entries: - if type(e) is not dict: - raise TypeError('given entry arg must be a dict') - if len(e) != 1: - raise ValueError('given entry args must contain exactly one key-value pair') - k = next(iter(e)) - if k in self.val and e[k] == self.val[k]: - found.append(e) - if found: - self._err('Expected <%s> to not contain entries %s, but did contain %s.' % (self.val, self._fmt_items(entries), self._fmt_items(found))) - return self - -### datetime assertions ### - def is_before(self, other): - """Asserts that val is a date and is before other date.""" - if type(self.val) is not datetime.datetime: - raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) - if self.val >= other: - self._err('Expected <%s> to be before <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - return self - - def is_after(self, other): - """Asserts that val is a date and is after other date.""" - if type(self.val) is not datetime.datetime: - raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) - if self.val <= other: - self._err('Expected <%s> to be after <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - return self - - def is_equal_to_ignoring_milliseconds(self, other): - if type(self.val) is not datetime.datetime: - raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) - if self.val.date() != other.date() or self.val.hour != other.hour or self.val.minute != other.minute or self.val.second != other.second: - self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) - return self - - def is_equal_to_ignoring_seconds(self, other): - if type(self.val) is not datetime.datetime: - raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) - if self.val.date() != other.date() or self.val.hour != other.hour or self.val.minute != other.minute: - self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M'), other.strftime('%Y-%m-%d %H:%M'))) - return self - - def is_equal_to_ignoring_time(self, other): - if type(self.val) is not datetime.datetime: - raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) - if self.val.date() != other.date(): - self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d'), other.strftime('%Y-%m-%d'))) - return self - -### file assertions ### - def exists(self): - """Asserts that val is a path and that it exists.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a path') - if not os.path.exists(self.val): - self._err('Expected <%s> to exist, but was not found.' % self.val) - return self - - def does_not_exist(self): - """Asserts that val is a path and that it does not exist.""" - if not isinstance(self.val, str_types): - raise TypeError('val is not a path') - if os.path.exists(self.val): - self._err('Expected <%s> to not exist, but was found.' % self.val) - return self - - def is_file(self): - """Asserts that val is an existing path to a file.""" - self.exists() - if not os.path.isfile(self.val): - self._err('Expected <%s> to be a file, but was not.' % self.val) - return self - - def is_directory(self): - """Asserts that val is an existing path to a directory.""" - self.exists() - if not os.path.isdir(self.val): - self._err('Expected <%s> to be a directory, but was not.' % self.val) - return self - - def is_named(self, filename): - """Asserts that val is an existing path to a file and that file is named filename.""" - self.is_file() - if not isinstance(filename, str_types): - raise TypeError('given filename arg must be a path') - val_filename = os.path.basename(os.path.abspath(self.val)) - if val_filename != filename: - self._err('Expected filename <%s> to be equal to <%s>, but was not.' % (val_filename, filename)) - return self + def _builder(self, val, description='', kind=None, expected=None): + """Helper to build a new Builder. Only used when we don't want to chain.""" + return AssertionBuilder(val, description, kind, expected) - def is_child_of(self, parent): - """Asserts that val is an existing path to a file and that file is a child of parent.""" - self.is_file() - if not isinstance(parent, str_types): - raise TypeError('given parent directory arg must be a path') - val_abspath = os.path.abspath(self.val) - parent_abspath = os.path.abspath(parent) - if not val_abspath.startswith(parent_abspath): - self._err('Expected file <%s> to be a child of <%s>, but was not.' % (val_abspath, parent_abspath)) - return self - -### collection of objects assertions ### - def extracting(self, *names, **kwargs): - """Asserts that val is collection, then extracts the named properties or named zero-arg methods into a list (or list of tuples if multiple names are given).""" - if not isinstance(self.val, Iterable): - raise TypeError('val is not iterable') - if isinstance(self.val, str_types): - raise TypeError('val must not be string') - if len(names) == 0: - raise ValueError('one or more name args must be given') - - def _extract(x, name): - if self._check_dict_like(x, check_values=False, return_as_bool=True): - if name in x: - return x[name] - else: - raise ValueError('item keys %s did not contain key <%s>' % (list(x.keys()), name)) - elif isinstance(x, Iterable): - self._check_iterable(x, name='item') - return x[name] - elif hasattr(x, name): - attr = getattr(x, name) - if callable(attr): - try: - return attr() - except TypeError: - raise ValueError('val method <%s()> exists, but is not zero-arg method' % name) - else: - return attr - else: - raise ValueError('val does not have property or zero-arg method <%s>' % name) - - def _filter(x): - if 'filter' in kwargs: - if isinstance(kwargs['filter'], str_types): - return bool(_extract(x, kwargs['filter'])) - elif self._check_dict_like(kwargs['filter'], check_values=False, return_as_bool=True): - for k in kwargs['filter']: - if isinstance(k, str_types): - if _extract(x, k) != kwargs['filter'][k]: - return False - return True - elif callable(kwargs['filter']): - return kwargs['filter'](x) - return False - return True - - def _sort(x): - if 'sort' in kwargs: - if isinstance(kwargs['sort'], str_types): - return _extract(x, kwargs['sort']) - elif isinstance(kwargs['sort'], Iterable): - items = [] - for k in kwargs['sort']: - if isinstance(k, str_types): - items.append(_extract(x, k)) - return tuple(items) - elif callable(kwargs['sort']): - return kwargs['sort'](x) - return 0 - - extracted = [] - for i in sorted(self.val, key=lambda x: _sort(x)): - if _filter(i): - items = [_extract(i, name) for name in names] - extracted.append(tuple(items) if len(items) > 1 else items[0]) - return AssertionBuilder(extracted, self.description, self.kind) - -### dynamic assertions ### - def __getattr__(self, attr): - """Asserts that val has attribute attr and that attribute's value is equal to other via a dynamic assertion of the form: has_().""" - if not attr.startswith('has_'): - raise AttributeError('assertpy has no assertion <%s()>' % attr) - - attr_name = attr[4:] - err_msg = False - is_dict = isinstance(self.val, Iterable) and hasattr(self.val, '__getitem__') - - if not hasattr(self.val, attr_name): - if is_dict: - if attr_name not in self.val: - err_msg = 'Expected key <%s>, but val has no key <%s>.' % (attr_name, attr_name) - else: - err_msg = 'Expected attribute <%s>, but val has no attribute <%s>.' % (attr_name, attr_name) - - def _wrapper(*args, **kwargs): - if err_msg: - self._err(err_msg) # ok to raise AssertionError now that we are inside wrapper - else: - if len(args) != 1: - raise TypeError('assertion <%s()> takes exactly 1 argument (%d given)' % (attr, len(args))) - - try: - val_attr = getattr(self.val, attr_name) - except AttributeError: - val_attr = self.val[attr_name] - - if callable(val_attr): - try: - actual = val_attr() - except TypeError: - raise TypeError('val does not have zero-arg method <%s()>' % attr_name) - else: - actual = val_attr - - expected = args[0] - if actual != expected: - self._err('Expected <%s> to be equal to <%s> on %s <%s>, but was not.' % (actual, expected, 'key' if is_dict else 'attribute', attr_name)) - return self - - return _wrapper - -### expected exceptions ### - def raises(self, ex): - """Asserts that val is callable and that when called raises the given error.""" - if not callable(self.val): - raise TypeError('val must be callable') - if not issubclass(ex, BaseException): - raise TypeError('given arg must be exception') - return AssertionBuilder(self.val, self.description, self.kind, ex) - - def when_called_with(self, *some_args, **some_kwargs): - """Asserts the val callable when invoked with the given args and kwargs raises the expected exception.""" - if not self.expected: - raise TypeError('expected exception not set, raises() must be called first') - try: - self.val(*some_args, **some_kwargs) - except BaseException as e: - if issubclass(type(e), self.expected): - # chain on with exception message as val - return AssertionBuilder(str(e), self.description, self.kind) - else: - # got exception, but wrong type, so raise - self._err('Expected <%s> to raise <%s> when called with (%s), but raised <%s>.' % ( - self.val.__name__, - self.expected.__name__, - self._fmt_args_kwargs(*some_args, **some_kwargs), - type(e).__name__)) - - # didn't fail as expected, so raise - self._err('Expected <%s> to raise <%s> when called with (%s).' % ( - self.val.__name__, - self.expected.__name__, - self._fmt_args_kwargs(*some_args, **some_kwargs))) - -### helpers ### def _err(self, msg): """Helper to raise an AssertionError, and optionally prepend custom description.""" out = '%s%s' % ('[%s] ' % self.description if len(self.description) > 0 else '', msg) @@ -1110,280 +137,3 @@ def _err(self, msg): return self else: raise AssertionError(out) - - def _fmt_items(self, i): - if len(i) == 0: - return '<>' - elif len(i) == 1 and hasattr(i, '__getitem__'): - return '<%s>' % i[0] - else: - return '<%s>' % str(i).lstrip('([').rstrip(',])') - - def _fmt_args_kwargs(self, *some_args, **some_kwargs): - """Helper to convert the given args and kwargs into a string.""" - if some_args: - out_args = str(some_args).lstrip('(').rstrip(',)') - if some_kwargs: - out_kwargs = ', '.join([str(i).lstrip('(').rstrip(')').replace(', ',': ') for i in [ - (k,some_kwargs[k]) for k in sorted(some_kwargs.keys())]]) - - if some_args and some_kwargs: - return out_args + ', ' + out_kwargs - elif some_args: - return out_args - elif some_kwargs: - return out_kwargs - else: - return '' - - def _validate_between_args(self, val_type, low, high): - low_type = type(low) - high_type = type(high) - - if val_type in self.NON_COMPAREABLE_TYPES: - raise TypeError('ordering is not defined for type <%s>' % val_type.__name__) - - if val_type in self.COMPAREABLE_TYPES: - if low_type is not val_type: - raise TypeError('given low arg must be <%s>, but was <%s>' % (val_type.__name__, low_type.__name__)) - if high_type is not val_type: - raise TypeError('given high arg must be <%s>, but was <%s>' % (val_type.__name__, low_type.__name__)) - elif isinstance(self.val, numbers.Number): - if isinstance(low, numbers.Number) is False: - raise TypeError('given low arg must be numeric, but was <%s>' % low_type.__name__) - if isinstance(high, numbers.Number) is False: - raise TypeError('given high arg must be numeric, but was <%s>' % high_type.__name__) - else: - raise TypeError('ordering is not defined for type <%s>' % val_type.__name__) - - if low > high: - raise ValueError('given low arg must be less than given high arg') - - def _validate_close_to_args(self, val, other, tolerance): - if type(val) is complex or type(other) is complex or type(tolerance) is complex: - raise TypeError('ordering is not defined for complex numbers') - - if isinstance(val, numbers.Number) is False and type(val) is not datetime.datetime: - raise TypeError('val is not numeric or datetime') - - if type(val) is datetime.datetime: - if type(other) is not datetime.datetime: - raise TypeError('given arg must be datetime, but was <%s>' % type(other).__name__) - if type(tolerance) is not datetime.timedelta: - raise TypeError('given tolerance arg must be timedelta, but was <%s>' % type(tolerance).__name__) - else: - if isinstance(other, numbers.Number) is False: - raise TypeError('given arg must be numeric') - if isinstance(tolerance, numbers.Number) is False: - raise TypeError('given tolerance arg must be numeric') - if tolerance < 0: - raise ValueError('given tolerance arg must be positive') - - def _check_dict_like(self, d, check_keys=True, check_values=True, check_getitem=True, name='val', return_as_bool=False): - if not isinstance(d, Iterable): - if return_as_bool: - return False - else: - raise TypeError('%s <%s> is not dict-like: not iterable' % (name, type(d).__name__)) - if check_keys: - if not hasattr(d, 'keys') or not callable(getattr(d, 'keys')): - if return_as_bool: - return False - else: - raise TypeError('%s <%s> is not dict-like: missing keys()' % (name, type(d).__name__)) - if check_values: - if not hasattr(d, 'values') or not callable(getattr(d, 'values')): - if return_as_bool: - return False - else: - raise TypeError('%s <%s> is not dict-like: missing values()' % (name, type(d).__name__)) - if check_getitem: - if not hasattr(d, '__getitem__'): - if return_as_bool: - return False - else: - raise TypeError('%s <%s> is not dict-like: missing [] accessor' % (name, type(d).__name__)) - if return_as_bool: - return True - - def _check_iterable(self, l, check_getitem=True, name='val'): - if not isinstance(l, Iterable): - raise TypeError('%s <%s> is not iterable' % (name, type(l).__name__)) - if check_getitem: - if not hasattr(l, '__getitem__'): - raise TypeError('%s <%s> does not have [] accessor' % (name, type(l).__name__)) - - def _dict_not_equal(self, val, other, ignore=None, include=None): - if ignore or include: - ignores = self._dict_ignore(ignore) - includes = self._dict_include(include) - - # guarantee include keys are in val - if include: - missing = [] - for i in includes: - if i not in val: - missing.append(i) - if missing: - self._err('Expected <%s> to include key%s %s, but did not include key%s %s.' % ( - val, - '' if len(includes) == 1 else 's', - self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in includes]), - '' if len(missing) == 1 else 's', - self._fmt_items(missing))) - - if ignore and include: - k1 = set([k for k in val if k not in ignores and k in includes]) - elif ignore: - k1 = set([k for k in val if k not in ignores]) - else: # include - k1 = set([k for k in val if k in includes]) - - if ignore and include: - k2 = set([k for k in other if k not in ignores and k in includes]) - elif ignore: - k2 = set([k for k in other if k not in ignores]) - else: # include - k2 = set([k for k in other if k in includes]) - - if k1 != k2: - return True - else: - for k in k1: - if self._check_dict_like(val[k], check_values=False, return_as_bool=True) and self._check_dict_like(other[k], check_values=False, return_as_bool=True): - return self._dict_not_equal(val[k], other[k], - ignore=[i[1:] for i in ignores if type(i) is tuple and i[0] == k] if ignore else None, - include=[i[1:] for i in self._dict_ignore(include) if type(i) is tuple and i[0] == k] if include else None) - elif val[k] != other[k]: - return True - return False - else: - return val != other - - def _dict_ignore(self, ignore): - return [i[0] if type(i) is tuple and len(i) == 1 else i \ - for i in (ignore if type(ignore) is list else [ignore])] - - def _dict_include(self, include): - return [i[0] if type(i) is tuple else i \ - for i in (include if type(include) is list else [include])] - - def _dict_err(self, val, other, ignore=None, include=None): - def _dict_repr(d, other): - out = '' - ellip = False - for k,v in d.items(): - if k not in other: - out += '%s%s: %s' % (', ' if len(out) > 0 else '', repr(k), repr(v)) - elif v != other[k]: - out += '%s%s: %s' % (', ' if len(out) > 0 else '', repr(k), - _dict_repr(v, other[k]) if self._check_dict_like(v, check_values=False, return_as_bool=True) and self._check_dict_like(other[k], check_values=False, return_as_bool=True) else repr(v) - ) - else: - ellip = True - return '{%s%s}' % ('..' if ellip and len(out) == 0 else '.., ' if ellip else '', out) - - if ignore: - ignores = self._dict_ignore(ignore) - ignore_err = ' ignoring keys %s' % self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in ignores]) - if include: - includes = self._dict_ignore(include) - include_err = ' including keys %s' % self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in includes]) - - self._err('Expected <%s> to be equal to <%s>%s%s, but was not.' % ( - _dict_repr(val, other), - _dict_repr(other, val), - ignore_err if ignore else '', - include_err if include else '' - )) - -### snapshot testing ### - def snapshot(self, id=None, path='__snapshots'): - if sys.version_info[0] < 3: - raise NotImplementedError('snapshot testing requires Python 3') - - class _Encoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, set): - return {'__type__': 'set', '__data__': list(o)} - elif isinstance(o, complex): - return {'__type__': 'complex', '__data__': [o.real, o.imag]} - elif isinstance(o, datetime.datetime): - return {'__type__': 'datetime', '__data__': o.strftime('%Y-%m-%d %H:%M:%S')} - elif '__dict__' in dir(o) and type(o) is not type: - return { - '__type__': 'instance', - '__class__': o.__class__.__name__, - '__module__': o.__class__.__module__, - '__data__': o.__dict__ - } - return json.JSONEncoder.default(self, o) - - class _Decoder(json.JSONDecoder): - def __init__(self): - json.JSONDecoder.__init__(self, object_hook=self.object_hook) - - def object_hook(self, d): - if '__type__' in d and '__data__' in d: - if d['__type__'] == 'set': - return set(d['__data__']) - elif d['__type__'] == 'complex': - return complex(d['__data__'][0], d['__data__'][1]) - elif d['__type__'] == 'datetime': - return datetime.datetime.strptime(d['__data__'], '%Y-%m-%d %H:%M:%S') - elif d['__type__'] == 'instance': - mod = __import__(d['__module__'], fromlist=[d['__class__']]) - klass = getattr(mod, d['__class__']) - inst = klass.__new__(klass) - inst.__dict__ = d['__data__'] - return inst - return d - - def _save(name, val): - with open(name, 'w') as fp: - json.dump(val, fp, indent=2, separators=(',', ': '), sort_keys=True, cls=_Encoder) - - def _load(name): - with open(name, 'r') as fp: - return json.load(fp, cls=_Decoder) - - def _name(path, name): - try: - return os.path.join(path, 'snap-%s.json' % name.replace(' ','_').lower()) - except Exception: - raise ValueError('failed to create snapshot filename, either bad path or bad name') - - if id: - # custom id - snapname = _name(path, id) - else: - # make id from filename and line number - f = inspect.currentframe() - fpath = os.path.basename(f.f_back.f_code.co_filename) - fname = os.path.splitext(fpath)[0] - lineno = str(f.f_back.f_lineno) - snapname = _name(path, fname) - - if not os.path.exists(path): - os.makedirs(path) - - if os.path.isfile(snapname): - # snap exists, so load - snap = _load(snapname) - - if id: - # custom id, so test - return self.is_equal_to(snap) - else: - if lineno in snap: - # found sub-snap, so test - return self.is_equal_to(snap[lineno]) - else: - # lineno not in snap, so create sub-snap and pass - snap[lineno] = self.val - _save(snapname, snap) - else: - # no snap, so create and pass - _save(snapname, self.val if id else {lineno: self.val}) - - return self diff --git a/assertpy/base.py b/assertpy/base.py new file mode 100644 index 0000000..42b999c --- /dev/null +++ b/assertpy/base.py @@ -0,0 +1,128 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +class BaseMixin(object): + """Base mixin.""" + + def described_as(self, description): + """Describes the assertion. On failure, the description is included in the error message.""" + self.description = str(description) + return self + + def is_equal_to(self, other, **kwargs): + """Asserts that val is equal to other.""" + if self._check_dict_like(self.val, check_values=False, return_as_bool=True) and \ + self._check_dict_like(other, check_values=False, return_as_bool=True): + if self._dict_not_equal(self.val, other, ignore=kwargs.get('ignore'), include=kwargs.get('include')): + self._dict_err(self.val, other, ignore=kwargs.get('ignore'), include=kwargs.get('include')) + else: + if self.val != other: + self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val, other)) + return self + + def is_not_equal_to(self, other): + """Asserts that val is not equal to other.""" + if self.val == other: + self._err('Expected <%s> to be not equal to <%s>, but was.' % (self.val, other)) + return self + + def is_same_as(self, other): + """Asserts that the val is identical to other, via 'is' compare.""" + if self.val is not other: + self._err('Expected <%s> to be identical to <%s>, but was not.' % (self.val, other)) + return self + + def is_not_same_as(self, other): + """Asserts that the val is not identical to other, via 'is' compare.""" + if self.val is other: + self._err('Expected <%s> to be not identical to <%s>, but was.' % (self.val, other)) + return self + + def is_true(self): + """Asserts that val is true.""" + if not self.val: + self._err('Expected , but was not.') + return self + + def is_false(self): + """Asserts that val is false.""" + if self.val: + self._err('Expected , but was not.') + return self + + def is_none(self): + """Asserts that val is none.""" + if self.val is not None: + self._err('Expected <%s> to be , but was not.' % self.val) + return self + + def is_not_none(self): + """Asserts that val is not none.""" + if self.val is None: + self._err('Expected not , but was.') + return self + + def is_type_of(self, some_type): + """Asserts that val is of the given type.""" + if type(some_type) is not type and\ + not issubclass(type(some_type), type): + raise TypeError('given arg must be a type') + if type(self.val) is not some_type: + if hasattr(self.val, '__name__'): + t = self.val.__name__ + elif hasattr(self.val, '__class__'): + t = self.val.__class__.__name__ + else: + t = 'unknown' + self._err('Expected <%s:%s> to be of type <%s>, but was not.' % (self.val, t, some_type.__name__)) + return self + + def is_instance_of(self, some_class): + """Asserts that val is an instance of the given class.""" + try: + if not isinstance(self.val, some_class): + if hasattr(self.val, '__name__'): + t = self.val.__name__ + elif hasattr(self.val, '__class__'): + t = self.val.__class__.__name__ + else: + t = 'unknown' + self._err('Expected <%s:%s> to be instance of class <%s>, but was not.' % (self.val, t, some_class.__name__)) + except TypeError: + raise TypeError('given arg must be a class') + return self + + def is_length(self, length): + """Asserts that val is the given length.""" + if type(length) is not int: + raise TypeError('given arg must be an int') + if length < 0: + raise ValueError('given arg must be a positive int') + if len(self.val) != length: + self._err('Expected <%s> to be of length <%d>, but was <%d>.' % (self.val, length, len(self.val))) + return self \ No newline at end of file diff --git a/assertpy/collection.py b/assertpy/collection.py new file mode 100644 index 0000000..5f9d906 --- /dev/null +++ b/assertpy/collection.py @@ -0,0 +1,92 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import collections + +if sys.version_info[0] == 3: + Iterable = collections.abc.Iterable +else: + Iterable = collections.Iterable + + +class CollectionMixin(object): + """Collection assertions mixin.""" + + def is_iterable(self): + """Asserts that val is iterable collection.""" + if not isinstance(self.val, Iterable): + self._err('Expected iterable, but was not.') + return self + + def is_not_iterable(self): + """Asserts that val is not iterable collection.""" + if isinstance(self.val, Iterable): + self._err('Expected not iterable, but was.') + return self + + def is_subset_of(self, *supersets): + """Asserts that val is iterable and a subset of the given superset or flattened superset if multiple supersets are given.""" + if not isinstance(self.val, Iterable): + raise TypeError('val is not iterable') + if len(supersets) == 0: + raise ValueError('one or more superset args must be given') + + missing = [] + if hasattr(self.val, 'keys') and callable(getattr(self.val, 'keys')) and hasattr(self.val, '__getitem__'): + # flatten superset dicts + superdict = {} + for l,j in enumerate(supersets): + self._check_dict_like(j, check_values=False, name='arg #%d' % (l+1)) + for k in j.keys(): + superdict.update({k: j[k]}) + + for i in self.val.keys(): + if i not in superdict: + missing.append({i: self.val[i]}) # bad key + elif self.val[i] != superdict[i]: + missing.append({i: self.val[i]}) # bad val + if missing: + self._err('Expected <%s> to be subset of %s, but %s %s missing.' % (self.val, self._fmt_items(superdict), self._fmt_items(missing), 'was' if len(missing) == 1 else 'were')) + else: + # flatten supersets + superset = set() + for j in supersets: + try: + for k in j: + superset.add(k) + except Exception: + superset.add(j) + + for i in self.val: + if i not in superset: + missing.append(i) + if missing: + self._err('Expected <%s> to be subset of %s, but %s %s missing.' % (self.val, self._fmt_items(superset), self._fmt_items(missing), 'was' if len(missing) == 1 else 'were')) + + return self diff --git a/assertpy/contains.py b/assertpy/contains.py new file mode 100644 index 0000000..1fd87d5 --- /dev/null +++ b/assertpy/contains.py @@ -0,0 +1,170 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +if sys.version_info[0] == 3: + str_types = (str,) + xrange = range +else: + str_types = (basestring,) + xrange = xrange + + +class ContainsMixin(object): + """Containment assertions mixin.""" + + def contains(self, *items): + """Asserts that val contains the given item or items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + elif len(items) == 1: + if items[0] not in self.val: + if self._check_dict_like(self.val, return_as_bool=True): + self._err('Expected <%s> to contain key <%s>, but did not.' % (self.val, items[0])) + else: + self._err('Expected <%s> to contain item <%s>, but did not.' % (self.val, items[0])) + else: + missing = [] + for i in items: + if i not in self.val: + missing.append(i) + if missing: + if self._check_dict_like(self.val, return_as_bool=True): + self._err('Expected <%s> to contain keys %s, but did not contain key%s %s.' % (self.val, self._fmt_items(items), '' if len(missing) == 0 else 's', self._fmt_items(missing))) + else: + self._err('Expected <%s> to contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) + return self + + def does_not_contain(self, *items): + """Asserts that val does not contain the given item or items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + elif len(items) == 1: + if items[0] in self.val: + self._err('Expected <%s> to not contain item <%s>, but did.' % (self.val, items[0])) + else: + found = [] + for i in items: + if i in self.val: + found.append(i) + if found: + self._err('Expected <%s> to not contain items %s, but did contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(found))) + return self + + def contains_only(self, *items): + """Asserts that val contains only the given item or items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + else: + extra = [] + for i in self.val: + if i not in items: + extra.append(i) + if extra: + self._err('Expected <%s> to contain only %s, but did contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(extra))) + + missing = [] + for i in items: + if i not in self.val: + missing.append(i) + if missing: + self._err('Expected <%s> to contain only %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) + return self + + def contains_sequence(self, *items): + """Asserts that val contains the given sequence of items in order.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + else: + try: + for i in xrange(len(self.val) - len(items) + 1): + for j in xrange(len(items)): + if self.val[i+j] != items[j]: + break + else: + return self + except TypeError: + raise TypeError('val is not iterable') + self._err('Expected <%s> to contain sequence %s, but did not.' % (self.val, self._fmt_items(items))) + + def contains_duplicates(self): + """Asserts that val is iterable and contains duplicate items.""" + try: + if len(self.val) != len(set(self.val)): + return self + except TypeError: + raise TypeError('val is not iterable') + self._err('Expected <%s> to contain duplicates, but did not.' % self.val) + + def does_not_contain_duplicates(self): + """Asserts that val is iterable and does not contain any duplicate items.""" + try: + if len(self.val) == len(set(self.val)): + return self + except TypeError: + raise TypeError('val is not iterable') + self._err('Expected <%s> to not contain duplicates, but did.' % self.val) + + def is_empty(self): + """Asserts that val is empty.""" + if len(self.val) != 0: + if isinstance(self.val, str_types): + self._err('Expected <%s> to be empty string, but was not.' % self.val) + else: + self._err('Expected <%s> to be empty, but was not.' % self.val) + return self + + def is_not_empty(self): + """Asserts that val is not empty.""" + if len(self.val) == 0: + if isinstance(self.val, str_types): + self._err('Expected not empty string, but was empty.') + else: + self._err('Expected not empty, but was empty.') + return self + + def is_in(self, *items): + """Asserts that val is equal to one of the given items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + else: + for i in items: + if self.val == i: + return self + self._err('Expected <%s> to be in %s, but was not.' % (self.val, self._fmt_items(items))) + + def is_not_in(self, *items): + """Asserts that val is not equal to one of the given items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + else: + for i in items: + if self.val == i: + self._err('Expected <%s> to not be in %s, but was.' % (self.val, self._fmt_items(items))) + return self diff --git a/assertpy/date.py b/assertpy/date.py new file mode 100644 index 0000000..2dbafbb --- /dev/null +++ b/assertpy/date.py @@ -0,0 +1,81 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime + + +class DateMixin(object): + """Date and time assertions mixin.""" + +### datetime assertions ### + def is_before(self, other): + """Asserts that val is a date and is before other date.""" + if type(self.val) is not datetime.datetime: + raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) + if self.val >= other: + self._err('Expected <%s> to be before <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + return self + + def is_after(self, other): + """Asserts that val is a date and is after other date.""" + if type(self.val) is not datetime.datetime: + raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) + if self.val <= other: + self._err('Expected <%s> to be after <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + return self + + def is_equal_to_ignoring_milliseconds(self, other): + if type(self.val) is not datetime.datetime: + raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) + if self.val.date() != other.date() or self.val.hour != other.hour or self.val.minute != other.minute or self.val.second != other.second: + self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + return self + + def is_equal_to_ignoring_seconds(self, other): + if type(self.val) is not datetime.datetime: + raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) + if self.val.date() != other.date() or self.val.hour != other.hour or self.val.minute != other.minute: + self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M'), other.strftime('%Y-%m-%d %H:%M'))) + return self + + def is_equal_to_ignoring_time(self, other): + if type(self.val) is not datetime.datetime: + raise TypeError('val must be datetime, but was type <%s>' % type(self.val).__name__) + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was type <%s>' % type(other).__name__) + if self.val.date() != other.date(): + self._err('Expected <%s> to be equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d'), other.strftime('%Y-%m-%d'))) + return self diff --git a/assertpy/dict.py b/assertpy/dict.py new file mode 100644 index 0000000..d1f9379 --- /dev/null +++ b/assertpy/dict.py @@ -0,0 +1,108 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class DictMixin(object): + """Dict assertions mixin.""" + + def contains_key(self, *keys): + """Asserts the val is a dict and contains the given key or keys. Alias for contains().""" + self._check_dict_like(self.val, check_values=False, check_getitem=False) + return self.contains(*keys) + + def does_not_contain_key(self, *keys): + """Asserts the val is a dict and does not contain the given key or keys. Alias for does_not_contain().""" + self._check_dict_like(self.val, check_values=False, check_getitem=False) + return self.does_not_contain(*keys) + + def contains_value(self, *values): + """Asserts that val is a dict and contains the given value or values.""" + self._check_dict_like(self.val, check_getitem=False) + if len(values) == 0: + raise ValueError('one or more value args must be given') + missing = [] + for v in values: + if v not in self.val.values(): + missing.append(v) + if missing: + self._err('Expected <%s> to contain values %s, but did not contain %s.' % (self.val, self._fmt_items(values), self._fmt_items(missing))) + return self + + def does_not_contain_value(self, *values): + """Asserts that val is a dict and does not contain the given value or values.""" + self._check_dict_like(self.val, check_getitem=False) + if len(values) == 0: + raise ValueError('one or more value args must be given') + else: + found = [] + for v in values: + if v in self.val.values(): + found.append(v) + if found: + self._err('Expected <%s> to not contain values %s, but did contain %s.' % (self.val, self._fmt_items(values), self._fmt_items(found))) + return self + + def contains_entry(self, *args, **kwargs): + """Asserts that val is a dict and contains the given entry or entries.""" + self._check_dict_like(self.val, check_values=False) + entries = list(args) + [{k:v} for k,v in kwargs.items()] + if len(entries) == 0: + raise ValueError('one or more entry args must be given') + missing = [] + for e in entries: + if type(e) is not dict: + raise TypeError('given entry arg must be a dict') + if len(e) != 1: + raise ValueError('given entry args must contain exactly one key-value pair') + k = next(iter(e)) + if k not in self.val: + missing.append(e) # bad key + elif self.val[k] != e[k]: + missing.append(e) # bad val + if missing: + self._err('Expected <%s> to contain entries %s, but did not contain %s.' % (self.val, self._fmt_items(entries), self._fmt_items(missing))) + return self + + def does_not_contain_entry(self, *args, **kwargs): + """Asserts that val is a dict and does not contain the given entry or entries.""" + self._check_dict_like(self.val, check_values=False) + entries = list(args) + [{k:v} for k,v in kwargs.items()] + if len(entries) == 0: + raise ValueError('one or more entry args must be given') + found = [] + for e in entries: + if type(e) is not dict: + raise TypeError('given entry arg must be a dict') + if len(e) != 1: + raise ValueError('given entry args must contain exactly one key-value pair') + k = next(iter(e)) + if k in self.val and e[k] == self.val[k]: + found.append(e) + if found: + self._err('Expected <%s> to not contain entries %s, but did contain %s.' % (self.val, self._fmt_items(entries), self._fmt_items(found))) + return self \ No newline at end of file diff --git a/assertpy/dynamic.py b/assertpy/dynamic.py new file mode 100644 index 0000000..edaa963 --- /dev/null +++ b/assertpy/dynamic.py @@ -0,0 +1,82 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import collections + +if sys.version_info[0] == 3: + Iterable = collections.abc.Iterable +else: + Iterable = collections.Iterable + + +class DynamicMixin(object): + """Dynamic assertions mixin.""" + + def __getattr__(self, attr): + """Asserts that val has attribute attr and that attribute's value is equal to other via a dynamic assertion of the form: has_().""" + if not attr.startswith('has_'): + raise AttributeError('assertpy has no assertion <%s()>' % attr) + + attr_name = attr[4:] + err_msg = False + is_dict = isinstance(self.val, Iterable) and hasattr(self.val, '__getitem__') + + if not hasattr(self.val, attr_name): + if is_dict: + if attr_name not in self.val: + err_msg = 'Expected key <%s>, but val has no key <%s>.' % (attr_name, attr_name) + else: + err_msg = 'Expected attribute <%s>, but val has no attribute <%s>.' % (attr_name, attr_name) + + def _wrapper(*args, **kwargs): + if err_msg: + self._err(err_msg) # ok to raise AssertionError now that we are inside wrapper + else: + if len(args) != 1: + raise TypeError('assertion <%s()> takes exactly 1 argument (%d given)' % (attr, len(args))) + + try: + val_attr = getattr(self.val, attr_name) + except AttributeError: + val_attr = self.val[attr_name] + + if callable(val_attr): + try: + actual = val_attr() + except TypeError: + raise TypeError('val does not have zero-arg method <%s()>' % attr_name) + else: + actual = val_attr + + expected = args[0] + if actual != expected: + self._err('Expected <%s> to be equal to <%s> on %s <%s>, but was not.' % (actual, expected, 'key' if is_dict else 'attribute', attr_name)) + return self + + return _wrapper diff --git a/assertpy/exception.py b/assertpy/exception.py new file mode 100644 index 0000000..6093ecf --- /dev/null +++ b/assertpy/exception.py @@ -0,0 +1,71 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import collections + +if sys.version_info[0] == 3: + Iterable = collections.abc.Iterable +else: + Iterable = collections.Iterable + + +class ExceptionMixin(object): + """Expected exception mixin.""" + + def raises(self, ex): + """Asserts that val is callable and that when called raises the given error.""" + if not callable(self.val): + raise TypeError('val must be callable') + if not issubclass(ex, BaseException): + raise TypeError('given arg must be exception') + return self._builder(self.val, self.description, self.kind, ex) # don't chain! + + def when_called_with(self, *some_args, **some_kwargs): + """Asserts the val callable when invoked with the given args and kwargs raises the expected exception.""" + if not self.expected: + raise TypeError('expected exception not set, raises() must be called first') + try: + self.val(*some_args, **some_kwargs) + except BaseException as e: + if issubclass(type(e), self.expected): + # chain on with _message_ (don't chain to self!) + return self._builder(str(e), self.description, self.kind) + else: + # got exception, but wrong type, so raise + self._err('Expected <%s> to raise <%s> when called with (%s), but raised <%s>.' % ( + self.val.__name__, + self.expected.__name__, + self._fmt_args_kwargs(*some_args, **some_kwargs), + type(e).__name__)) + + # didn't fail as expected, so raise + self._err('Expected <%s> to raise <%s> when called with (%s).' % ( + self.val.__name__, + self.expected.__name__, + self._fmt_args_kwargs(*some_args, **some_kwargs))) diff --git a/assertpy/extracting.py b/assertpy/extracting.py new file mode 100644 index 0000000..f0a8b81 --- /dev/null +++ b/assertpy/extracting.py @@ -0,0 +1,110 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import collections + +if sys.version_info[0] == 3: + str_types = (str,) + Iterable = collections.abc.Iterable +else: + str_types = (basestring,) + Iterable = collections.Iterable + + +class ExtractingMixin(object): + """Collection flattening mixin.""" + +### collection of objects assertions ### + def extracting(self, *names, **kwargs): + """Asserts that val is collection, then extracts the named properties or named zero-arg methods into a list (or list of tuples if multiple names are given).""" + if not isinstance(self.val, Iterable): + raise TypeError('val is not iterable') + if isinstance(self.val, str_types): + raise TypeError('val must not be string') + if len(names) == 0: + raise ValueError('one or more name args must be given') + + def _extract(x, name): + if self._check_dict_like(x, check_values=False, return_as_bool=True): + if name in x: + return x[name] + else: + raise ValueError('item keys %s did not contain key <%s>' % (list(x.keys()), name)) + elif isinstance(x, Iterable): + self._check_iterable(x, name='item') + return x[name] + elif hasattr(x, name): + attr = getattr(x, name) + if callable(attr): + try: + return attr() + except TypeError: + raise ValueError('val method <%s()> exists, but is not zero-arg method' % name) + else: + return attr + else: + raise ValueError('val does not have property or zero-arg method <%s>' % name) + + def _filter(x): + if 'filter' in kwargs: + if isinstance(kwargs['filter'], str_types): + return bool(_extract(x, kwargs['filter'])) + elif self._check_dict_like(kwargs['filter'], check_values=False, return_as_bool=True): + for k in kwargs['filter']: + if isinstance(k, str_types): + if _extract(x, k) != kwargs['filter'][k]: + return False + return True + elif callable(kwargs['filter']): + return kwargs['filter'](x) + return False + return True + + def _sort(x): + if 'sort' in kwargs: + if isinstance(kwargs['sort'], str_types): + return _extract(x, kwargs['sort']) + elif isinstance(kwargs['sort'], Iterable): + items = [] + for k in kwargs['sort']: + if isinstance(k, str_types): + items.append(_extract(x, k)) + return tuple(items) + elif callable(kwargs['sort']): + return kwargs['sort'](x) + return 0 + + extracted = [] + for i in sorted(self.val, key=lambda x: _sort(x)): + if _filter(i): + items = [_extract(i, name) for name in names] + extracted.append(tuple(items) if len(items) > 1 else items[0]) + + # chain on with _extracted_ list (don't chain to self!) + return self._builder(extracted, self.description, self.kind) diff --git a/assertpy/file.py b/assertpy/file.py new file mode 100644 index 0000000..c881ee9 --- /dev/null +++ b/assertpy/file.py @@ -0,0 +1,123 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys + +if sys.version_info[0] == 3: + str_types = (str,) +else: + str_types = (basestring,) + + +def contents_of(f, encoding='utf-8'): + """Helper to read the contents of the given file or path into a string with the given encoding. + Encoding defaults to 'utf-8', other useful encodings are 'ascii' and 'latin-1'.""" + + try: + contents = f.read() + except AttributeError: + try: + with open(f, 'r') as fp: + contents = fp.read() + except TypeError: + raise ValueError('val must be file or path, but was type <%s>' % type(f).__name__) + except OSError: + if not isinstance(f, str_types): + raise ValueError('val must be file or path, but was type <%s>' % type(f).__name__) + raise + + if sys.version_info[0] == 3 and type(contents) is bytes: + # in PY3 force decoding of bytes to target encoding + return contents.decode(encoding, 'replace') + elif sys.version_info[0] == 2 and encoding == 'ascii': + # in PY2 force encoding back to ascii + return contents.encode('ascii', 'replace') + else: + # in all other cases, try to decode to target encoding + try: + return contents.decode(encoding, 'replace') + except AttributeError: + pass + # if all else fails, just return the contents "as is" + return contents + + +class FileMixin(object): + """File assertions mixin.""" + + def exists(self): + """Asserts that val is a path and that it exists.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a path') + if not os.path.exists(self.val): + self._err('Expected <%s> to exist, but was not found.' % self.val) + return self + + def does_not_exist(self): + """Asserts that val is a path and that it does not exist.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a path') + if os.path.exists(self.val): + self._err('Expected <%s> to not exist, but was found.' % self.val) + return self + + def is_file(self): + """Asserts that val is an existing path to a file.""" + self.exists() + if not os.path.isfile(self.val): + self._err('Expected <%s> to be a file, but was not.' % self.val) + return self + + def is_directory(self): + """Asserts that val is an existing path to a directory.""" + self.exists() + if not os.path.isdir(self.val): + self._err('Expected <%s> to be a directory, but was not.' % self.val) + return self + + def is_named(self, filename): + """Asserts that val is an existing path to a file and that file is named filename.""" + self.is_file() + if not isinstance(filename, str_types): + raise TypeError('given filename arg must be a path') + val_filename = os.path.basename(os.path.abspath(self.val)) + if val_filename != filename: + self._err('Expected filename <%s> to be equal to <%s>, but was not.' % (val_filename, filename)) + return self + + def is_child_of(self, parent): + """Asserts that val is an existing path to a file and that file is a child of parent.""" + self.is_file() + if not isinstance(parent, str_types): + raise TypeError('given parent directory arg must be a path') + val_abspath = os.path.abspath(self.val) + parent_abspath = os.path.abspath(parent) + if not val_abspath.startswith(parent_abspath): + self._err('Expected file <%s> to be a child of <%s>, but was not.' % (val_abspath, parent_abspath)) + return self diff --git a/assertpy/helpers.py b/assertpy/helpers.py new file mode 100644 index 0000000..00007d4 --- /dev/null +++ b/assertpy/helpers.py @@ -0,0 +1,228 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import print_function +import sys +import numbers +import datetime +import collections + +if sys.version_info[0] == 3: + Iterable = collections.abc.Iterable +else: + Iterable = collections.Iterable + + +class HelpersMixin(object): + """Helpers mixin.""" + + def _fmt_items(self, i): + if len(i) == 0: + return '<>' + elif len(i) == 1 and hasattr(i, '__getitem__'): + return '<%s>' % i[0] + else: + return '<%s>' % str(i).lstrip('([').rstrip(',])') + + def _fmt_args_kwargs(self, *some_args, **some_kwargs): + """Helper to convert the given args and kwargs into a string.""" + if some_args: + out_args = str(some_args).lstrip('(').rstrip(',)') + if some_kwargs: + out_kwargs = ', '.join([str(i).lstrip('(').rstrip(')').replace(', ',': ') for i in [ + (k,some_kwargs[k]) for k in sorted(some_kwargs.keys())]]) + + if some_args and some_kwargs: + return out_args + ', ' + out_kwargs + elif some_args: + return out_args + elif some_kwargs: + return out_kwargs + else: + return '' + + def _validate_between_args(self, val_type, low, high): + low_type = type(low) + high_type = type(high) + + if val_type in self.NUMERIC_NON_COMPAREABLE: + raise TypeError('ordering is not defined for type <%s>' % val_type.__name__) + + if val_type in self.NUMERIC_COMPAREABLE: + if low_type is not val_type: + raise TypeError('given low arg must be <%s>, but was <%s>' % (val_type.__name__, low_type.__name__)) + if high_type is not val_type: + raise TypeError('given high arg must be <%s>, but was <%s>' % (val_type.__name__, low_type.__name__)) + elif isinstance(self.val, numbers.Number): + if isinstance(low, numbers.Number) is False: + raise TypeError('given low arg must be numeric, but was <%s>' % low_type.__name__) + if isinstance(high, numbers.Number) is False: + raise TypeError('given high arg must be numeric, but was <%s>' % high_type.__name__) + else: + raise TypeError('ordering is not defined for type <%s>' % val_type.__name__) + + if low > high: + raise ValueError('given low arg must be less than given high arg') + + def _validate_close_to_args(self, val, other, tolerance): + if type(val) is complex or type(other) is complex or type(tolerance) is complex: + raise TypeError('ordering is not defined for complex numbers') + + if isinstance(val, numbers.Number) is False and type(val) is not datetime.datetime: + raise TypeError('val is not numeric or datetime') + + if type(val) is datetime.datetime: + if type(other) is not datetime.datetime: + raise TypeError('given arg must be datetime, but was <%s>' % type(other).__name__) + if type(tolerance) is not datetime.timedelta: + raise TypeError('given tolerance arg must be timedelta, but was <%s>' % type(tolerance).__name__) + else: + if isinstance(other, numbers.Number) is False: + raise TypeError('given arg must be numeric') + if isinstance(tolerance, numbers.Number) is False: + raise TypeError('given tolerance arg must be numeric') + if tolerance < 0: + raise ValueError('given tolerance arg must be positive') + + def _check_dict_like(self, d, check_keys=True, check_values=True, check_getitem=True, name='val', return_as_bool=False): + if not isinstance(d, Iterable): + if return_as_bool: + return False + else: + raise TypeError('%s <%s> is not dict-like: not iterable' % (name, type(d).__name__)) + if check_keys: + if not hasattr(d, 'keys') or not callable(getattr(d, 'keys')): + if return_as_bool: + return False + else: + raise TypeError('%s <%s> is not dict-like: missing keys()' % (name, type(d).__name__)) + if check_values: + if not hasattr(d, 'values') or not callable(getattr(d, 'values')): + if return_as_bool: + return False + else: + raise TypeError('%s <%s> is not dict-like: missing values()' % (name, type(d).__name__)) + if check_getitem: + if not hasattr(d, '__getitem__'): + if return_as_bool: + return False + else: + raise TypeError('%s <%s> is not dict-like: missing [] accessor' % (name, type(d).__name__)) + if return_as_bool: + return True + + def _check_iterable(self, l, check_getitem=True, name='val'): + if not isinstance(l, Iterable): + raise TypeError('%s <%s> is not iterable' % (name, type(l).__name__)) + if check_getitem: + if not hasattr(l, '__getitem__'): + raise TypeError('%s <%s> does not have [] accessor' % (name, type(l).__name__)) + + def _dict_not_equal(self, val, other, ignore=None, include=None): + if ignore or include: + ignores = self._dict_ignore(ignore) + includes = self._dict_include(include) + + # guarantee include keys are in val + if include: + missing = [] + for i in includes: + if i not in val: + missing.append(i) + if missing: + self._err('Expected <%s> to include key%s %s, but did not include key%s %s.' % ( + val, + '' if len(includes) == 1 else 's', + self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in includes]), + '' if len(missing) == 1 else 's', + self._fmt_items(missing))) + + if ignore and include: + k1 = set([k for k in val if k not in ignores and k in includes]) + elif ignore: + k1 = set([k for k in val if k not in ignores]) + else: # include + k1 = set([k for k in val if k in includes]) + + if ignore and include: + k2 = set([k for k in other if k not in ignores and k in includes]) + elif ignore: + k2 = set([k for k in other if k not in ignores]) + else: # include + k2 = set([k for k in other if k in includes]) + + if k1 != k2: + return True + else: + for k in k1: + if self._check_dict_like(val[k], check_values=False, return_as_bool=True) and self._check_dict_like(other[k], check_values=False, return_as_bool=True): + return self._dict_not_equal(val[k], other[k], + ignore=[i[1:] for i in ignores if type(i) is tuple and i[0] == k] if ignore else None, + include=[i[1:] for i in self._dict_ignore(include) if type(i) is tuple and i[0] == k] if include else None) + elif val[k] != other[k]: + return True + return False + else: + return val != other + + def _dict_ignore(self, ignore): + return [i[0] if type(i) is tuple and len(i) == 1 else i \ + for i in (ignore if type(ignore) is list else [ignore])] + + def _dict_include(self, include): + return [i[0] if type(i) is tuple else i \ + for i in (include if type(include) is list else [include])] + + def _dict_err(self, val, other, ignore=None, include=None): + def _dict_repr(d, other): + out = '' + ellip = False + for k,v in d.items(): + if k not in other: + out += '%s%s: %s' % (', ' if len(out) > 0 else '', repr(k), repr(v)) + elif v != other[k]: + out += '%s%s: %s' % (', ' if len(out) > 0 else '', repr(k), + _dict_repr(v, other[k]) if self._check_dict_like(v, check_values=False, return_as_bool=True) and self._check_dict_like(other[k], check_values=False, return_as_bool=True) else repr(v) + ) + else: + ellip = True + return '{%s%s}' % ('..' if ellip and len(out) == 0 else '.., ' if ellip else '', out) + + if ignore: + ignores = self._dict_ignore(ignore) + ignore_err = ' ignoring keys %s' % self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in ignores]) + if include: + includes = self._dict_ignore(include) + include_err = ' including keys %s' % self._fmt_items(['.'.join([str(s) for s in i]) if type(i) is tuple else i for i in includes]) + + self._err('Expected <%s> to be equal to <%s>%s%s, but was not.' % ( + _dict_repr(val, other), + _dict_repr(other, val), + ignore_err if ignore else '', + include_err if include else '' + )) diff --git a/assertpy/numeric.py b/assertpy/numeric.py new file mode 100644 index 0000000..e0effe8 --- /dev/null +++ b/assertpy/numeric.py @@ -0,0 +1,208 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import division, print_function +import sys +import math +import numbers +import datetime + + +class NumericMixin(object): + """Numeric assertions mixin.""" + + NUMERIC_COMPAREABLE = set([datetime.datetime, datetime.timedelta, datetime.date, datetime.time]) + NUMERIC_NON_COMPAREABLE = set([complex]) + + def _validate_compareable(self, other): + self_type = type(self.val) + other_type = type(other) + + if self_type in self.NUMERIC_NON_COMPAREABLE: + raise TypeError('ordering is not defined for type <%s>' % self_type.__name__) + if self_type in self.NUMERIC_COMPAREABLE: + if other_type is not self_type: + raise TypeError('given arg must be <%s>, but was <%s>' % (self_type.__name__, other_type.__name__)) + return + if isinstance(self.val, numbers.Number): + if not isinstance(other, numbers.Number): + raise TypeError('given arg must be a number, but was <%s>' % other_type.__name__) + return + raise TypeError('ordering is not defined for type <%s>' % self_type.__name__) + + def _validate_number(self): + """Raise TypeError if val is not numeric.""" + if isinstance(self.val, numbers.Number) is False: + raise TypeError('val is not numeric') + + def _validate_real(self): + """Raise TypeError if val is not real number.""" + if isinstance(self.val, numbers.Real) is False: + raise TypeError('val is not real number') + + def is_zero(self): + """Asserts that val is numeric and equal to zero.""" + self._validate_number() + return self.is_equal_to(0) + + def is_not_zero(self): + """Asserts that val is numeric and not equal to zero.""" + self._validate_number() + return self.is_not_equal_to(0) + + def is_nan(self): + """Asserts that val is real number and NaN (not a number).""" + self._validate_number() + self._validate_real() + if not math.isnan(self.val): + self._err('Expected <%s> to be , but was not.' % self.val) + return self + + def is_not_nan(self): + """Asserts that val is real number and not NaN (not a number).""" + self._validate_number() + self._validate_real() + if math.isnan(self.val): + self._err('Expected not , but was.') + return self + + def is_inf(self): + """Asserts that val is real number and Inf (infinity).""" + self._validate_number() + self._validate_real() + if not math.isinf(self.val): + self._err('Expected <%s> to be , but was not.' % self.val) + return self + + def is_not_inf(self): + """Asserts that val is real number and not Inf (infinity).""" + self._validate_number() + self._validate_real() + if math.isinf(self.val): + self._err('Expected not , but was.') + return self + + def is_greater_than(self, other): + """Asserts that val is numeric and is greater than other.""" + self._validate_compareable(other) + if self.val <= other: + if type(self.val) is datetime.datetime: + self._err('Expected <%s> to be greater than <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to be greater than <%s>, but was not.' % (self.val, other)) + return self + + def is_greater_than_or_equal_to(self, other): + """Asserts that val is numeric and is greater than or equal to other.""" + self._validate_compareable(other) + if self.val < other: + if type(self.val) is datetime.datetime: + self._err('Expected <%s> to be greater than or equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to be greater than or equal to <%s>, but was not.' % (self.val, other)) + return self + + def is_less_than(self, other): + """Asserts that val is numeric and is less than other.""" + self._validate_compareable(other) + if self.val >= other: + if type(self.val) is datetime.datetime: + self._err('Expected <%s> to be less than <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to be less than <%s>, but was not.' % (self.val, other)) + return self + + def is_less_than_or_equal_to(self, other): + """Asserts that val is numeric and is less than or equal to other.""" + self._validate_compareable(other) + if self.val > other: + if type(self.val) is datetime.datetime: + self._err('Expected <%s> to be less than or equal to <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to be less than or equal to <%s>, but was not.' % (self.val, other)) + return self + + def is_positive(self): + """Asserts that val is numeric and greater than zero.""" + return self.is_greater_than(0) + + def is_negative(self): + """Asserts that val is numeric and less than zero.""" + return self.is_less_than(0) + + def is_between(self, low, high): + """Asserts that val is numeric and is between low and high.""" + val_type = type(self.val) + self._validate_between_args(val_type, low, high) + + if self.val < low or self.val > high: + if val_type is datetime.datetime: + self._err('Expected <%s> to be between <%s> and <%s>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), low.strftime('%Y-%m-%d %H:%M:%S'), high.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to be between <%s> and <%s>, but was not.' % (self.val, low, high)) + return self + + def is_not_between(self, low, high): + """Asserts that val is numeric and is between low and high.""" + val_type = type(self.val) + self._validate_between_args(val_type, low, high) + + if self.val >= low and self.val <= high: + if val_type is datetime.datetime: + self._err('Expected <%s> to not be between <%s> and <%s>, but was.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), low.strftime('%Y-%m-%d %H:%M:%S'), high.strftime('%Y-%m-%d %H:%M:%S'))) + else: + self._err('Expected <%s> to not be between <%s> and <%s>, but was.' % (self.val, low, high)) + return self + + def is_close_to(self, other, tolerance): + """Asserts that val is numeric and is close to other within tolerance.""" + self._validate_close_to_args(self.val, other, tolerance) + + if self.val < (other-tolerance) or self.val > (other+tolerance): + if type(self.val) is datetime.datetime: + tolerance_seconds = tolerance.days * 86400 + tolerance.seconds + tolerance.microseconds / 1000000 + h, rem = divmod(tolerance_seconds, 3600) + m, s = divmod(rem, 60) + self._err('Expected <%s> to be close to <%s> within tolerance <%d:%02d:%02d>, but was not.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'), h, m, s)) + else: + self._err('Expected <%s> to be close to <%s> within tolerance <%s>, but was not.' % (self.val, other, tolerance)) + return self + + def is_not_close_to(self, other, tolerance): + """Asserts that val is numeric and is not close to other within tolerance.""" + self._validate_close_to_args(self.val, other, tolerance) + + if self.val >= (other-tolerance) and self.val <= (other+tolerance): + if type(self.val) is datetime.datetime: + tolerance_seconds = tolerance.days * 86400 + tolerance.seconds + tolerance.microseconds / 1000000 + h, rem = divmod(tolerance_seconds, 3600) + m, s = divmod(rem, 60) + self._err('Expected <%s> to not be close to <%s> within tolerance <%d:%02d:%02d>, but was.' % (self.val.strftime('%Y-%m-%d %H:%M:%S'), other.strftime('%Y-%m-%d %H:%M:%S'), h, m, s)) + else: + self._err('Expected <%s> to not be close to <%s> within tolerance <%s>, but was.' % (self.val, other, tolerance)) + return self diff --git a/assertpy/snapshot.py b/assertpy/snapshot.py new file mode 100644 index 0000000..6c3ef9f --- /dev/null +++ b/assertpy/snapshot.py @@ -0,0 +1,127 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +import datetime +import inspect +import json + + +class SnapshotMixin(object): + """Snapshot mixin.""" + + def snapshot(self, id=None, path='__snapshots'): + if sys.version_info[0] < 3: + raise NotImplementedError('snapshot testing requires Python 3') + + class _Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, set): + return {'__type__': 'set', '__data__': list(o)} + elif isinstance(o, complex): + return {'__type__': 'complex', '__data__': [o.real, o.imag]} + elif isinstance(o, datetime.datetime): + return {'__type__': 'datetime', '__data__': o.strftime('%Y-%m-%d %H:%M:%S')} + elif '__dict__' in dir(o) and type(o) is not type: + return { + '__type__': 'instance', + '__class__': o.__class__.__name__, + '__module__': o.__class__.__module__, + '__data__': o.__dict__ + } + return json.JSONEncoder.default(self, o) + + class _Decoder(json.JSONDecoder): + def __init__(self): + json.JSONDecoder.__init__(self, object_hook=self.object_hook) + + def object_hook(self, d): + if '__type__' in d and '__data__' in d: + if d['__type__'] == 'set': + return set(d['__data__']) + elif d['__type__'] == 'complex': + return complex(d['__data__'][0], d['__data__'][1]) + elif d['__type__'] == 'datetime': + return datetime.datetime.strptime(d['__data__'], '%Y-%m-%d %H:%M:%S') + elif d['__type__'] == 'instance': + mod = __import__(d['__module__'], fromlist=[d['__class__']]) + klass = getattr(mod, d['__class__']) + inst = klass.__new__(klass) + inst.__dict__ = d['__data__'] + return inst + return d + + def _save(name, val): + with open(name, 'w') as fp: + json.dump(val, fp, indent=2, separators=(',', ': '), sort_keys=True, cls=_Encoder) + + def _load(name): + with open(name, 'r') as fp: + return json.load(fp, cls=_Decoder) + + def _name(path, name): + try: + return os.path.join(path, 'snap-%s.json' % name.replace(' ','_').lower()) + except Exception: + raise ValueError('failed to create snapshot filename, either bad path or bad name') + + if id: + # custom id + snapname = _name(path, id) + else: + # make id from filename and line number + f = inspect.currentframe() + fpath = os.path.basename(f.f_back.f_code.co_filename) + fname = os.path.splitext(fpath)[0] + lineno = str(f.f_back.f_lineno) + snapname = _name(path, fname) + + if not os.path.exists(path): + os.makedirs(path) + + if os.path.isfile(snapname): + # snap exists, so load + snap = _load(snapname) + + if id: + # custom id, so test + return self.is_equal_to(snap) + else: + if lineno in snap: + # found sub-snap, so test + return self.is_equal_to(snap[lineno]) + else: + # lineno not in snap, so create sub-snap and pass + snap[lineno] = self.val + _save(snapname, snap) + else: + # no snap, so create and pass + _save(snapname, self.val if id else {lineno: self.val}) + + return self diff --git a/assertpy/string.py b/assertpy/string.py new file mode 100644 index 0000000..7a340e2 --- /dev/null +++ b/assertpy/string.py @@ -0,0 +1,207 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import re +import collections + +if sys.version_info[0] == 3: + str_types = (str,) + unicode = str + Iterable = collections.abc.Iterable +else: + str_types = (basestring,) + unicode = unicode + Iterable = collections.Iterable + + +class StringMixin(object): + """String assertions mixin.""" + + def is_equal_to_ignoring_case(self, other): + """Asserts that val is case-insensitive equal to other.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if not isinstance(other, str_types): + raise TypeError('given arg must be a string') + if self.val.lower() != other.lower(): + self._err('Expected <%s> to be case-insensitive equal to <%s>, but was not.' % (self.val, other)) + return self + + def contains_ignoring_case(self, *items): + """Asserts that val is string and contains the given item or items.""" + if len(items) == 0: + raise ValueError('one or more args must be given') + if isinstance(self.val, str_types): + if len(items) == 1: + if not isinstance(items[0], str_types): + raise TypeError('given arg must be a string') + if items[0].lower() not in self.val.lower(): + self._err('Expected <%s> to case-insensitive contain item <%s>, but did not.' % (self.val, items[0])) + else: + missing = [] + for i in items: + if not isinstance(i, str_types): + raise TypeError('given args must all be strings') + if i.lower() not in self.val.lower(): + missing.append(i) + if missing: + self._err('Expected <%s> to case-insensitive contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) + elif isinstance(self.val, Iterable): + missing = [] + for i in items: + if not isinstance(i, str_types): + raise TypeError('given args must all be strings') + found = False + for v in self.val: + if not isinstance(v, str_types): + raise TypeError('val items must all be strings') + if i.lower() == v.lower(): + found = True + break + if not found: + missing.append(i) + if missing: + self._err('Expected <%s> to case-insensitive contain items %s, but did not contain %s.' % (self.val, self._fmt_items(items), self._fmt_items(missing))) + else: + raise TypeError('val is not a string or iterable') + return self + + def starts_with(self, prefix): + """Asserts that val is string or iterable and starts with prefix.""" + if prefix is None: + raise TypeError('given prefix arg must not be none') + if isinstance(self.val, str_types): + if not isinstance(prefix, str_types): + raise TypeError('given prefix arg must be a string') + if len(prefix) == 0: + raise ValueError('given prefix arg must not be empty') + if not self.val.startswith(prefix): + self._err('Expected <%s> to start with <%s>, but did not.' % (self.val, prefix)) + elif isinstance(self.val, Iterable): + if len(self.val) == 0: + raise ValueError('val must not be empty') + first = next(iter(self.val)) + if first != prefix: + self._err('Expected %s to start with <%s>, but did not.' % (self.val, prefix)) + else: + raise TypeError('val is not a string or iterable') + return self + + def ends_with(self, suffix): + """Asserts that val is string or iterable and ends with suffix.""" + if suffix is None: + raise TypeError('given suffix arg must not be none') + if isinstance(self.val, str_types): + if not isinstance(suffix, str_types): + raise TypeError('given suffix arg must be a string') + if len(suffix) == 0: + raise ValueError('given suffix arg must not be empty') + if not self.val.endswith(suffix): + self._err('Expected <%s> to end with <%s>, but did not.' % (self.val, suffix)) + elif isinstance(self.val, Iterable): + if len(self.val) == 0: + raise ValueError('val must not be empty') + last = None + for last in self.val: + pass + if last != suffix: + self._err('Expected %s to end with <%s>, but did not.' % (self.val, suffix)) + else: + raise TypeError('val is not a string or iterable') + return self + + def matches(self, pattern): + """Asserts that val is string and matches regex pattern.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if not isinstance(pattern, str_types): + raise TypeError('given pattern arg must be a string') + if len(pattern) == 0: + raise ValueError('given pattern arg must not be empty') + if re.search(pattern, self.val) is None: + self._err('Expected <%s> to match pattern <%s>, but did not.' % (self.val, pattern)) + return self + + def does_not_match(self, pattern): + """Asserts that val is string and does not match regex pattern.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if not isinstance(pattern, str_types): + raise TypeError('given pattern arg must be a string') + if len(pattern) == 0: + raise ValueError('given pattern arg must not be empty') + if re.search(pattern, self.val) is not None: + self._err('Expected <%s> to not match pattern <%s>, but did.' % (self.val, pattern)) + return self + + def is_alpha(self): + """Asserts that val is non-empty string and all characters are alphabetic.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if len(self.val) == 0: + raise ValueError('val is empty') + if not self.val.isalpha(): + self._err('Expected <%s> to contain only alphabetic chars, but did not.' % self.val) + return self + + def is_digit(self): + """Asserts that val is non-empty string and all characters are digits.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if len(self.val) == 0: + raise ValueError('val is empty') + if not self.val.isdigit(): + self._err('Expected <%s> to contain only digits, but did not.' % self.val) + return self + + def is_lower(self): + """Asserts that val is non-empty string and all characters are lowercase.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if len(self.val) == 0: + raise ValueError('val is empty') + if self.val != self.val.lower(): + self._err('Expected <%s> to contain only lowercase chars, but did not.' % self.val) + return self + + def is_upper(self): + """Asserts that val is non-empty string and all characters are uppercase.""" + if not isinstance(self.val, str_types): + raise TypeError('val is not a string') + if len(self.val) == 0: + raise ValueError('val is empty') + if self.val != self.val.upper(): + self._err('Expected <%s> to contain only uppercase chars, but did not.' % self.val) + return self + + def is_unicode(self): + """Asserts that val is a unicode string.""" + if type(self.val) is not unicode: + self._err('Expected <%s> to be unicode, but was <%s>.' % (self.val, type(self.val).__name__)) + return self From 4b424faab4721ad5ccb2cedac3fc3b86e0fd59de Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Sun, 20 Oct 2019 20:53:31 -0600 Subject: [PATCH 02/11] fixup traceback test --- tests/test_traceback.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_traceback.py b/tests/test_traceback.py index 8579745..d16fd0b 100644 --- a/tests/test_traceback.py +++ b/tests/test_traceback.py @@ -52,10 +52,10 @@ def test_traceback(): assert_that(frames[0][1]).is_equal_to('test_traceback') assert_that(frames[0][2]).is_equal_to(35) - assert_that(frames[1][0]).ends_with('assertpy.py') + assert_that(frames[1][0]).ends_with('base.py') assert_that(frames[1][1]).is_equal_to('is_equal_to') - assert_that(frames[1][2]).is_greater_than(160) + assert_that(frames[1][2]).is_greater_than(40) assert_that(frames[2][0]).ends_with('assertpy.py') assert_that(frames[2][1]).is_equal_to('_err') - assert_that(frames[2][2]).is_greater_than(1000) + assert_that(frames[2][2]).is_greater_than(100) From 83e8bb51ff817c46fad1332bc2db8f1fc20d8390 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Sun, 20 Oct 2019 20:58:27 -0600 Subject: [PATCH 03/11] add extension system --- assertpy/assertpy.py | 32 ++++++++++++++----- tests/test_extensions.py | 66 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 8 deletions(-) create mode 100644 tests/test_extensions.py diff --git a/assertpy/assertpy.py b/assertpy/assertpy.py index f1e3d7e..83faf1e 100644 --- a/assertpy/assertpy.py +++ b/assertpy/assertpy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2019, Activision Publishing, Inc. +# Copyright (c) 2015-2018, Activision Publishing, Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without modification, @@ -29,6 +29,7 @@ """Assertion library for python unit testing with a fluent API""" import os +import types import contextlib from .base import BaseMixin from .contains import ContainsMixin @@ -50,7 +51,7 @@ contextlib.__tracebackhide__ = True # monkey patch contextlib with clean py.test tracebacks -### soft assertions ### +# soft assertions _soft_ctx = 0 _soft_err = [] @@ -78,19 +79,18 @@ def soft_assertions(): _soft_err = [] raise AssertionError(out) - -### factory methods ### +# factory methods def assert_that(val, description=''): """Factory method for the assertion builder with value to be tested and optional description.""" global _soft_ctx if _soft_ctx: - return AssertionBuilder(val, description, 'soft') - return AssertionBuilder(val, description) + return builder(val, description, 'soft') + return builder(val, description) def assert_warn(val, description=''): """Factory method for the assertion builder with value to be tested, optional description, and just warn on assertion failures instead of raisings exceptions.""" - return AssertionBuilder(val, description, 'warn') + return builder(val, description, 'warn') def fail(msg=''): """Force test failure with the given message.""" @@ -106,6 +106,22 @@ def soft_fail(msg=''): return fail(msg) +# assertion extensions +_extensions = [] +def add_extension(func): + if not callable(func): + raise TypeError('func must be callable') + _extensions.append(func) + +def builder(val, description='', kind=None, expected=None): + ab = AssertionBuilder(val, description, kind, expected) + if _extensions: + # glue extension method onto new builder instance + for func in _extensions: + meth = types.MethodType(func, ab) + setattr(ab, func.__name__, meth) + return ab + class AssertionBuilder(DynamicMixin, ExceptionMixin, SnapshotMixin, ExtractingMixin, FileMixin, DateMixin, DictMixin, CollectionMixin, StringMixin, NumericMixin, @@ -123,7 +139,7 @@ def __init__(self, val, description='', kind=None, expected=None): def _builder(self, val, description='', kind=None, expected=None): """Helper to build a new Builder. Only used when we don't want to chain.""" - return AssertionBuilder(val, description, kind, expected) + return builder(val, description, kind, expected) def _err(self, msg): """Helper to raise an AssertionError, and optionally prepend custom description.""" diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 0000000..8bd6321 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,66 @@ +# Copyright (c) 2015-2019, Activision Publishing, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from assertpy import assert_that, add_extension +import numbers + +def is_even(self): + if isinstance(self.val, numbers.Integral) is False: + raise TypeError('val is not integer number') + if self.val % 2 != 0: + self._err('Expected <%s> to be even, but was not.' % (self.val)) + return self + + +def test_extension(): + add_extension(is_even) + assert_that(124).is_even() + assert_that(12345678901234567890).is_even() + +def test_extension_failure(): + try: + add_extension(is_even) + assert_that(123).is_even() + fail('should have raised error') + except AssertionError as ex: + assert_that(str(ex)).is_equal_to('Expected <123> to be even, but was not.') + +def test_extension_failure_not_callable(): + try: + add_extension('foo') + fail('should have raised error') + except TypeError as ex: + assert_that(str(ex)).is_equal_to('func must be callable') + +def test_extension_failure_not_integer(): + try: + add_extension(is_even) + assert_that(124.0).is_even() + fail('should have raised error') + except TypeError as ex: + assert_that(str(ex)).is_equal_to('val is not integer number') From d18a31e39c708cd536208917cfd84844aedcaf34 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Sun, 20 Oct 2019 21:20:06 -0600 Subject: [PATCH 04/11] fixup extension test --- tests/test_extensions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 8bd6321..360adc5 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -34,12 +34,13 @@ def is_even(self): raise TypeError('val is not integer number') if self.val % 2 != 0: self._err('Expected <%s> to be even, but was not.' % (self.val)) - return self + return self def test_extension(): add_extension(is_even) assert_that(124).is_even() + assert_that(124).is_type_of(int).is_even().is_greater_than(123).is_less_than(125).is_equal_to(124) assert_that(12345678901234567890).is_even() def test_extension_failure(): From d06a238a2509bbd4df12771d2a4f8f50b9ef7340 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Mon, 21 Oct 2019 19:18:41 -0600 Subject: [PATCH 05/11] refactor type name getter into func --- assertpy/base.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/assertpy/base.py b/assertpy/base.py index 42b999c..847a6cc 100644 --- a/assertpy/base.py +++ b/assertpy/base.py @@ -87,18 +87,19 @@ def is_not_none(self): self._err('Expected not , but was.') return self + def _type(self, val): + if hasattr(val, '__name__'): + return val.__name__ + elif hasattr(val, '__class__'): + return val.__class__.__name__ + return 'unknown' + def is_type_of(self, some_type): """Asserts that val is of the given type.""" - if type(some_type) is not type and\ - not issubclass(type(some_type), type): + if type(some_type) is not type and not issubclass(type(some_type), type): raise TypeError('given arg must be a type') if type(self.val) is not some_type: - if hasattr(self.val, '__name__'): - t = self.val.__name__ - elif hasattr(self.val, '__class__'): - t = self.val.__class__.__name__ - else: - t = 'unknown' + t = self._type(self.val) self._err('Expected <%s:%s> to be of type <%s>, but was not.' % (self.val, t, some_type.__name__)) return self @@ -106,12 +107,7 @@ def is_instance_of(self, some_class): """Asserts that val is an instance of the given class.""" try: if not isinstance(self.val, some_class): - if hasattr(self.val, '__name__'): - t = self.val.__name__ - elif hasattr(self.val, '__class__'): - t = self.val.__class__.__name__ - else: - t = 'unknown' + t = self._type(self.val) self._err('Expected <%s:%s> to be instance of class <%s>, but was not.' % (self.val, t, some_class.__name__)) except TypeError: raise TypeError('given arg must be a class') From b53d26559c0de3a437046ac65e37266fff882536 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Mon, 21 Oct 2019 19:19:00 -0600 Subject: [PATCH 06/11] remove unused check --- assertpy/assertpy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/assertpy/assertpy.py b/assertpy/assertpy.py index 83faf1e..0dc947d 100644 --- a/assertpy/assertpy.py +++ b/assertpy/assertpy.py @@ -132,8 +132,6 @@ def __init__(self, val, description='', kind=None, expected=None): """Construct the assertion builder.""" self.val = val self.description = description - if kind not in set([None, 'soft', 'warn']): - raise ValueError('bad assertion kind') self.kind = kind self.expected = expected From 97a01a884ca282fef3ff8db61d5139e307031ed7 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Mon, 21 Oct 2019 19:19:51 -0600 Subject: [PATCH 07/11] add tests of is_multiple_of() extension --- tests/test_extensions.py | 61 +++++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 7 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 360adc5..546a5b5 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -29,21 +29,32 @@ from assertpy import assert_that, add_extension import numbers + def is_even(self): if isinstance(self.val, numbers.Integral) is False: - raise TypeError('val is not integer number') + raise TypeError('val must be an integer') if self.val % 2 != 0: self._err('Expected <%s> to be even, but was not.' % (self.val)) return self +def is_multiple_of(self, other): + if isinstance(self.val, numbers.Integral) is False: + raise TypeError('val must be an integer') + if isinstance(other, numbers.Integral) is False: + raise TypeError('given arg must be an integer') + _, rem = divmod(self.val, other) + if rem != 0: + self._err('Expected <%s> to be multiple of <%s>, but was not.' % (self.val, other)) + return self + + -def test_extension(): +def test_is_even_extension(): add_extension(is_even) assert_that(124).is_even() assert_that(124).is_type_of(int).is_even().is_greater_than(123).is_less_than(125).is_equal_to(124) - assert_that(12345678901234567890).is_even() -def test_extension_failure(): +def test_is_even_extension_failure(): try: add_extension(is_even) assert_that(123).is_even() @@ -51,17 +62,53 @@ def test_extension_failure(): except AssertionError as ex: assert_that(str(ex)).is_equal_to('Expected <123> to be even, but was not.') -def test_extension_failure_not_callable(): +def test_is_even_extension_failure_not_callable(): try: add_extension('foo') fail('should have raised error') except TypeError as ex: assert_that(str(ex)).is_equal_to('func must be callable') -def test_extension_failure_not_integer(): +def test_is_even_extension_failure_not_integer(): try: add_extension(is_even) assert_that(124.0).is_even() fail('should have raised error') except TypeError as ex: - assert_that(str(ex)).is_equal_to('val is not integer number') + assert_that(str(ex)).is_equal_to('val must be an integer') + +def test_is_multiple_of_extension(): + add_extension(is_multiple_of) + assert_that(24).is_multiple_of(1) + assert_that(24).is_multiple_of(2) + assert_that(24).is_multiple_of(3) + assert_that(24).is_multiple_of(4) + assert_that(24).is_multiple_of(6) + assert_that(24).is_multiple_of(8) + assert_that(24).is_multiple_of(12) + assert_that(24).is_multiple_of(24) + assert_that(124).is_type_of(int).is_even().is_multiple_of(31).is_equal_to(124) + +def test_is_multiple_of_extension_failure(): + try: + add_extension(is_multiple_of) + assert_that(24).is_multiple_of(5) + fail('should have raised error') + except AssertionError as ex: + assert_that(str(ex)).is_equal_to('Expected <24> to be multiple of <5>, but was not.') + +def test_is_multiple_of_extension_failure_not_integer(): + try: + add_extension(is_multiple_of) + assert_that(24.0).is_multiple_of(5) + fail('should have raised error') + except TypeError as ex: + assert_that(str(ex)).is_equal_to('val must be an integer') + +def test_is_multiple_of_extension_failure_arg_not_integer(): + try: + add_extension(is_multiple_of) + assert_that(24).is_multiple_of('foo') + fail('should have raised error') + except TypeError as ex: + assert_that(str(ex)).is_equal_to('given arg must be an integer') From 0c754547c6b75aa5d8519cfb932c91e82e53166d Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Mon, 21 Oct 2019 20:04:39 -0600 Subject: [PATCH 08/11] refactor to add extensions just once --- tests/test_extensions.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 546a5b5..ab518bc 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -47,16 +47,16 @@ def is_multiple_of(self, other): self._err('Expected <%s> to be multiple of <%s>, but was not.' % (self.val, other)) return self +add_extension(is_even) +add_extension(is_multiple_of) def test_is_even_extension(): - add_extension(is_even) assert_that(124).is_even() assert_that(124).is_type_of(int).is_even().is_greater_than(123).is_less_than(125).is_equal_to(124) def test_is_even_extension_failure(): try: - add_extension(is_even) assert_that(123).is_even() fail('should have raised error') except AssertionError as ex: @@ -71,14 +71,12 @@ def test_is_even_extension_failure_not_callable(): def test_is_even_extension_failure_not_integer(): try: - add_extension(is_even) assert_that(124.0).is_even() fail('should have raised error') except TypeError as ex: assert_that(str(ex)).is_equal_to('val must be an integer') def test_is_multiple_of_extension(): - add_extension(is_multiple_of) assert_that(24).is_multiple_of(1) assert_that(24).is_multiple_of(2) assert_that(24).is_multiple_of(3) @@ -91,7 +89,6 @@ def test_is_multiple_of_extension(): def test_is_multiple_of_extension_failure(): try: - add_extension(is_multiple_of) assert_that(24).is_multiple_of(5) fail('should have raised error') except AssertionError as ex: @@ -99,7 +96,6 @@ def test_is_multiple_of_extension_failure(): def test_is_multiple_of_extension_failure_not_integer(): try: - add_extension(is_multiple_of) assert_that(24.0).is_multiple_of(5) fail('should have raised error') except TypeError as ex: @@ -107,7 +103,6 @@ def test_is_multiple_of_extension_failure_not_integer(): def test_is_multiple_of_extension_failure_arg_not_integer(): try: - add_extension(is_multiple_of) assert_that(24).is_multiple_of('foo') fail('should have raised error') except TypeError as ex: From f6a60069cc73857418d36b2cda3bcefb85d37523 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Mon, 21 Oct 2019 20:37:25 -0600 Subject: [PATCH 09/11] cleanups, extension system doc --- README.md | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1ceeb3c..adfec74 100644 --- a/README.md +++ b/README.md @@ -568,7 +568,7 @@ assert_that(people).extracting('first_name').contains('Fred','Bob') assert_that(people).extracting('first_name').does_not_contain('Charlie') ``` -Of couse `extracting` works with subclasses too...suppose we create a simple class hierarchy by creating a `Developer` subclass of `Person`, like this: +Of course `extracting` works with subclasses too...suppose we create a simple class hierarchy by creating a `Developer` subclass of `Person`, like this: ```py class Developer(Person): @@ -852,7 +852,7 @@ forgiving behavior, you can use `soft_fail()` which is collected like any other ### Snapshot Testing -Take a snapshot of a python data structure, store it on disk in JSON format, and automatically compare the latest data to the stored data on every test run. The snapshot testing features of `assertpy` are borrowed from [Jest](https://facebook.github.io/jest/), a well-kwown and powerful Javascript testing framework. +Take a snapshot of a python data structure, store it on disk in JSON format, and automatically compare the latest data to the stored data on every test run. The snapshot testing features of `assertpy` are borrowed from [Jest](https://facebook.github.io/jest/), a well-kwown and powerful Javascript testing framework. Snapshots require Python 3. For example, snapshot the following dict: @@ -909,7 +909,33 @@ assert_that({'a':1,'b':2,'c':3}).snapshot(path='my-custom-folder') #### Snapshot Blackbox -Functional testing (which snapshot testing falls under) is very much blackbox testing. When something goes wrong, it's hard to pinpoint the issue, because they provide little *isolation*. On the plus side, snapshots can provide enormous *leverage* as a few well-place snapshots can strongly verify application state that would require dozens if not hundreds of unit tests. +Functional testing (which snapshot testing falls under) is very much blackbox testing. When something goes wrong, it's hard to pinpoint the issue, because functional tests provide little *isolation*. On the plus side, snapshots can provide enormous *leverage* as a few well-placed snapshot tests can strongly verify an application is working that would otherwise require dozens if not hundreds of unit tests. + +### Extending via Custom Assertions + +Sometimes you want to add your own custom assertions to `assertpy`. This can be done using the `add_extension()` helper. + +For example, we can write a custom `is_5()` assertion like this: + +```py +from assertpy import assert_that, add_extension + +def is_5(self): + if self.val != 5: + self._err(f'{self.val} is NOT 5!') + return self + +add_extension(is_5) +``` + +Once registered with `assertpy`, we can use our new assertion as expected (at the file level): + +```py +assert_that(5).is_5() +assert_that(6).is_5() # fails! +``` + +If you want better control of extension scope, such as writing extensions once and using them anywhere in your tests, you'll need to use the test setup functionality of your test runner. With [pytest](http://pytest.org/latest/contents.html), you can just use a `conftest.py` file and a _setup_ fixture. ### Chaining @@ -944,7 +970,7 @@ There are always a few new features in the works...if you'd like to help, check All files are licensed under the BSD 3-Clause License as follows: -> Copyright (c) 2015-2018, Activision Publishing, Inc. +> Copyright (c) 2015-2019, Activision Publishing, Inc. > All rights reserved. > > Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: From e1deba1d785fe8f7eacd90fb5b633102a313e823 Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Tue, 22 Oct 2019 16:03:09 -0600 Subject: [PATCH 10/11] import future print in the right spot --- assertpy/assertpy.py | 1 + assertpy/helpers.py | 1 - assertpy/numeric.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/assertpy/assertpy.py b/assertpy/assertpy.py index 0dc947d..c5ea208 100644 --- a/assertpy/assertpy.py +++ b/assertpy/assertpy.py @@ -28,6 +28,7 @@ """Assertion library for python unit testing with a fluent API""" +from __future__ import print_function import os import types import contextlib diff --git a/assertpy/helpers.py b/assertpy/helpers.py index 00007d4..e5099a1 100644 --- a/assertpy/helpers.py +++ b/assertpy/helpers.py @@ -26,7 +26,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import print_function import sys import numbers import datetime diff --git a/assertpy/numeric.py b/assertpy/numeric.py index e0effe8..e8721bc 100644 --- a/assertpy/numeric.py +++ b/assertpy/numeric.py @@ -26,7 +26,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import division, print_function +from __future__ import division import sys import math import numbers From 4afe08c1a4372b81d5750b995ff585de076f9dff Mon Sep 17 00:00:00 2001 From: Justin Shacklette Date: Sat, 26 Oct 2019 17:10:11 -0600 Subject: [PATCH 11/11] more extension doc --- README.md | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index adfec74..2c437cf 100644 --- a/README.md +++ b/README.md @@ -911,14 +911,14 @@ assert_that({'a':1,'b':2,'c':3}).snapshot(path='my-custom-folder') Functional testing (which snapshot testing falls under) is very much blackbox testing. When something goes wrong, it's hard to pinpoint the issue, because functional tests provide little *isolation*. On the plus side, snapshots can provide enormous *leverage* as a few well-placed snapshot tests can strongly verify an application is working that would otherwise require dozens if not hundreds of unit tests. -### Extending via Custom Assertions +### Extension System - adding custom assertions Sometimes you want to add your own custom assertions to `assertpy`. This can be done using the `add_extension()` helper. For example, we can write a custom `is_5()` assertion like this: ```py -from assertpy import assert_that, add_extension +from assertpy import add_extension def is_5(self): if self.val != 5: @@ -928,14 +928,75 @@ def is_5(self): add_extension(is_5) ``` -Once registered with `assertpy`, we can use our new assertion as expected (at the file level): +Once registered with `assertpy`, we can use our new assertion as expected: ```py assert_that(5).is_5() assert_that(6).is_5() # fails! ``` -If you want better control of extension scope, such as writing extensions once and using them anywhere in your tests, you'll need to use the test setup functionality of your test runner. With [pytest](http://pytest.org/latest/contents.html), you can just use a `conftest.py` file and a _setup_ fixture. +Of course, `is_5()` is only available in the test file where `add_extension()` is called. If you want better control of scope of your custom extensions, such as writing extensions once and using them in any test file, you'll need to use the test setup functionality of your test runner. With [pytest](http://pytest.org/latest/contents.html), you can just use a `conftest.py` file and a _fixture_. + +For example, if your `conftest.py` is: + +```py +import pytest +from assertpy import add_extension + +def is_5(self): + if self.val != 5: + self._err(f'{self.val} is NOT 5!') + return self + +@pytest.fixture(scope='module') +def my_extensions(): + add_extension(is_5) +``` + +Then in any test method in any test file (like `test_foo.py` for example), you just pass in the fixture and all of your extensions are available, like this: + +```py +from assertpy import assert_that + +def test_foo(my_extensions): + assert_that(5).is_5() + assert_that(6).is_5() # fails! +``` + +#### Writing custom assertions + +Here are some useful tips to help you write your own custom assertions: + +1. Use `self.val` to get the _actual_ value to be tested. +2. It's better to test the negative and fail if true. +3. Fail by raising an `AssertionError` +4. Always use the `self._err()` helper to fail (and print your failure message). +5. Always `return self` to allow for chaining. + +Putting it all together, here is another custom assertion example, but annotated with comments: + +```py +def is_multiple_of(self, other): + # validate actual value - must be "integer" (aka int or long) + if isinstance(self.val, numbers.Integral) is False and self.val > 0: + # bad input is error, not an assertion fail, so raise error + raise TypeError('val must be a positive integer') + + # validate expected value + if isinstance(other, numbers.Integral) is False and other > 0: + raise TypeError('given arg must be a positive integer') + + # divide and compute remainder using divmod() built-in + _, rem = divmod(self.val, other) + + # test the negative (is remainder non-zero?) + if rem > 0: + # non-zero remainder, so not multiple -> we fail! + self._err('Expected <%s> to be multiple of <%s>, but was not.' % (self.val, other)) + + # success, and return self to allow chaining + return self +``` ### Chaining