209 lines
5.7 KiB
Cython

import cmath
import math
import numpy as np
from numpy cimport import_array
import_array()
from pandas._libs.util cimport (
is_array,
is_complex_object,
is_real_number_object,
)
from pandas.core.dtypes.common import is_dtype_equal
from pandas.core.dtypes.missing import (
array_equivalent,
isna,
)
cdef bint isiterable(obj):
return hasattr(obj, '__iter__')
cdef bint has_length(obj):
return hasattr(obj, '__len__')
cdef bint is_dictlike(obj):
return hasattr(obj, 'keys') and hasattr(obj, '__getitem__')
cpdef assert_dict_equal(a, b, bint compare_keys=True):
assert is_dictlike(a) and is_dictlike(b), (
"Cannot compare dict objects, one or both is not dict-like"
)
a_keys = frozenset(a.keys())
b_keys = frozenset(b.keys())
if compare_keys:
assert a_keys == b_keys
for k in a_keys:
assert_almost_equal(a[k], b[k])
return True
cpdef assert_almost_equal(a, b,
rtol=1.e-5, atol=1.e-8,
bint check_dtype=True,
obj=None, lobj=None, robj=None, index_values=None):
"""
Check that left and right objects are almost equal.
Parameters
----------
a : object
b : object
rtol : float, default 1e-5
Relative tolerance.
.. versionadded:: 1.1.0
atol : float, default 1e-8
Absolute tolerance.
.. versionadded:: 1.1.0
check_dtype: bool, default True
check dtype if both a and b are np.ndarray.
obj : str, default None
Specify object name being compared, internally used to show
appropriate assertion message.
lobj : str, default None
Specify left object name being compared, internally used to show
appropriate assertion message.
robj : str, default None
Specify right object name being compared, internally used to show
appropriate assertion message.
index_values : ndarray, default None
Specify shared index values of objects being compared, internally used
to show appropriate assertion message.
.. versionadded:: 1.1.0
"""
cdef:
double diff = 0.0
Py_ssize_t i, na, nb
double fa, fb
bint is_unequal = False, a_is_ndarray, b_is_ndarray
if lobj is None:
lobj = a
if robj is None:
robj = b
if isinstance(a, dict) or isinstance(b, dict):
return assert_dict_equal(a, b)
if isinstance(a, str) or isinstance(b, str):
assert a == b, f"{a} != {b}"
return True
a_is_ndarray = is_array(a)
b_is_ndarray = is_array(b)
if obj is None:
if a_is_ndarray or b_is_ndarray:
obj = 'numpy array'
else:
obj = 'Iterable'
if isiterable(a):
if not isiterable(b):
from pandas._testing import assert_class_equal
# classes can't be the same, to raise error
assert_class_equal(a, b, obj=obj)
assert has_length(a) and has_length(b), (
f"Can't compare objects without length, one or both is invalid: ({a}, {b})"
)
if a_is_ndarray and b_is_ndarray:
na, nb = a.size, b.size
if a.shape != b.shape:
from pandas._testing import raise_assert_detail
raise_assert_detail(
obj, f'{obj} shapes are different', a.shape, b.shape)
if check_dtype and not is_dtype_equal(a.dtype, b.dtype):
from pandas._testing import assert_attr_equal
assert_attr_equal('dtype', a, b, obj=obj)
if array_equivalent(a, b, strict_nan=True):
return True
else:
na, nb = len(a), len(b)
if na != nb:
from pandas._testing import raise_assert_detail
# if we have a small diff set, print it
if abs(na - nb) < 10:
r = list(set(a) ^ set(b))
else:
r = None
raise_assert_detail(obj, f"{obj} length are different", na, nb, r)
for i in range(len(a)):
try:
assert_almost_equal(a[i], b[i], rtol=rtol, atol=atol)
except AssertionError:
is_unequal = True
diff += 1
if is_unequal:
from pandas._testing import raise_assert_detail
msg = (f"{obj} values are different "
f"({np.round(diff * 100.0 / na, 5)} %)")
raise_assert_detail(obj, msg, lobj, robj, index_values=index_values)
return True
elif isiterable(b):
from pandas._testing import assert_class_equal
# classes can't be the same, to raise error
assert_class_equal(a, b, obj=obj)
if isna(a) and isna(b):
# TODO: Should require same-dtype NA?
# nan / None comparison
return True
if a == b:
# object comparison
return True
if is_real_number_object(a) and is_real_number_object(b):
if array_equivalent(a, b, strict_nan=True):
# inf comparison
return True
fa, fb = a, b
if not math.isclose(fa, fb, rel_tol=rtol, abs_tol=atol):
assert False, (f"expected {fb:.5f} but got {fa:.5f}, "
f"with rtol={rtol}, atol={atol}")
return True
if is_complex_object(a) and is_complex_object(b):
if array_equivalent(a, b, strict_nan=True):
# inf comparison
return True
if not cmath.isclose(a, b, rel_tol=rtol, abs_tol=atol):
assert False, (f"expected {b:.5f} but got {a:.5f}, "
f"with rtol={rtol}, atol={atol}")
return True
raise AssertionError(f"{a} != {b}")