.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_risk_control/2-advanced-analysis/plot_risk_control_multi-label_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_risk_control_2-advanced-analysis_plot_risk_control_multi-label_classification.py: ========================================================= Control the risk of a multi-label classifier ========================================================= In this example, we explain how to perform risk control for multi-label classification using the Learn-Then-Test (LTT) procedure implemented in MAPIE. .. GENERATED FROM PYTHON SOURCE LINES 10-24 .. code-block:: Python # sphinx_gallery_thumbnail_number = 4 import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import precision_score from sklearn.model_selection import train_test_split from sklearn.multioutput import MultiOutputClassifier from sklearn.naive_bayes import GaussianNB from mapie.risk_control import MultiLabelClassificationController RANDOM_STATE = 42 .. GENERATED FROM PYTHON SOURCE LINES 25-29 First, we generate a two-dimensional toy dataset with three possible labels. The idea is to create a triangle where the observations on the edges have only one label, those on the vertices have two labels (those of the two edges) and the center have all the labels. .. GENERATED FROM PYTHON SOURCE LINES 29-115 .. code-block:: Python # Generate synthetic dataset np.random.seed(RANDOM_STATE) centers = [(0, 10), (-5, 0), (5, 0), (0, 5), (0, 0), (-4, 5), (5, 5)] covs = [ np.eye(2), np.eye(2), np.eye(2), np.diag([5, 5]), np.diag([3, 1]), np.array([[4, 3], [3, 4]]), np.array([[3, -2], [-2, 3]]), ] x_min, x_max, y_min, y_max, step = -15, 15, -5, 15, 0.1 n_samples = 800 X = np.vstack( [ np.random.multivariate_normal(center, cov, n_samples) for center, cov in zip(centers, covs) ] ) classes = [[1, 0, 1], [1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1]] y = np.vstack([np.full((n_samples, 3), row) for row in classes]) # Split the dataset into training, calibration and test sets. X_train_cal, X_test, y_train_cal, y_test = train_test_split(X, y, test_size=0.2) X_train, X_calib, y_train, y_calib = train_test_split( X_train_cal, y_train_cal, test_size=0.25 ) # Plot the three datasets to visualize the distribution of the labels. colors = { (0, 0, 1): {"color": "#1f77b4", "lac": "0-0-1"}, (0, 1, 1): {"color": "#ff7f0e", "lac": "0-1-1"}, (1, 0, 1): {"color": "#2ca02c", "lac": "1-0-1"}, (0, 1, 0): {"color": "#d62728", "lac": "0-1-0"}, (1, 1, 0): {"color": "#ffd700", "lac": "1-1-0"}, (1, 0, 0): {"color": "#c20078", "lac": "1-0-0"}, (1, 1, 1): {"color": "#06C2AC", "lac": "1-1-1"}, } fig, axes = plt.subplots(1, 3, figsize=(18, 6)) titles = ["Training Data", "Calibration Data", "Test Data"] datasets = [(X_train, y_train), (X_calib, y_calib), (X_test, y_test)] for i, (ax, (X_data, y_data), title) in enumerate(zip(axes, datasets, titles)): for label, props in colors.items(): label = np.array(label) mask = np.all(y_data == label, axis=1) ax.scatter( X_data[mask, 0], X_data[mask, 1], color=props["color"], edgecolors="k", s=60, alpha=1, label=props["lac"] if i == 0 else None, ) ax.set_title(title, fontsize=18) ax.set_xlabel("Feature 1", fontsize=16) ax.tick_params(labelsize=14) if i == 0: ax.set_ylabel("Feature 2", fontsize=16) else: ax.set_ylabel("") ax.set_yticks([]) handles, labels = axes[0].get_legend_handles_labels() fig.legend( handles, labels, loc="lower center", bbox_to_anchor=(0.5, 0), ncol=7, fontsize=18, ) plt.suptitle("Visualization of Train, Calibration, and Test Sets", fontsize=22) plt.tight_layout(rect=[0, 0.08, 1, 0.95]) plt.show() .. image-sg:: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_risk_control_multi-label_classification_001.png :alt: Visualization of Train, Calibration, and Test Sets, Training Data, Calibration Data, Test Data :srcset: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_risk_control_multi-label_classification_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 116-120 Second, we fit a MultiOutputClassifier by training one Gaussian Naive Bayes classifier per label. Using MultiOutputClassifier allows us to extend classifiers that do not natively support multi-label classification. .. GENERATED FROM PYTHON SOURCE LINES 120-124 .. code-block:: Python clf = MultiOutputClassifier(GaussianNB()) clf.fit(X_train, y_train) .. raw:: html
MultiOutputClassifier(estimator=GaussianNB())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 125-142 Next, we initialize a :class:`~mapie.risk_control.MultiLabelClassificationController` using the probability estimation function from the fitted estimator: ``clf.predict_proba``, a chosen risk ("precision" in this example), a target risk level, and a confidence level. Then we use the calibration data to compute statistically valid thresholds using a risk control procedure. When ``risk="precision"``, the controller relies on the LTT procedure, which is designed to handle non-monotonic risks. Alternatively, ``risk="recall"`` can also be used. In that case, the controller relies on monotonicity and uses either: - RCPS (Risk-Controlling Prediction Sets), which provides a probabilistic guarantee, - CRC (Conformal Risk Control), which provides a guarantee in expectation. Please refer to the theoretical description of risk control in the MAPIE documentation for more details. .. GENERATED FROM PYTHON SOURCE LINES 142-161 .. code-block:: Python target_precision = 0.8 confidence_level = 0.9 mcc = MultiLabelClassificationController( predict_function=clf.predict_proba, risk="precision", method="ltt", predict_params=np.arange(0.01, 1, 0.01), target_level=target_precision, confidence_level=confidence_level, ) mcc.calibrate(X_calib, y_calib) print( f"{len(mcc.valid_predict_params[0])} thresholds found that guarantee a precision of " f"at least {target_precision} with a confidence of {confidence_level}. " f"The best threshold is: {mcc.best_predict_param[0]:.3f}." ) .. rst-class:: sphx-glr-script-out .. code-block:: none 53 thresholds found that guarantee a precision of at least 0.8 with a confidence of 0.9. The best threshold is: 0.410. .. GENERATED FROM PYTHON SOURCE LINES 162-164 In the plot below, we visualize how the threshold values impact precision, and what thresholds have been computed as statistically guaranteed. .. GENERATED FROM PYTHON SOURCE LINES 164-250 .. code-block:: Python tested_thresholds = mcc.predict_params precisions = 1 - mcc.r_hat # risk is defined as 1 - precision naive_threshold_index = np.argmin( np.where(precisions >= target_precision, precisions - target_precision, np.inf) ) valid_thresholds_indices = mcc.valid_index[0] # valid_index is a list of lists mask_invalid_threshold = np.ones(len(tested_thresholds), dtype=bool) mask_invalid_threshold[valid_thresholds_indices] = False best_threshold_index = np.where(tested_thresholds == mcc.best_predict_param[0])[0][0] probas_test = clf.predict_proba(X_test) proba_positive = np.column_stack([p[:, 1] for p in probas_test]) y_pred_naive = (proba_positive >= tested_thresholds[naive_threshold_index]).astype(int) precision_naive_threshold = precision_score( y_test, y_pred_naive, average=None, zero_division=0 ).mean() y_pred_ltt_best_threshold = ( proba_positive >= tested_thresholds[best_threshold_index] ).astype(int) precision_best_ltt_threshold = precision_score( y_test, y_pred_ltt_best_threshold, average=None, zero_division=0 ).mean() plt.figure() plt.scatter( tested_thresholds[valid_thresholds_indices], precisions[valid_thresholds_indices], c="tab:green", label="Valid thresholds", ) plt.scatter( tested_thresholds[mask_invalid_threshold], precisions[mask_invalid_threshold], c="tab:red", label="Invalid thresholds", ) plt.scatter( tested_thresholds[best_threshold_index], precisions[best_threshold_index], c="tab:green", label="Best threshold", marker="*", edgecolors="k", s=300, ) plt.scatter( tested_thresholds[naive_threshold_index], precisions[naive_threshold_index], c="tab:red", label="Naive threshold", marker="*", edgecolors="k", s=300, ) plt.axhline(target_precision, color="tab:gray", linestyle="--") plt.text( 0.65, target_precision + 0.02, "Target precision", color="tab:gray", fontstyle="italic", ) plt.xlabel("Threshold") plt.ylabel("Precision") plt.legend() plt.show() print( "With the naive threshold, the precision is:\n " f"- {precisions[naive_threshold_index]:.3f} on the calibration set\n " f"- {precision_naive_threshold:.3f} on the test set." ) print( "\n\nWith risk control, the precision is:\n " f"- {precisions[best_threshold_index]:.3f} on the calibration set\n " f"- {precision_best_ltt_threshold:.3f} on the test set." ) .. image-sg:: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_risk_control_multi-label_classification_002.png :alt: plot risk control multi label classification :srcset: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_risk_control_multi-label_classification_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none With the naive threshold, the precision is: - 0.800 on the calibration set - 0.782 on the test set. With risk control, the precision is: - 0.842 on the calibration set - 0.833 on the test set. .. GENERATED FROM PYTHON SOURCE LINES 251-262 The naive threshold is selected on the calibration set to match the target precision, but it does not provide any statistical guarantee on unseen data. In contrast, the threshold selected by risk control takes into account the uncertainty due to the finite calibration sample size and guarantees that the target precision is met on unseen data with high probability. As illustrated above, not all thresholds achieving a precision higher than the target are statistically valid. This highlights the importance of risk control when deploying multi-label classifiers in practice. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.005 seconds) .. _sphx_glr_download_examples_risk_control_2-advanced-analysis_plot_risk_control_multi-label_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_risk_control_multi-label_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_risk_control_multi-label_classification.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_risk_control_multi-label_classification.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_