Source code for slickml.visualization._shap

from typing import List, Optional, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import shap
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.figure import Figure

from slickml.utils import check_var

# TODO(amir): this options should be set globally too
sns.set_style("ticks")
mpl.rcParams["axes.linewidth"] = 2
mpl.rcParams["lines.linewidth"] = 2


# TODO(amir): double check this for `multi-outputs` problems
# TODO(amir): implement return_fig: Optional[bool] = False,
[docs]def plot_shap_summary( shap_values: np.ndarray, features: Union[pd.DataFrame, np.ndarray], *, plot_type: Optional[str] = "dot", figsize: Optional[Union[str, Tuple[float, float]]] = "auto", color: Optional[str] = None, cmap: Optional[LinearSegmentedColormap] = None, max_display: Optional[int] = 20, feature_names: Optional[List[str]] = None, layered_violin_max_num_bins: Optional[int] = 10, title: Optional[str] = None, sort: Optional[bool] = True, color_bar: Optional[bool] = True, class_names: Optional[List[str]] = None, class_inds: Optional[List[int]] = None, color_bar_label: Optional[str] = "Feature Value", save_path: Optional[str] = None, display_plot: Optional[bool] = True, ) -> None: """Visualizes shap beeswarm plot as summary of shapley values. Notes ----- This is a helper function to plot the ``shap`` summary plot based on all types of ``shap.Explainer`` including ``shap.LinearExplainer`` for linear models, ``shap.TreeExplainer`` for tree-based models, and ``shap.DeepExplainer`` deep neural network models. More on details are available at [shap-api]_. Parameters ---------- shap_values : np.ndarray Calculated SHAP values array. For single output explanations such as binary classification problems, this will be a matrix of SHAP values with a shape of ``(n_samples, n_features)``. Additionally, for multi-output explanations this would be a list of such matrices of SHAP values (``List[np.ndarray]``) features : Union[pd.DataFrame, np.ndarray] The feature matrix that was used to calculate the SHAP values. For the case of Numpy array it is recommened to pass the ``feature_names`` list as well for better visualization results plot_type : str, optional The type of summary plot where possible options are "bar", "dot", "violin", "layered_violin", and "compact_dot". Recommendations are "dot" for single-output such as binary classifications, "bar" for multi-output problems, "compact_dot" for Shap interactions, by default "dot" figsize : tuple, optional Figure size where "auto" is auto-scaled figure size based on the number of features that are being displayed. Passing a single float will cause each row to be that many inches high. Passing a pair of floats will scale the plot by that number of inches. If None is passed then the size of the current figure will be left unchanged, by default "auto" color : str, optional Color of plots when ``plot_type="violin"`` and ``plot_type=layered_violin"`` are "RdBl" color-map while color of the horizontal lines when ``plot_type="bar"`` is "#D0AAF3", by default None cmap : LinearSegmentedColormap, optional Color map when ``plot_type="violin"`` and ``plot_type=layered_violin"``, by default "RdBl" max_display : int, optional Limit to show the number of features in the plot, by default 20 feature_names : List[str], optional List of feature names to pass. It should follow the order of features, by default None layered_violin_max_num_bins : int, optional The number of bins for calculating the violin plots ranges and outliers, by default 10 title : str, optional Title of the plot, by default None sort : bool, optional Flag to plot sorted shap vlues in descending order, by default True color_bar : bool, optional Flag to show a color bar when ``plot_type="dot"`` or ``plot_type="violin"`` class_names : List[str], optional List of class names for multi-output problems, by default None class_inds : List[int], optional List of class indices for multi-output problems, by default None color_bar_label : str, optional Label for color bar, by default "Feature Value" save_path : str, optional The full or relative path to save the plot including the image format such as "myplot.png" or "../../myplot.pdf", by default None display_plot : bool, optional Whether to show the plot, by default True References ---------- .. [shap-api] https://shap-lrjball.readthedocs.io/en/latest/generated/shap.summary_plot.html Returns ------- None """ check_var( shap_values, var_name="shap_values", dtypes=np.ndarray, ) check_var( features, var_name="features", dtypes=(np.ndarray, pd.DataFrame), ) check_var( plot_type, var_name="plot_type", dtypes=str, values=( "bar", "dot", "violin", "layered_violin", "compact_dot", ), ) check_var( figsize, var_name="figsize", dtypes=(str, tuple), ) if color: check_var( color, var_name="color", dtypes=str, ) if not color and plot_type == "bar": color = "#D0AAF3" if cmap: check_var( cmap, var_name="cmap", dtypes=LinearSegmentedColormap, ) check_var( max_display, var_name="max_display", dtypes=int, ) if feature_names: check_var( feature_names, var_name="feature_names", dtypes=list, ) check_var( layered_violin_max_num_bins, var_name="layered_violin_max_num_bins", dtypes=int, ) if title: check_var( title, var_name="title", dtypes=str, ) check_var( sort, var_name="sort", dtypes=bool, ) check_var( color_bar, var_name="color_bar", dtypes=bool, ) if class_names: check_var( class_names, var_name="class_names", dtypes=list, ) if class_inds: check_var( class_inds, var_name="class_inds", dtypes=list, ) check_var( color_bar_label, var_name="color_bar_label", dtypes=str, ) if save_path: check_var( save_path, var_name="save_path", dtypes=str, ) check_var( display_plot, var_name="display_plot", dtypes=bool, ) shap.summary_plot( shap_values, features, plot_type=plot_type, plot_size=figsize, color=color, cmap=cmap, max_display=max_display, feature_names=feature_names, title=title, show=display_plot, sort=sort, color_bar=color_bar, layered_violin_max_num_bins=layered_violin_max_num_bins, class_names=class_names, class_inds=class_inds, color_bar_label=color_bar_label, ) if save_path: plt.savefig( save_path, bbox_inches="tight", dpi=200, ) if display_plot: plt.show() return None
[docs]def plot_shap_waterfall( shap_values: np.ndarray, features: Union[pd.DataFrame, np.ndarray], *, figsize: Optional[Tuple[float, float]] = (8, 5), bar_color: Optional[str] = "#B3C3F3", bar_thickness: Optional[Union[float, int]] = 0.5, line_color: Optional[str] = "purple", marker: Optional[str] = "o", markersize: Optional[Union[int, float]] = 7, markeredgecolor: Optional[str] = "purple", markerfacecolor: Optional[str] = "purple", markeredgewidth: Optional[Union[int, float]] = 1, max_display: Optional[int] = 20, title: Optional[str] = None, fontsize: Optional[Union[int, float]] = 12, save_path: Optional[str] = None, display_plot: Optional[bool] = True, return_fig: Optional[bool] = False, ) -> Optional[Figure]: """Visualizes the Shapley values as a waterfall plot. pl This function is a helper function to plot the shap summary plot based on all types of shap explainers including tree, linear, and dnn. Parameters ---------- shap_values : np.ndarray Calculated SHAP values array. For single output explanations such as binary classification problems, this will be a matrix of SHAP values with a shape of ``(n_samples, n_features)``. Additionally, for multi-output explanations this would be a list of such matrices of SHAP values (``List[np.ndarray]``) features : Union[pd.DataFrame, np.ndarray] The feature matrix that was used to calculate the SHAP values. For the case of Numpy array it is recommened to pass the ``feature_names`` list as well for better visualization results figsize : Tuple[float, float], optional Figure size, by default (8, 5) bar_color : str, optional Color of the horizontal bar lines, "#B3C3F3" bar_thickness : Union[float, int], optional Thickness (hight) of the horizontal bar lines, by default 0.5 line_color : str, optional Color of the line plot, by default "purple" marker : str, optional Marker style of the lollipops. More valid marker styles can be found at [markers-api]_, by default "o" markersize : Union[int, float], optional Markersize, by default 7 markeredgecolor : str, optional Marker edge color, by default "purple" markerfacecolor: str, optional Marker face color, by default "purple" markeredgewidth : Union[int, float], optional Marker edge width, by default 1 max_display : int, optional Limit to show the number of features in the plot, by default 20 title : str, optional Title of the plot, by default None fontsize : Union[int, float], optional Fontsize for xlabel and ylabel, and ticks parameters, by default 12 save_path : str, optional The full or relative path to save the plot including the image format such as "myplot.png" or "../../myplot.pdf", by default None display_plot : bool, optional Whether to show the plot, by default True return_fig : bool, optional Whether to return figure object, by default False References ---------- .. [markers-api] https://matplotlib.org/stable/api/markers_api.html Returns ------- Figure, optional """ check_var( shap_values, var_name="shap_values", dtypes=np.ndarray, ) check_var( features, var_name="features", dtypes=(np.ndarray, pd.DataFrame), ) check_var( figsize, var_name="figsize", dtypes=(str, tuple), ) check_var( bar_color, var_name="bar_color", dtypes=str, ) check_var( bar_thickness, var_name="bar_thickness", dtypes=(float, int), ) check_var( line_color, var_name="line_color", dtypes=str, ) check_var( marker, var_name="marker", dtypes=str, ) check_var( markersize, var_name="markersize", dtypes=(int, float), ) check_var( markeredgecolor, var_name="markeredgecolor", dtypes=str, ) check_var( markerfacecolor, var_name="markerfacecolor", dtypes=str, ) check_var( markeredgewidth, var_name="markeredgewidth", dtypes=(int, float), ) check_var( max_display, var_name="max_display", dtypes=int, ) if title: check_var( title, var_name="title", dtypes=str, ) check_var( fontsize, var_name="font_size", dtypes=(int, float), ) if save_path: check_var( save_path, var_name="save_path", dtypes=str, ) check_var( display_plot, var_name="display_plot", dtypes=bool, ) check_var( return_fig, var_name="return_fig", dtypes=bool, ) # main calculation of cum/comp ratios feature_names = features.columns shap_ratio = (np.abs(shap_values).sum(0) / np.abs(shap_values).sum()) * 100 feature_names = feature_names[np.argsort(shap_ratio)[::-1]] shap_ratio_order = np.sort(shap_ratio)[::-1] cum_sum = np.cumsum(shap_ratio_order) feature_names = feature_names[:max_display] shap_ratio_order = shap_ratio_order[:max_display] cum_sum = cum_sum[:max_display] fig, ax1 = plt.subplots(figsize=figsize) # subplot 1: cumsum shap line-marker plot ax1.plot( cum_sum[::-1], feature_names[::-1], color=line_color, marker=marker, markeredgecolor=markeredgecolor, markerfacecolor=markerfacecolor, markeredgewidth=markeredgewidth, markersize=markersize, ) # subplot2: barplot ax2 = ax1.twiny() ax2.barh( feature_names[::-1], shap_ratio_order[::-1], height=bar_thickness, alpha=0.6, color=bar_color, ) ax1.grid(True) ax2.grid(False) ax1.set_xticks( np.arange( 0, round(cum_sum.max(), -1) + 1, 10, ), ) ax2.set_xticks( np.arange( 0, round(shap_ratio_order.max(), -1) + 1, 10, ), ) ax1.tick_params( axis="both", which="major", labelsize=fontsize, ) ax1.set( ylim=[ -1, len(feature_names), ], xlabel="Cumulative Ratio (%)", ylabel="Feature", title=title, ) ax2.set( xlabel="Composition Ratio (%)", ) if save_path: plt.savefig( save_path, bbox_inches="tight", dpi=200, ) if display_plot: plt.show() if return_fig: return fig return None