pdpbox.info_plots.PredictPlot

class pdpbox.info_plots.PredictPlot(df, feature, feature_name, model, model_features, pred_func=None, n_classes=None, predict_kwds=None, chunk_size=-1, cust_grid_points=None, grid_type='percentile', num_grid_points=10, percentile_range=None, grid_range=None, show_outliers=False, endpoint=True)

Generates box plots depicting the statistical distribution of prediction values across distinct groups (or buckets) of a single feature.

The box plots illustrate the distribution of prediction values, with emphasis on the first quartile, median (second quartile), and third quartile, across different feature buckets. This visualization provides insights into the variation of predicted values with respect to different groups of the chosen feature. Such understanding is essential for comprehensive feature analysis and aids in interpreting model predictions.

Attributes:
dfpd.DataFrame

A processed DataFrame that includes feature and target (for target plot) or predict (for predict plot) columns, feature buckets, along with the count of samples within each bucket.

feature_infoFeatureInfo

An instance of the FeatureInfo class.

feature_colslist of str

List of feature columns.

targetlist of int

List of target indices. For binary and regression problems, the list will be just [0]. For multi-class targets, the list is the class indices.

n_classesint

The number of classes provided, or inferred from the model when it is not provided.

plot_typestr

The type of the plot to be generated.

plot_enginesdict

A dictionary that maps plot types to their plotting engines.

count_dfpd.DataFrame

A DataFrame that contains the count as well as the normalized count (percentage) of samples within each feature bucket.

summary_dfpd.DataFrame

A DataFrame that contains the summary statistics of target (for target plot) or predict (for predict plot) values for each feature bucket.

target_lineslist of pd.DataFrame

A list of DataFrames, each DataFrame includes aggregate metrics across feature buckets for a target (for target plot) or predict (for predict plot) variable. For binary and regression problems, the list will contain a single DataFrame. For multi-class targets, the list will contain a DataFrame for each class.

Methods

plot(**kwargs)

Generates the plot.

Both pdpbox.info_plots.TargetPlot and pdpbox.info_plots.PredictPlot inherit from pdpbox.info_plots._InfoPlot class and share the same plot method.

class pdpbox.info_plots._InfoPlot(df, feature, feature_name, target=None, model=None, model_features=None, pred_func=None, n_classes=None, predict_kwds=None, chunk_size=-1, plot_type='target', **kwargs)

Methods

plot([which_classes, show_percentile, ...])

The plot function for TargetPlot and PredictPlot.

plot(which_classes=None, show_percentile=False, figsize=None, dpi=300, ncols=2, plot_params=None, engine='plotly', template='plotly_white')

The plot function for TargetPlot and PredictPlot.

Parameters:
which_classeslist of int, optional

List of class indices to plot. If None, all classes will be plotted. Default is None.

show_percentilebool, optional

If True, percentiles are shown in the plot. Default is False.

figsizetuple or None, optional

The figure size for matplotlib or plotly figure. If None, the default figure size is used. Default is None.

dpiint, optional

The resolution of the plot, measured in dots per inch. Only applicable when engine is ‘matplotlib’. Default is 300.

ncolsint, optional

The number of columns of subplots in the figure. Default is 2.

plot_paramsdict or None, optional

Custom plot parameters that control the style and aesthetics of the plot. Default is None.

engine{‘matplotlib’, ‘plotly’}, optional

The plotting engine to use. Default is plotly.

templatestr, optional

The template to use for plotly plots. Only applicable when engine is ‘plotly’. Reference: https://plotly.com/python/templates/ Default is plotly_white.

Returns:
matplotlib.figure.Figure or plotly.graph_objects.Figure

A Matplotlib or Plotly figure object depending on the plot engine being used.

dict of matplotlib.axes.Axes or None

A dictionary of Matplotlib axes objects. The keys are the names of the axes. The values are the axes objects. If engine is ‘ploltly’, it is None.

pd.DataFrame

A DataFrame that contains the summary statistics of target (for target plot) or predict (for predict plot) values for each feature bucket.