.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_calibration/1-quickstart/plot_calibration_venn_abers_multiclass.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_calibration_1-quickstart_plot_calibration_venn_abers_multiclass.py: Calibrating multi-class classifier with Venn-ABERS ================================================== This example shows how to calibrate a multi-class classifier with `VennAbersCalibrator` and visualize the impact on predicted probabilities. We compare an uncalibrated model against its Venn-ABERS calibrated version using reliability diagrams and multi-class Brier scores. .. GENERATED FROM PYTHON SOURCE LINES 11-24 .. code-block:: Python from __future__ import annotations import matplotlib.pyplot as plt import numpy as np from sklearn.calibration import calibration_curve from sklearn.datasets import make_classification from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import label_binarize from mapie.calibration import VennAbersCalibrator from mapie.utils import train_conformalize_test_split .. GENERATED FROM PYTHON SOURCE LINES 25-29 1. Build a miscalibrated multi-class classifier ----------------------------------------------- We generate a 3-class dataset and fit a random forest model, which is known to be miscalibrated out of the box. .. GENERATED FROM PYTHON SOURCE LINES 29-52 .. code-block:: Python X, y = make_classification( n_samples=5000, n_features=20, n_informative=12, n_redundant=2, n_classes=3, n_clusters_per_class=1, class_sep=0.8, random_state=7, ) classes = np.unique(y) # Split into train, calibration, and test sets (X_train, X_calib, X_test, y_train, y_calib, y_test) = train_conformalize_test_split( X, y, train_size=0.5, conformalize_size=0.2, test_size=0.3, random_state=42 ) base_model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=7) base_model.fit(X_train, y_train) probs_raw = base_model.predict_proba(X_test) .. GENERATED FROM PYTHON SOURCE LINES 53-58 2. Calibrate with Venn-ABERS ---------------------------- The calibrator refits the base model internally and learns a mapping from the held-out calibration set. Venn-ABERS natively supports multi-class problems. .. GENERATED FROM PYTHON SOURCE LINES 58-67 .. code-block:: Python va_calibrator = VennAbersCalibrator( estimator=RandomForestClassifier(n_estimators=100, max_depth=10, random_state=7), inductive=True, random_state=7, ) va_calibrator.fit(X_train, y_train, X_calib=X_calib, y_calib=y_calib) probs_va = va_calibrator.predict_proba(X_test) .. GENERATED FROM PYTHON SOURCE LINES 68-72 3. Multi-class Brier score helper --------------------------------- We compute the mean squared error between predicted probabilities and one-hot encoded labels. .. GENERATED FROM PYTHON SOURCE LINES 72-82 .. code-block:: Python def multiclass_brier(y_true: np.ndarray, proba: np.ndarray) -> float: y_onehot = label_binarize(y_true, classes=classes) return float(np.mean(np.sum((y_onehot - proba) ** 2, axis=1))) brier_raw = multiclass_brier(y_test, probs_raw) brier_va = multiclass_brier(y_test, probs_va) .. GENERATED FROM PYTHON SOURCE LINES 83-87 4. Reliability diagrams and Brier scores ---------------------------------------- We plot one-vs-rest reliability curves for each class before and after calibration. Lower Brier score indicates better calibration. .. GENERATED FROM PYTHON SOURCE LINES 87-122 .. code-block:: Python fig, axes = plt.subplots(1, 2, figsize=(12, 5)) for cls in classes: y_true_cls = (y_test == cls).astype(int) prob_raw_cls = probs_raw[:, cls] prob_va_cls = probs_va[:, cls] frac_pos_raw, mean_pred_raw = calibration_curve( y_true_cls, prob_raw_cls, n_bins=10, strategy="uniform" ) frac_pos_va, mean_pred_va = calibration_curve( y_true_cls, prob_va_cls, n_bins=10, strategy="uniform" ) axes[0].plot(mean_pred_raw, frac_pos_raw, marker="o", label=f"class {cls}") axes[1].plot(mean_pred_va, frac_pos_va, marker="o", label=f"class {cls}") for ax, title in zip( axes, [ f"Before calibration (Brier={brier_raw:.3f})", f"After Venn-ABERS (Brier={brier_va:.3f})", ], ): ax.plot([0, 1], [0, 1], "k--", linewidth=1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xlabel("Mean predicted probability") ax.set_ylabel("Fraction of positives") ax.set_title(title) ax.grid(True) ax.legend() plt.tight_layout() plt.show() .. image-sg:: /examples_calibration/1-quickstart/images/sphx_glr_plot_calibration_venn_abers_multiclass_001.png :alt: Before calibration (Brier=0.204), After Venn-ABERS (Brier=0.156) :srcset: /examples_calibration/1-quickstart/images/sphx_glr_plot_calibration_venn_abers_multiclass_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.764 seconds) .. _sphx_glr_download_examples_calibration_1-quickstart_plot_calibration_venn_abers_multiclass.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_calibration_venn_abers_multiclass.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_calibration_venn_abers_multiclass.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_calibration_venn_abers_multiclass.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_