Skip to content

Explainer

The Explainer class is part of the jarvais.explainer module. It generates explainability reports for trained models.

jarvais.explainer.Explainer

A class to generate diagnostic plots and reports for models trained using TrainerSupervised.

Attributes:

Name Type Description
trainer TrainerSupervised

The TrainerSupervised object containing the trained model.

predictor object

The AutoGluon predictor object used for inference.

X_train DataFrame

The training dataset used to train the model.

X_test DataFrame

The test dataset for evaluating the model.

y_test DataFrame

The true target values for the test dataset.

output_dir Path

The directory where plots, reports, and outputs are saved.

sensitive_features list

List of features considered sensitive for bias auditing.

Source code in src/jarvais/explainer/explainer.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class Explainer():
    """
    A class to generate diagnostic plots and reports for models trained using TrainerSupervised.

    Attributes:
        trainer (TrainerSupervised): The TrainerSupervised object containing the trained model.
        predictor (object): The AutoGluon predictor object used for inference.
        X_train (pd.DataFrame): The training dataset used to train the model.
        X_test (pd.DataFrame): The test dataset for evaluating the model.
        y_test (pd.DataFrame): The true target values for the test dataset.
        output_dir (Path): The directory where plots, reports, and outputs are saved.
        sensitive_features (list, optional): List of features considered sensitive for bias auditing.
    """
    def __init__(
            self,
            trainer,
            X_train: pd.DataFrame,
            X_test: pd.DataFrame,
            y_test: pd.DataFrame,
            output_dir: str | Path | None = None,
            sensitive_features: list | None = None,
        ) -> None:

        self.trainer = trainer
        self.predictor = trainer.predictor
        self.X_train = X_train
        self.X_test = X_test
        self.y_test = y_test
        self.sensitive_features = sensitive_features

        self.output_dir = Path.cwd() if output_dir is None else Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        (self.output_dir / 'figures').mkdir(parents=True, exist_ok=True)

    def run(self) -> None:
        """Generate diagnostic plots and reports for the trained model."""

        self._run_bias_audit()

        plot_violin_of_bootstrapped_metrics(
            self.trainer,
            self.X_test,
            self.y_test,
            self.trainer.X_val,
            self.trainer.y_val,
            self.X_train,
            self.trainer.y_train,
            output_dir=self.output_dir / 'figures'
        )            

        if self.trainer.task in ['binary', 'multiclass']:
            plot_classification_diagnostics(
                self.y_test,
                self.predictor.predict_proba(self.X_test).iloc[:, 1],
                self.trainer.y_val,
                self.predictor.predict_proba(self.trainer.X_val).iloc[:, 1],
                self.trainer.y_train,
                self.predictor.predict_proba(self.X_train).iloc[:, 1],
                output_dir=self.output_dir / 'figures'
            )
            plot_shap_values(
                self.predictor,
                self.X_train,
                self.X_test,
                output_dir=self.output_dir / 'figures'
            )

        elif self.trainer.task == 'regression':
            plot_regression_diagnostics(
                self.y_test,
                self.predictor.predict(self.X_test, as_pandas=False),
                output_dir=self.output_dir / 'figures'
            )

        # Plot feature importance
        if self.trainer.task == 'survival': # NEEDS TO BE UPDATED
            model = self.trainer.predictors['CoxPH']
            result = permutation_importance(model, self.X_test,
                                            Surv.from_dataframe('event', 'time', self.y_test),
                                            n_repeats=15)

            importance_df = pd.DataFrame(
                {
                    "importance": result["importances_mean"],
                    "stddev": result["importances_std"],
                },
                index=self.X_test.columns,
            ).sort_values(by="importance", ascending=False)
            model_name = 'CoxPH'
        else:
            importance_df = self.predictor.feature_importance(
                pd.concat([self.X_test, self.y_test], axis=1))
            model_name = self.predictor.model_best

        plot_feature_importance(importance_df, self.output_dir / 'figures', model_name)
        generate_explainer_report_pdf(self.trainer.task, self.output_dir)

    def _run_bias_audit(self) -> List[pd.DataFrame]:

        bias_output_dir = self.output_dir / 'bias'
        bias_output_dir.mkdir(parents=True, exist_ok=True)

        if self.sensitive_features is None:
            if self.trainer.task == 'survival': # Data needs to be not be one hot encoded
                self.sensitive_features = infer_sensitive_features(undummify(self.X_test, prefix_sep='|'))
            else:
                self.sensitive_features = infer_sensitive_features(self.X_test)

        y_pred = None if self.trainer.task == 'survival' else pd.Series(self.trainer.infer(self.X_test) )
        metrics = ['mean_prediction'] if self.trainer.task == 'regression' else ['mean_prediction', 'false_positive_rate'] 

        bias = BiasExplainer(
            self.y_test, 
            y_pred, 
            self.sensitive_features,
            self.trainer.task, 
            bias_output_dir,
            metrics
        )
        bias.run(relative=True)

    @classmethod
    def from_trainer(cls, trainer, **kwargs):
        """Create Explainer object from TrainerSupervised object."""
        return cls(trainer, trainer.X_train, trainer.X_test, trainer.y_test, trainer.output_dir, **kwargs)

