664 lines
21 KiB
Python
664 lines
21 KiB
Python
"""
|
|
Module consolidating common testing functions for checking plotting.
|
|
|
|
Currently all plotting tests are marked as slow via
|
|
``pytestmark = pytest.mark.slow`` at the module level.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Sequence,
|
|
)
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from pandas.util._decorators import cache_readonly
|
|
import pandas.util._test_decorators as td
|
|
|
|
from pandas.core.dtypes.api import is_list_like
|
|
|
|
import pandas as pd
|
|
from pandas import (
|
|
DataFrame,
|
|
Series,
|
|
to_datetime,
|
|
)
|
|
import pandas._testing as tm
|
|
|
|
if TYPE_CHECKING:
|
|
from matplotlib.axes import Axes
|
|
|
|
|
|
@td.skip_if_no_mpl
|
|
class TestPlotBase:
|
|
"""
|
|
This is a common base class used for various plotting tests
|
|
"""
|
|
|
|
def setup_method(self, method):
|
|
|
|
import matplotlib as mpl
|
|
|
|
from pandas.plotting._matplotlib import compat
|
|
|
|
self.compat = compat
|
|
|
|
mpl.rcdefaults()
|
|
|
|
self.start_date_to_int64 = 812419200000000000
|
|
self.end_date_to_int64 = 819331200000000000
|
|
|
|
self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3()
|
|
self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0()
|
|
self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0()
|
|
self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0()
|
|
|
|
self.bp_n_objects = 7
|
|
self.polycollection_factor = 2
|
|
self.default_figsize = (6.4, 4.8)
|
|
self.default_tick_position = "left"
|
|
|
|
n = 100
|
|
with tm.RNGContext(42):
|
|
gender = np.random.choice(["Male", "Female"], size=n)
|
|
classroom = np.random.choice(["A", "B", "C"], size=n)
|
|
|
|
self.hist_df = DataFrame(
|
|
{
|
|
"gender": gender,
|
|
"classroom": classroom,
|
|
"height": np.random.normal(66, 4, size=n),
|
|
"weight": np.random.normal(161, 32, size=n),
|
|
"category": np.random.randint(4, size=n),
|
|
"datetime": to_datetime(
|
|
np.random.randint(
|
|
self.start_date_to_int64,
|
|
self.end_date_to_int64,
|
|
size=n,
|
|
dtype=np.int64,
|
|
)
|
|
),
|
|
}
|
|
)
|
|
|
|
self.tdf = tm.makeTimeDataFrame()
|
|
self.hexbin_df = DataFrame(
|
|
{
|
|
"A": np.random.uniform(size=20),
|
|
"B": np.random.uniform(size=20),
|
|
"C": np.arange(20) + np.random.uniform(size=20),
|
|
}
|
|
)
|
|
|
|
def teardown_method(self, method):
|
|
tm.close()
|
|
|
|
@cache_readonly
|
|
def plt(self):
|
|
import matplotlib.pyplot as plt
|
|
|
|
return plt
|
|
|
|
@cache_readonly
|
|
def colorconverter(self):
|
|
import matplotlib.colors as colors
|
|
|
|
return colors.colorConverter
|
|
|
|
def _check_legend_labels(self, axes, labels=None, visible=True):
|
|
"""
|
|
Check each axes has expected legend labels
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
labels : list-like
|
|
expected legend labels
|
|
visible : bool
|
|
expected legend visibility. labels are checked only when visible is
|
|
True
|
|
"""
|
|
if visible and (labels is None):
|
|
raise ValueError("labels must be specified when visible is True")
|
|
axes = self._flatten_visible(axes)
|
|
for ax in axes:
|
|
if visible:
|
|
assert ax.get_legend() is not None
|
|
self._check_text_labels(ax.get_legend().get_texts(), labels)
|
|
else:
|
|
assert ax.get_legend() is None
|
|
|
|
def _check_legend_marker(self, ax, expected_markers=None, visible=True):
|
|
"""
|
|
Check ax has expected legend markers
|
|
|
|
Parameters
|
|
----------
|
|
ax : matplotlib Axes object
|
|
expected_markers : list-like
|
|
expected legend markers
|
|
visible : bool
|
|
expected legend visibility. labels are checked only when visible is
|
|
True
|
|
"""
|
|
if visible and (expected_markers is None):
|
|
raise ValueError("Markers must be specified when visible is True")
|
|
if visible:
|
|
handles, _ = ax.get_legend_handles_labels()
|
|
markers = [handle.get_marker() for handle in handles]
|
|
assert markers == expected_markers
|
|
else:
|
|
assert ax.get_legend() is None
|
|
|
|
def _check_data(self, xp, rs):
|
|
"""
|
|
Check each axes has identical lines
|
|
|
|
Parameters
|
|
----------
|
|
xp : matplotlib Axes object
|
|
rs : matplotlib Axes object
|
|
"""
|
|
xp_lines = xp.get_lines()
|
|
rs_lines = rs.get_lines()
|
|
|
|
def check_line(xpl, rsl):
|
|
xpdata = xpl.get_xydata()
|
|
rsdata = rsl.get_xydata()
|
|
tm.assert_almost_equal(xpdata, rsdata)
|
|
|
|
assert len(xp_lines) == len(rs_lines)
|
|
[check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
|
|
tm.close()
|
|
|
|
def _check_visible(self, collections, visible=True):
|
|
"""
|
|
Check each artist is visible or not
|
|
|
|
Parameters
|
|
----------
|
|
collections : matplotlib Artist or its list-like
|
|
target Artist or its list or collection
|
|
visible : bool
|
|
expected visibility
|
|
"""
|
|
from matplotlib.collections import Collection
|
|
|
|
if not isinstance(collections, Collection) and not is_list_like(collections):
|
|
collections = [collections]
|
|
|
|
for patch in collections:
|
|
assert patch.get_visible() == visible
|
|
|
|
def _check_patches_all_filled(
|
|
self, axes: Axes | Sequence[Axes], filled: bool = True
|
|
) -> None:
|
|
"""
|
|
Check for each artist whether it is filled or not
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
filled : bool
|
|
expected filling
|
|
"""
|
|
|
|
axes = self._flatten_visible(axes)
|
|
for ax in axes:
|
|
for patch in ax.patches:
|
|
assert patch.fill == filled
|
|
|
|
def _get_colors_mapped(self, series, colors):
|
|
unique = series.unique()
|
|
# unique and colors length can be differed
|
|
# depending on slice value
|
|
mapped = dict(zip(unique, colors))
|
|
return [mapped[v] for v in series.values]
|
|
|
|
def _check_colors(
|
|
self, collections, linecolors=None, facecolors=None, mapping=None
|
|
):
|
|
"""
|
|
Check each artist has expected line colors and face colors
|
|
|
|
Parameters
|
|
----------
|
|
collections : list-like
|
|
list or collection of target artist
|
|
linecolors : list-like which has the same length as collections
|
|
list of expected line colors
|
|
facecolors : list-like which has the same length as collections
|
|
list of expected face colors
|
|
mapping : Series
|
|
Series used for color grouping key
|
|
used for andrew_curves, parallel_coordinates, radviz test
|
|
"""
|
|
from matplotlib.collections import (
|
|
Collection,
|
|
LineCollection,
|
|
PolyCollection,
|
|
)
|
|
from matplotlib.lines import Line2D
|
|
|
|
conv = self.colorconverter
|
|
if linecolors is not None:
|
|
|
|
if mapping is not None:
|
|
linecolors = self._get_colors_mapped(mapping, linecolors)
|
|
linecolors = linecolors[: len(collections)]
|
|
|
|
assert len(collections) == len(linecolors)
|
|
for patch, color in zip(collections, linecolors):
|
|
if isinstance(patch, Line2D):
|
|
result = patch.get_color()
|
|
# Line2D may contains string color expression
|
|
result = conv.to_rgba(result)
|
|
elif isinstance(patch, (PolyCollection, LineCollection)):
|
|
result = tuple(patch.get_edgecolor()[0])
|
|
else:
|
|
result = patch.get_edgecolor()
|
|
|
|
expected = conv.to_rgba(color)
|
|
assert result == expected
|
|
|
|
if facecolors is not None:
|
|
|
|
if mapping is not None:
|
|
facecolors = self._get_colors_mapped(mapping, facecolors)
|
|
facecolors = facecolors[: len(collections)]
|
|
|
|
assert len(collections) == len(facecolors)
|
|
for patch, color in zip(collections, facecolors):
|
|
if isinstance(patch, Collection):
|
|
# returned as list of np.array
|
|
result = patch.get_facecolor()[0]
|
|
else:
|
|
result = patch.get_facecolor()
|
|
|
|
if isinstance(result, np.ndarray):
|
|
result = tuple(result)
|
|
|
|
expected = conv.to_rgba(color)
|
|
assert result == expected
|
|
|
|
def _check_text_labels(self, texts, expected):
|
|
"""
|
|
Check each text has expected labels
|
|
|
|
Parameters
|
|
----------
|
|
texts : matplotlib Text object, or its list-like
|
|
target text, or its list
|
|
expected : str or list-like which has the same length as texts
|
|
expected text label, or its list
|
|
"""
|
|
if not is_list_like(texts):
|
|
assert texts.get_text() == expected
|
|
else:
|
|
labels = [t.get_text() for t in texts]
|
|
assert len(labels) == len(expected)
|
|
for label, e in zip(labels, expected):
|
|
assert label == e
|
|
|
|
def _check_ticks_props(
|
|
self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
|
|
):
|
|
"""
|
|
Check each axes has expected tick properties
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
xlabelsize : number
|
|
expected xticks font size
|
|
xrot : number
|
|
expected xticks rotation
|
|
ylabelsize : number
|
|
expected yticks font size
|
|
yrot : number
|
|
expected yticks rotation
|
|
"""
|
|
from matplotlib.ticker import NullFormatter
|
|
|
|
axes = self._flatten_visible(axes)
|
|
for ax in axes:
|
|
if xlabelsize is not None or xrot is not None:
|
|
if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
|
|
# If minor ticks has NullFormatter, rot / fontsize are not
|
|
# retained
|
|
labels = ax.get_xticklabels()
|
|
else:
|
|
labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
|
|
|
|
for label in labels:
|
|
if xlabelsize is not None:
|
|
tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
|
|
if xrot is not None:
|
|
tm.assert_almost_equal(label.get_rotation(), xrot)
|
|
|
|
if ylabelsize is not None or yrot is not None:
|
|
if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
|
|
labels = ax.get_yticklabels()
|
|
else:
|
|
labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
|
|
|
|
for label in labels:
|
|
if ylabelsize is not None:
|
|
tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
|
|
if yrot is not None:
|
|
tm.assert_almost_equal(label.get_rotation(), yrot)
|
|
|
|
def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
|
|
"""
|
|
Check each axes has expected scales
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
xaxis : {'linear', 'log'}
|
|
expected xaxis scale
|
|
yaxis : {'linear', 'log'}
|
|
expected yaxis scale
|
|
"""
|
|
axes = self._flatten_visible(axes)
|
|
for ax in axes:
|
|
assert ax.xaxis.get_scale() == xaxis
|
|
assert ax.yaxis.get_scale() == yaxis
|
|
|
|
def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
|
|
"""
|
|
Check expected number of axes is drawn in expected layout
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
axes_num : number
|
|
expected number of axes. Unnecessary axes should be set to
|
|
invisible.
|
|
layout : tuple
|
|
expected layout, (expected number of rows , columns)
|
|
figsize : tuple
|
|
expected figsize. default is matplotlib default
|
|
"""
|
|
from pandas.plotting._matplotlib.tools import flatten_axes
|
|
|
|
if figsize is None:
|
|
figsize = self.default_figsize
|
|
visible_axes = self._flatten_visible(axes)
|
|
|
|
if axes_num is not None:
|
|
assert len(visible_axes) == axes_num
|
|
for ax in visible_axes:
|
|
# check something drawn on visible axes
|
|
assert len(ax.get_children()) > 0
|
|
|
|
if layout is not None:
|
|
result = self._get_axes_layout(flatten_axes(axes))
|
|
assert result == layout
|
|
|
|
tm.assert_numpy_array_equal(
|
|
visible_axes[0].figure.get_size_inches(),
|
|
np.array(figsize, dtype=np.float64),
|
|
)
|
|
|
|
def _get_axes_layout(self, axes):
|
|
x_set = set()
|
|
y_set = set()
|
|
for ax in axes:
|
|
# check axes coordinates to estimate layout
|
|
points = ax.get_position().get_points()
|
|
x_set.add(points[0][0])
|
|
y_set.add(points[0][1])
|
|
return (len(y_set), len(x_set))
|
|
|
|
def _flatten_visible(self, axes):
|
|
"""
|
|
Flatten axes, and filter only visible
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
|
|
"""
|
|
from pandas.plotting._matplotlib.tools import flatten_axes
|
|
|
|
axes = flatten_axes(axes)
|
|
axes = [ax for ax in axes if ax.get_visible()]
|
|
return axes
|
|
|
|
def _check_has_errorbars(self, axes, xerr=0, yerr=0):
|
|
"""
|
|
Check axes has expected number of errorbars
|
|
|
|
Parameters
|
|
----------
|
|
axes : matplotlib Axes object, or its list-like
|
|
xerr : number
|
|
expected number of x errorbar
|
|
yerr : number
|
|
expected number of y errorbar
|
|
"""
|
|
axes = self._flatten_visible(axes)
|
|
for ax in axes:
|
|
containers = ax.containers
|
|
xerr_count = 0
|
|
yerr_count = 0
|
|
for c in containers:
|
|
has_xerr = getattr(c, "has_xerr", False)
|
|
has_yerr = getattr(c, "has_yerr", False)
|
|
if has_xerr:
|
|
xerr_count += 1
|
|
if has_yerr:
|
|
yerr_count += 1
|
|
assert xerr == xerr_count
|
|
assert yerr == yerr_count
|
|
|
|
def _check_box_return_type(
|
|
self, returned, return_type, expected_keys=None, check_ax_title=True
|
|
):
|
|
"""
|
|
Check box returned type is correct
|
|
|
|
Parameters
|
|
----------
|
|
returned : object to be tested, returned from boxplot
|
|
return_type : str
|
|
return_type passed to boxplot
|
|
expected_keys : list-like, optional
|
|
group labels in subplot case. If not passed,
|
|
the function checks assuming boxplot uses single ax
|
|
check_ax_title : bool
|
|
Whether to check the ax.title is the same as expected_key
|
|
Intended to be checked by calling from ``boxplot``.
|
|
Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
|
|
"""
|
|
from matplotlib.axes import Axes
|
|
|
|
types = {"dict": dict, "axes": Axes, "both": tuple}
|
|
if expected_keys is None:
|
|
# should be fixed when the returning default is changed
|
|
if return_type is None:
|
|
return_type = "dict"
|
|
|
|
assert isinstance(returned, types[return_type])
|
|
if return_type == "both":
|
|
assert isinstance(returned.ax, Axes)
|
|
assert isinstance(returned.lines, dict)
|
|
else:
|
|
# should be fixed when the returning default is changed
|
|
if return_type is None:
|
|
for r in self._flatten_visible(returned):
|
|
assert isinstance(r, Axes)
|
|
return
|
|
|
|
assert isinstance(returned, Series)
|
|
|
|
assert sorted(returned.keys()) == sorted(expected_keys)
|
|
for key, value in returned.items():
|
|
assert isinstance(value, types[return_type])
|
|
# check returned dict has correct mapping
|
|
if return_type == "axes":
|
|
if check_ax_title:
|
|
assert value.get_title() == key
|
|
elif return_type == "both":
|
|
if check_ax_title:
|
|
assert value.ax.get_title() == key
|
|
assert isinstance(value.ax, Axes)
|
|
assert isinstance(value.lines, dict)
|
|
elif return_type == "dict":
|
|
line = value["medians"][0]
|
|
axes = line.axes
|
|
if check_ax_title:
|
|
assert axes.get_title() == key
|
|
else:
|
|
raise AssertionError
|
|
|
|
def _check_grid_settings(self, obj, kinds, kws={}):
|
|
# Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
|
|
|
|
import matplotlib as mpl
|
|
|
|
def is_grid_on():
|
|
xticks = self.plt.gca().xaxis.get_major_ticks()
|
|
yticks = self.plt.gca().yaxis.get_major_ticks()
|
|
# for mpl 2.2.2, gridOn and gridline.get_visible disagree.
|
|
# for new MPL, they are the same.
|
|
|
|
if self.mpl_ge_3_1_0:
|
|
xoff = all(not g.gridline.get_visible() for g in xticks)
|
|
yoff = all(not g.gridline.get_visible() for g in yticks)
|
|
else:
|
|
xoff = all(not g.gridOn for g in xticks)
|
|
yoff = all(not g.gridOn for g in yticks)
|
|
|
|
return not (xoff and yoff)
|
|
|
|
spndx = 1
|
|
for kind in kinds:
|
|
|
|
self.plt.subplot(1, 4 * len(kinds), spndx)
|
|
spndx += 1
|
|
mpl.rc("axes", grid=False)
|
|
obj.plot(kind=kind, **kws)
|
|
assert not is_grid_on()
|
|
|
|
self.plt.subplot(1, 4 * len(kinds), spndx)
|
|
spndx += 1
|
|
mpl.rc("axes", grid=True)
|
|
obj.plot(kind=kind, grid=False, **kws)
|
|
assert not is_grid_on()
|
|
|
|
if kind not in ["pie", "hexbin", "scatter"]:
|
|
self.plt.subplot(1, 4 * len(kinds), spndx)
|
|
spndx += 1
|
|
mpl.rc("axes", grid=True)
|
|
obj.plot(kind=kind, **kws)
|
|
assert is_grid_on()
|
|
|
|
self.plt.subplot(1, 4 * len(kinds), spndx)
|
|
spndx += 1
|
|
mpl.rc("axes", grid=False)
|
|
obj.plot(kind=kind, grid=True, **kws)
|
|
assert is_grid_on()
|
|
|
|
def _unpack_cycler(self, rcParams, field="color"):
|
|
"""
|
|
Auxiliary function for correctly unpacking cycler after MPL >= 1.5
|
|
"""
|
|
return [v[field] for v in rcParams["axes.prop_cycle"]]
|
|
|
|
def get_x_axis(self, ax):
|
|
return ax._shared_axes["x"] if self.compat.mpl_ge_3_5_0() else ax._shared_x_axes
|
|
|
|
def get_y_axis(self, ax):
|
|
return ax._shared_axes["y"] if self.compat.mpl_ge_3_5_0() else ax._shared_y_axes
|
|
|
|
|
|
def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
|
|
"""
|
|
Create plot and ensure that plot return object is valid.
|
|
|
|
Parameters
|
|
----------
|
|
f : func
|
|
Plotting function.
|
|
filterwarnings : str
|
|
Warnings filter.
|
|
See https://docs.python.org/3/library/warnings.html#warning-filter
|
|
default_axes : bool, optional
|
|
If False (default):
|
|
- If `ax` not in `kwargs`, then create subplot(211) and plot there
|
|
- Create new subplot(212) and plot there as well
|
|
- Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
|
|
If True:
|
|
- Simply run plotting function with kwargs provided
|
|
- All required axes instances will be created automatically
|
|
- It is recommended to use it when the plotting function
|
|
creates multiple axes itself. It helps avoid warnings like
|
|
'UserWarning: To output multiple subplots,
|
|
the figure containing the passed axes is being cleared'
|
|
**kwargs
|
|
Keyword arguments passed to the plotting function.
|
|
|
|
Returns
|
|
-------
|
|
Plot object returned by the last plotting.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
if default_axes:
|
|
gen_plots = _gen_default_plot
|
|
else:
|
|
gen_plots = _gen_two_subplots
|
|
|
|
ret = None
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter(filterwarnings)
|
|
try:
|
|
fig = kwargs.get("figure", plt.gcf())
|
|
plt.clf()
|
|
|
|
for ret in gen_plots(f, fig, **kwargs):
|
|
tm.assert_is_valid_plot_return_object(ret)
|
|
|
|
with tm.ensure_clean(return_filelike=True) as path:
|
|
plt.savefig(path)
|
|
|
|
except Exception as err:
|
|
raise err
|
|
finally:
|
|
tm.close(fig)
|
|
|
|
return ret
|
|
|
|
|
|
def _gen_default_plot(f, fig, **kwargs):
|
|
"""
|
|
Create plot in a default way.
|
|
"""
|
|
yield f(**kwargs)
|
|
|
|
|
|
def _gen_two_subplots(f, fig, **kwargs):
|
|
"""
|
|
Create plot on two subplots forcefully created.
|
|
"""
|
|
if "ax" not in kwargs:
|
|
fig.add_subplot(211)
|
|
yield f(**kwargs)
|
|
|
|
if f is pd.plotting.bootstrap_plot:
|
|
assert "ax" not in kwargs
|
|
else:
|
|
kwargs["ax"] = fig.add_subplot(212)
|
|
yield f(**kwargs)
|
|
|
|
|
|
def curpath():
|
|
pth, _ = os.path.split(os.path.abspath(__file__))
|
|
return pth
|