.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_multilabel_classification/1-quickstart/plot_tutorial_risk_control.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_multilabel_classification_1-quickstart_plot_tutorial_risk_control.py: ========================================================================= Tutorial for recall and precision control for multi-label classification ========================================================================= In this tutorial, we compare the prediction sets estimated by the RCPS and CRC methods implemented in MAPIE, for recall control purpose, on a two-dimensional toy dataset. We will also look at the Learn Then Test (LTT) procedure. It allows to create prediction sets for precision control. Throughout this tutorial, we will answer the following questions: - How does the threshold vary according to the desired risk? - Is the chosen conformal method well calibrated (i.e. does the actual risk equal to the desired one) ? .. GENERATED FROM PYTHON SOURCE LINES 19-28 .. code-block:: default import matplotlib.pyplot as plt import numpy as np from sklearn.model_selection import train_test_split from sklearn.multioutput import MultiOutputClassifier from sklearn.naive_bayes import GaussianNB from mapie.risk_control import PrecisionRecallController .. GENERATED FROM PYTHON SOURCE LINES 29-35 1. Construction of the dataset ---------------------------------------------------------------------------- We use a two-dimensional toy dataset with three possible labels. The idea is to create a triangle where the observations on the edges have only one label, those on the vertices have two labels (those of the two edges) and the center have all the labels .. GENERATED FROM PYTHON SOURCE LINES 35-70 .. code-block:: default centers = [(0, 10), (-5, 0), (5, 0), (0, 5), (0, 0), (-4, 5), (5, 5)] covs = [ np.eye(2), np.eye(2), np.eye(2), np.diag([5, 5]), np.diag([3, 1]), np.array([ [4, 3], [3, 4] ]), np.array([ [3, -2], [-2, 3] ]), ] x_min, x_max, y_min, y_max, step = -15, 15, -5, 15, 0.1 n_samples = 800 X = np.vstack([ np.random.multivariate_normal(center, cov, n_samples) for center, cov in zip(centers, covs) ]) classes = [ [1, 0, 1], [1, 1, 0], [0, 1, 1], [1, 1, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1] ] y = np.vstack([np.full((n_samples, 3), row) for row in classes]) X_train_cal, X_test, y_train_cal, y_test = train_test_split( X, y, test_size=0.2 ) X_train, X_cal, y_train, y_cal = train_test_split( X_train_cal, y_train_cal, test_size=0.25 ) .. GENERATED FROM PYTHON SOURCE LINES 71-72 Let’s see our data. .. GENERATED FROM PYTHON SOURCE LINES 72-95 .. code-block:: default colors = { (0, 0, 1): {"color": "#1f77b4", "lac": "0-0-1"}, (0, 1, 1): {"color": "#ff7f0e", "lac": "0-1-1"}, (1, 0, 1): {"color": "#2ca02c", "lac": "1-0-1"}, (0, 1, 0): {"color": "#d62728", "lac": "0-1-0"}, (1, 1, 0): {"color": "#ffd700", "lac": "1-1-0"}, (1, 0, 0): {"color": "#c20078", "lac": "1-0-0"}, (1, 1, 1): {"color": "#06C2AC", "lac": "1-1-1"} } for i in range(7): plt.scatter( X[n_samples * i:n_samples * (i + 1), 0], X[n_samples * i:n_samples * (i + 1), 1], color=colors[tuple(y[n_samples * i])]["color"], marker='o', s=10, edgecolor='k' ) plt.legend([c["lac"] for c in colors.values()]) plt.show() .. image-sg:: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_001.png :alt: plot tutorial risk control :srcset: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 96-117 2 Recall control risk with CRC and RCPS ---------------------------------------------------------------------------- 2.1 Fitting PrecisionRecallController ---------------------------------------------------------------------------- PrecisionRecallController will be fitted with RCPS and CRC methods. For the RCPS method, we will test all three Upper Confidence Bounds (Hoeffding, Bernstein and Waudby-Smith–Ramdas). The two methods give two different guarantees on the risk: * RCPS: ``𝒫(R(𝒯̂λ̂) ≤ α) ≥ 1 − δ`` where ``R(𝒯̂λ̂)`` is the risk we want to control and α is the desired risk * CRC: ``𝐸[Lₙ₊₁(λ̂)] ≤ α`` where ``Lₙ₊₁(λ̂)`` is the risk of a new observation and ``α`` is the desired risk In both cases, the objective of the method is to find the optimal value of ``λ`` (threshold above which we consider a label as being present) such that the recall on the test points is at least equal to the required recall. .. GENERATED FROM PYTHON SOURCE LINES 117-149 .. code-block:: default method_params = { "RCPS - Hoeffding": ("rcps", "hoeffding"), "RCPS - Bernstein": ("rcps", "bernstein"), "RCPS - WSR": ("rcps", "wsr"), "CRC": ("crc", None) } clf = MultiOutputClassifier(GaussianNB()).fit(X_train, y_train) alpha = np.arange(0.01, 1, 0.01) y_pss, recalls, thresholds, r_hats, r_hat_pluss = {}, {}, {}, {}, {} y_test_repeat = np.repeat(y_test[:, :, np.newaxis], len(alpha), 2) for i, (name, (method, bound)) in enumerate(method_params.items()): mapie = PrecisionRecallController( estimator=clf, method=method, metric_control="recall" ) mapie.fit(X_cal, y_cal) _, y_pss[name] = mapie.predict( X_test, alpha=alpha, bound=bound, delta=.1 ) recalls[name] = ( (y_test_repeat * y_pss[name]).sum(axis=1) / y_test_repeat.sum(axis=1) ).mean(axis=0) thresholds[name] = mapie.lambdas_star r_hats[name] = mapie.r_hat r_hat_pluss[name] = mapie.r_hat_plus .. GENERATED FROM PYTHON SOURCE LINES 150-163 2.2. Results ---------------------------------------------------------------------------- To check the results of the methods, we propose two types of plots: 1 - Plots where the confidence level varies. Here two metrics are plotted for each method and for each UCB * The actual recall (which should be always near to the required one): we can see that they are close to each other. * The value of the threshold: we see that the threshold is decreasing as ``1 - α`` increases, which is what is expected because a smaller threshold will give larger prediction sets, hence a larger recall. .. GENERATED FROM PYTHON SOURCE LINES 163-179 .. code-block:: default vars_y = [recalls, thresholds] labels_y = ["Average number of kept labels", "Recall", "Threshold"] fig, axs = plt.subplots(1, len(vars_y), figsize=(8*len(vars_y), 8)) for i, var in enumerate(vars_y): for name, (method, bound) in method_params.items(): axs[i].plot(1 - alpha, var[name], label=name, linewidth=2) if i == 0: axs[i].plot([0, 1], [0, 1], ls="--", color="k") axs[i].set_xlabel("Desired recall : 1 - alpha", fontsize=20) axs[i].set_ylabel(labels_y[i], fontsize=20) if i == (len(vars_y) - 1): axs[i].legend(fontsize=20, loc=[1, 0]) plt.show() .. image-sg:: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_002.png :alt: plot tutorial risk control :srcset: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 180-189 2 - Plots where we choose a specific risk value (0.1 in our case) and look at the average risk, the UCB of the risk (for RCPS methods) and the choice of the threshold ``λ``. * We can see that among the RCPS methods, the Bernstein method gives the best results as for a given value of ``α`` as we are above the required recall but with a larger value of ``λ`` than the two others bounds. * The CRC method gives the best results since it guarantees the coverage with a larger threshold. .. GENERATED FROM PYTHON SOURCE LINES 189-217 .. code-block:: default fig, axs = plt.subplots( 1, len(method_params), figsize=(8*len(method_params), 8) ) for i, (name, (method, bound)) in enumerate(method_params.items()): axs[i].plot( mapie.lambdas, r_hats[name], label=r"$\hat{R}$", linewidth=2 ) if name != "CRC": axs[i].plot( mapie.lambdas, r_hat_pluss[name], label=r"$\hat{R}^+$", linewidth=2 ) axs[i].plot([0, 1], [alpha[9], alpha[9]], label=r"$\alpha$") axs[i].plot( [thresholds[name][9], thresholds[name][9]], [0, 1], label=r"$\lambda^*" + f" = {thresholds[name][9]}$" ) axs[i].legend(fontsize=20) axs[i].set_title( f"{name} - Recall = {round(recalls[name][9], 2)}", fontsize=20 ) plt.show() .. image-sg:: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_003.png :alt: RCPS - Hoeffding - Recall = 0.93, RCPS - Bernstein - Recall = 0.92, RCPS - WSR - Recall = 0.92, CRC - Recall = 0.9 :srcset: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 218-242 3. Precision control risk with LTT ---------------------------------------------------------------------------- 3.1 Fitting PrecisionRecallController ---------------------------------------------------------------------------- In this part, we will use LTT to control precision. At the opposite of the 2 previous method, LTT can handle non-monotonous loss. The procedure consist in multiple hypothesis testing. This is why the output of this procedure isn't reduce to one value of ``λ``. More precisely, we look after all the ``λ`` that sastisfy the following: ``𝒫(R(𝒯̂λ̂) ≤ α) ≥ 1 − δ``, where ``R(𝒯̂λ̂)`` is the risk we want to control and each ``λ`` should satisfy FWER control. ``α`` is the desired risk. Notice that the procedure will diligently examine each ``λ`` such that the risk remains below level ``α``, meaning not every ``λ`` will be considered. This means that a for a ``λ`` such that risk is below ``α`` doesn't necessarly pass the FWER control! This is what we are going to explore. .. GENERATED FROM PYTHON SOURCE LINES 242-267 .. code-block:: default mapie_clf = PrecisionRecallController( estimator=clf, method='ltt', metric_control='precision' ) mapie_clf.fit(X_cal, y_cal) alpha = 0.1 _, y_ps = mapie_clf.predict( X_test, alpha=alpha, delta=0.1 ) valid_index = mapie_clf.valid_index[0] # valid_index is a list of list lambdas = mapie_clf.lambdas[valid_index] mini = lambdas[np.argmin(lambdas)] maxi = lambdas[np.argmax(lambdas)] r_hat = mapie_clf.r_hat idx_max = np.argmin(r_hat[valid_index]) .. GENERATED FROM PYTHON SOURCE LINES 268-274 3.2 Valid parameters for precision control ---------------------------------------------------------------------------- We can see that not all ``λ`` such that risk is below the orange line are choosen by the procedure. Otherwise, all the lambdas that are in the red rectangle verify family wise error rate control and allow to control precision at the desired level with a high probability. .. GENERATED FROM PYTHON SOURCE LINES 274-288 .. code-block:: default plt.figure(figsize=(8, 8)) plt.plot(mapie_clf.lambdas, r_hat, label=r"$\hat{R}_\lambda$") plt.plot([0, 1], [alpha, alpha], label=r"$\alpha$") plt.axvspan(mini, maxi, facecolor='red', alpha=0.3, label=r"LTT-$\lambda$") plt.plot( [lambdas[idx_max], lambdas[idx_max]], [0, 1], label=r"$\lambda^* =" + f"{lambdas[idx_max]}$" ) plt.xlabel(r"Threshold $\lambda$") plt.ylabel(r"Empirical risk: $\hat{R}_\lambda$") plt.title("Precision risk curve", fontsize=20) plt.legend() plt.show() .. image-sg:: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_004.png :alt: Precision risk curve :srcset: /examples_multilabel_classification/1-quickstart/images/sphx_glr_plot_tutorial_risk_control_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.624 seconds) .. _sphx_glr_download_examples_multilabel_classification_1-quickstart_plot_tutorial_risk_control.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_tutorial_risk_control.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_tutorial_risk_control.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_