run()

Generate diagnostic plots and reports for the trained model.

Source code in src/jarvais/explainer/explainer.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def run(self) -> None:
    """Generate diagnostic plots and reports for the trained model."""

    self._run_bias_audit()

    plot_violin_of_bootstrapped_metrics(
        self.trainer,
        self.X_test,
        self.y_test,
        self.trainer.X_val,
        self.trainer.y_val,
        self.X_train,
        self.trainer.y_train,
        output_dir=self.output_dir / 'figures'
    )            

    if self.trainer.task in ['binary', 'multiclass']:
        plot_classification_diagnostics(
            self.y_test,
            self.predictor.predict_proba(self.X_test).iloc[:, 1],
            self.trainer.y_val,
            self.predictor.predict_proba(self.trainer.X_val).iloc[:, 1],
            self.trainer.y_train,
            self.predictor.predict_proba(self.X_train).iloc[:, 1],
            output_dir=self.output_dir / 'figures'
        )
        plot_shap_values(
            self.predictor,
            self.X_train,
            self.X_test,
            output_dir=self.output_dir / 'figures'
        )

    elif self.trainer.task == 'regression':
        plot_regression_diagnostics(
            self.y_test,
            self.predictor.predict(self.X_test, as_pandas=False),
            output_dir=self.output_dir / 'figures'
        )

    # Plot feature importance
    if self.trainer.task == 'survival': # NEEDS TO BE UPDATED
        model = self.trainer.predictors['CoxPH']
        result = permutation_importance(model, self.X_test,
                                        Surv.from_dataframe('event', 'time', self.y_test),
                                        n_repeats=15)

        importance_df = pd.DataFrame(
            {
                "importance": result["importances_mean"],
                "stddev": result["importances_std"],
            },
            index=self.X_test.columns,
        ).sort_values(by="importance", ascending=False)
        model_name = 'CoxPH'
    else:
        importance_df = self.predictor.feature_importance(
            pd.concat([self.X_test, self.y_test], axis=1))
        model_name = self.predictor.model_best

    plot_feature_importance(importance_df, self.output_dir / 'figures', model_name)
    generate_explainer_report_pdf(self.trainer.task, self.output_dir)

from_trainer(trainer, **kwargs) classmethod

Create Explainer object from TrainerSupervised object.

Source code in src/jarvais/explainer/explainer.py
140
141
142
143
@classmethod
def from_trainer(cls, trainer, **kwargs):
    """Create Explainer object from TrainerSupervised object."""
    return cls(trainer, trainer.X_train, trainer.X_test, trainer.y_test, trainer.output_dir, **kwargs)

The BiasExplainer class is used by the Explainer class to run a bias audit.

jarvais.explainer.BiasExplainer

A class for explaining and analyzing bias in a predictive model's outcomes based on sensitive features.

This class performs various fairness audits by evaluating predictive outcomes with respect to sensitive features such as gender, age, race, and more. It first runs statistical analyses using the OLS regression F-statistic p-value to assess any possibility of bias in the model's predictions based on sensitive features. If the p-value is less than 0.05, indicating potential bias, the class generates visualizations (such as violin plots) and calculates fairness metrics (e.g., demographic parity, equalized odds). The results are presented for each sensitive feature, with optional relative fairness comparisons.

Attributes:

Name Type Description
y_true DataFrame

The true target values for the model.

y_pred DataFrame

The predicted values from the model.

sensitive_features dict or DataFrame

A dictionary or DataFrame containing sensitive features used for fairness analysis.

metrics list

A list of metrics to calculate for fairness analysis. Defaults to ['mean_prediction', 'false_positive_rate', 'true_positive_rate'].

mapper dict

A dictionary mapping internal metric names to user-friendly descriptions.

kwargs dict

Additional parameters passed to various methods, such as metric calculation and plot generation.

Source code in src/jarvais/explainer/bias.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class BiasExplainer():
    """
    A class for explaining and analyzing bias in a predictive model's outcomes based on sensitive features.

    This class performs various fairness audits by evaluating predictive outcomes with respect to sensitive features such as
    gender, age, race, and more. It first runs statistical analyses using the OLS regression F-statistic p-value to assess any possibility 
    of bias in the model's predictions based on sensitive features. If the p-value is less than 0.05, indicating potential bias, 
    the class generates visualizations (such as violin plots) and calculates fairness metrics (e.g., demographic parity, equalized odds). 
    The results are presented for each sensitive feature, with optional relative fairness comparisons.

    Attributes:
        y_true (pd.DataFrame):
            The true target values for the model.
        y_pred (pd.DataFrame):
            The predicted values from the model.
        sensitive_features (dict or pd.DataFrame):
            A dictionary or DataFrame containing sensitive features used for fairness analysis.
        metrics (list):
            A list of metrics to calculate for fairness analysis. Defaults to ['mean_prediction', 'false_positive_rate', 'true_positive_rate'].
        mapper (dict):
            A dictionary mapping internal metric names to user-friendly descriptions.
        kwargs (dict):
            Additional parameters passed to various methods, such as metric calculation and plot generation.
    """
    def __init__(
            self, 
            y_true: pd.Series, 
            y_pred: np.ndarray, 
            sensitive_features: dict, 
            task: str,
            output_dir: Path,
            metrics: list = ['mean_prediction', 'false_positive_rate', 'true_positive_rate'], 
            **kwargs: dict
        ) -> None:
        self.y_true = y_true
        self.y_pred = y_pred
        self.task = task
        self.output_dir = output_dir
        self.mapper = {"mean_prediction": "Demographic Parity",
                       "false_positive_rate": "(FPR) Equalized Odds",
                       "true_positive_rate": "(TPR) Equalized Odds or Equal Opportunity"}
        self.metrics = metrics
        self.kwargs = kwargs

        # Convert sensitive_features to DataFrame or leave as Series
        if isinstance(sensitive_features, pd.DataFrame) or isinstance(sensitive_features, pd.Series):
            self.sensitive_features = sensitive_features
        elif isinstance(sensitive_features, dict):
            self.sensitive_features = pd.DataFrame.from_dict(sensitive_features)
        elif isinstance(sensitive_features, list):
            if any(isinstance(item, list) for item in sensitive_features):
                self.sensitive_features = pd.DataFrame(sensitive_features, columns=[f'sensitive_feature_{i}' for i in range(len(sensitive_features))])
            else:
                self.sensitive_features = pd.DataFrame(sensitive_features, columns=['sensitive_feature'])
        else:
            raise ValueError("sensitive_features must be a pandas DataFrame, Series, dictionary or list")

    def _generate_violin(self, sensitive_feature: str, bias_metric:np.ndarray) -> None:
        """Generate a violin plot for the bias metric."""
        plt.figure(figsize=(8, 6)) 
        sns.set_theme(style="whitegrid")  

        sns.violinplot(
            x=self.sensitive_features[sensitive_feature], 
            y=bias_metric, 
            palette="muted",  
            inner="quart", 
            linewidth=1.25 
        )

        bias_metric_name = 'log_loss' if self.task == 'binary' else 'root_mean_squared_error'

        plt.title(f'{bias_metric_name.title()} Distribution by {sensitive_feature}', fontsize=16, weight='bold')  
        plt.xlabel(f'{sensitive_feature}', fontsize=14)  
        plt.ylabel(f'{bias_metric_name.title()} per Patient', fontsize=14) 
        plt.xticks(rotation=45, ha='right')

        plt.tight_layout()  
        plt.savefig(self.output_dir / f'{sensitive_feature}_{bias_metric_name}.png') 
        plt.show()

    def _subgroup_analysis_OLS(self, sensitive_feature: str, bias_metric:np.ndarray) -> float:
        """Fit a statsmodels OLS model to the bias metric data, using the sensitive feature and print summary based on p_val."""
        one_hot_encoded = pd.get_dummies(self.sensitive_features[sensitive_feature], prefix=sensitive_feature)
        X_columns = one_hot_encoded.columns

        X = one_hot_encoded.values  
        y = bias_metric  

        X = sm.add_constant(X.astype(float), has_constant='add')
        model = sm.OLS(y, X).fit()

        if model.f_pvalue < 0.05:
            output = []

            print(f"⚠️  **Possible Bias Detected in {sensitive_feature.title()}** ⚠️\n")
            output.append(f"=== Subgroup Analysis for '{sensitive_feature.title()}' Using OLS Regression ===\n")

            output.append("Model Statistics:")
            output.append(f"    R-squared:                  {model.rsquared:.3f}")
            output.append(f"    F-statistic:                {model.fvalue:.3f}")
            output.append(f"    F-statistic p-value:        {model.f_pvalue:.4f}")
            output.append(f"    AIC:                        {model.aic:.2f}")
            output.append(f"    Log-Likelihood:             {model.llf:.2f}")

            summary_df = pd.DataFrame({
                'Feature': ['const'] + X_columns.tolist(),     # Predictor names (includes 'const' if added)
                'Coefficient': model.params,    # Coefficients
                'Standard Error': model.bse     # Standard Errors
            })
            table_output = tabulate(summary_df, headers='keys', tablefmt='grid', showindex=False, floatfmt=".3f")
            output.append("Model Coefficients:")
            output.append('\n'.join(['    ' + line for line in table_output.split('\n')]))

            output_text = '\n'.join(output)
            print(output_text)

            with open(self.output_dir / f'{sensitive_feature}_Cox_model_summary.txt', 'w') as f:
                f.write(output_text)

        return model.f_pvalue

    def _subgroup_analysis_CoxPH(self, sensitive_feature: str) -> None:
        """Fit a CoxPH model using the sensitive feature and print summary based on p_val."""
        one_hot_encoded = pd.get_dummies(self.sensitive_features[sensitive_feature], prefix=sensitive_feature)
        df_encoded = self.y_true.join(one_hot_encoded)

        cph = CoxPHFitter(penalizer=0.0001)
        cph.fit(df_encoded, duration_col='time', event_col='event')            

        if cph.log_likelihood_ratio_test().p_value < 0.05:
            output = []

            print(f"⚠️  **Possible Bias Detected in {sensitive_feature.title()}** ⚠️")
            output.append(f"=== Subgroup Analysis for '{sensitive_feature.title()}' Using Cox Proportional Hazards Model ===\n")

            output.append("Model Statistics:")
            output.append(f"    AIC (Partial):               {cph.AIC_partial_:.2f}")
            output.append(f"    Log-Likelihood:              {cph.log_likelihood_:.2f}")
            output.append(f"    Log-Likelihood Ratio p-value: {cph.log_likelihood_ratio_test().p_value:.4f}")
            output.append(f"    Concordance Index (C-index):   {cph.concordance_index_:.2f}")

            summary_df = pd.DataFrame({
                'Feature': cph.summary.index.to_list(),
                'Coefficient': cph.summary['coef'].to_list(),
                'Standard Error': cph.summary['se(coef)'].to_list()
            })
            table_output = tabulate(summary_df, headers='keys', tablefmt='grid', showindex=False, floatfmt=".3f")
            output.append("Model Coefficients:")
            output.append('\n'.join(['    ' + line for line in table_output.split('\n')]))

            output_text = '\n'.join(output)
            print(output_text)

            with open(self.output_dir / f'{sensitive_feature}_OLS_model_summary.txt', 'w') as f:
                f.write(output_text)

    def _calculate_fair_metrics(
            self, 
            sensitive_feature: str, 
            fairness_threshold: float, 
            relative: bool
        ) -> pd.DataFrame:
        """Calculate the Fairlearn metrics and return the results in a DataFrame."""
        _metrics = {metric: get_metric(metric, sensitive_features=self.sensitive_features[sensitive_feature]) for metric in self.metrics}
        metric_frame = fm.MetricFrame(
            metrics=_metrics, 
            y_true=self.y_true, 
            y_pred=self.y_pred, 
            sensitive_features=self.sensitive_features[sensitive_feature], 
            **self.kwargs
        )
        result = pd.DataFrame(metric_frame.by_group.T, index=_metrics.keys())
        result = result.rename(columns=self.mapper)

        if relative:
            largest_feature = self.sensitive_features[sensitive_feature].mode().iloc[0]
            results_relative = result.T / result[largest_feature]
            results_relative = results_relative.applymap(
                lambda x: f"{x:.3f} ✅" if x <= fairness_threshold or 1/x <= fairness_threshold 
                else f"{x:.3f} ❌")
            result = pd.concat([result, results_relative.T.rename(index=lambda x: f"Relative {x}")])

        return result

    def run(
            self, 
            relative: bool = False, 
            fairness_threshold: float = 1.2
        ) -> None:
        """
        Runs the bias explainer analysis on the provided data. It first evaluates the potential bias in the model's predictions
        using the OLS regression F-statistic p-value. If the p-value is below the threshold of 0.05, indicating 
        potential bias in the sensitive feature, the method proceeds to generate visualizations and calculate fairness metrics.

        Args:
            relative (bool): 
                If True, the metrics will be presented relative to the most frequent value of each sensitive feature.
            fairness_threshold (float): 
                A threshold for determining fairness based on relative metrics. If the relative metric exceeds this threshold, 
                a warning flag will be applied.
        """
        if self.task == 'binary':
            y_true_array = self.y_true.to_numpy()
            bias_metric = np.array([
                log_loss([y_true_array[idx]], [self.y_pred[idx]], labels=np.unique(y_true_array))
                for idx in range(len(y_true_array))
            ])
            self.y_pred = (self.y_pred >= .5).astype(int)
        elif self.task == 'regression':
            bias_metric = np.sqrt((self.y_true.to_numpy() - self.y_pred) ** 2)

        self.results = []
        for sensitive_feature in self.sensitive_features.columns:
            if self.task == 'survival':
                self._subgroup_analysis_CoxPH(sensitive_feature)
            else:
                f_pvalue = self._subgroup_analysis_OLS(sensitive_feature, bias_metric)
                if f_pvalue < 0.05:
                    self._generate_violin(sensitive_feature, bias_metric)
                    result = self._calculate_fair_metrics(sensitive_feature, fairness_threshold, relative)

                    print(f"\n=== Subgroup Analysis for '{sensitive_feature.title()}' using FairLearn ===\n")
                    table_output = tabulate(result.iloc[:, :4], headers='keys', tablefmt='grid')
                    print('\n'.join(['    ' + line for line in table_output.split('\n')]), '\n')

                    result.to_csv(self.output_dir / f'{sensitive_feature}_fm_metrics.csv')

run(relative=False, fairness_threshold=1.2)

Runs the bias explainer analysis on the provided data. It first evaluates the potential bias in the model's predictions using the OLS regression F-statistic p-value. If the p-value is below the threshold of 0.05, indicating potential bias in the sensitive feature, the method proceeds to generate visualizations and calculate fairness metrics.

Parameters:

Name Type Description Default
relative bool

If True, the metrics will be presented relative to the most frequent value of each sensitive feature.

False
fairness_threshold float

A threshold for determining fairness based on relative metrics. If the relative metric exceeds this threshold, a warning flag will be applied.

1.2
Source code in src/jarvais/explainer/bias.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def run(
        self, 
        relative: bool = False, 
        fairness_threshold: float = 1.2
    ) -> None:
    """
    Runs the bias explainer analysis on the provided data. It first evaluates the potential bias in the model's predictions
    using the OLS regression F-statistic p-value. If the p-value is below the threshold of 0.05, indicating 
    potential bias in the sensitive feature, the method proceeds to generate visualizations and calculate fairness metrics.

    Args:
        relative (bool): 
            If True, the metrics will be presented relative to the most frequent value of each sensitive feature.
        fairness_threshold (float): 
            A threshold for determining fairness based on relative metrics. If the relative metric exceeds this threshold, 
            a warning flag will be applied.
    """
    if self.task == 'binary':
        y_true_array = self.y_true.to_numpy()
        bias_metric = np.array([
            log_loss([y_true_array[idx]], [self.y_pred[idx]], labels=np.unique(y_true_array))
            for idx in range(len(y_true_array))
        ])
        self.y_pred = (self.y_pred >= .5).astype(int)
    elif self.task == 'regression':
        bias_metric = np.sqrt((self.y_true.to_numpy() - self.y_pred) ** 2)

    self.results = []
    for sensitive_feature in self.sensitive_features.columns:
        if self.task == 'survival':
            self._subgroup_analysis_CoxPH(sensitive_feature)
        else:
            f_pvalue = self._subgroup_analysis_OLS(sensitive_feature, bias_metric)
            if f_pvalue < 0.05:
                self._generate_violin(sensitive_feature, bias_metric)
                result = self._calculate_fair_metrics(sensitive_feature, fairness_threshold, relative)

                print(f"\n=== Subgroup Analysis for '{sensitive_feature.title()}' using FairLearn ===\n")
                table_output = tabulate(result.iloc[:, :4], headers='keys', tablefmt='grid')
                print('\n'.join(['    ' + line for line in table_output.split('\n')]), '\n')

                result.to_csv(self.output_dir / f'{sensitive_feature}_fm_metrics.csv')