.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples_risk_control/2-advanced-analysis/plot_semantic_segmentation_recall_control.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_risk_control_2-advanced-analysis_plot_semantic_segmentation_recall_control.py: ======================================== Recall control for semantic segmentation ======================================== This example illustrates how to control the recall of a semantic segmentation model using MAPIE. We use :class:`~mapie.risk_control.SemanticSegmentationController` to calibrate a decision threshold that statistically guarantees a target recall level on unseen data. The dataset, model and utility functions are loaded from Hugging Face for simplicity and reproducibility. .. GENERATED FROM PYTHON SOURCE LINES 16-30 .. code-block:: Python import importlib.util import warnings from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch from huggingface_hub import hf_hub_download, snapshot_download from mapie.risk_control import SemanticSegmentationController warnings.filterwarnings("ignore") .. GENERATED FROM PYTHON SOURCE LINES 31-35 To keep this example self-contained, we load the dataset utilities and the segmentation LightningModule definition directly from a repository hosted on Hugging Face. .. GENERATED FROM PYTHON SOURCE LINES 35-50 .. code-block:: Python module_path = hf_hub_download( repo_id="mapie-library/rooftop_segmentation", filename="model_and_lightning_module.py", repo_type="dataset", ) spec = importlib.util.spec_from_file_location("hf_module", module_path) hf_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(hf_module) SegmentationLightningModule = hf_module.SegmentationLightningModule RoofSegmentationDataset = hf_module.RoofSegmentationDataset get_validation_transforms = hf_module.get_validation_transforms .. GENERATED FROM PYTHON SOURCE LINES 51-53 Load a pretrained segmentation model checkpoint from Hugging Face. .. GENERATED FROM PYTHON SOURCE LINES 53-75 .. code-block:: Python model_ckpt = hf_hub_download( repo_id="mapie-library/rooftop_segmentation", filename="best_model-v1.ckpt", repo_type="dataset", ) data_root = Path( snapshot_download( repo_id="mapie-library/rooftop_segmentation", repo_type="dataset", allow_patterns=["calib/**", "test/**"], ) ) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = SegmentationLightningModule.load_from_checkpoint(model_ckpt) model.to(DEVICE) model.eval() print("Model loaded successfully!") .. rst-class:: sphx-glr-script-out .. code-block:: none Fetching ... files: 0it [00:00, ?it/s] Fetching ... files: 44it [00:00, 389.10it/s] Fetching ... files: 83it [00:01, 63.26it/s] Fetching ... files: 102it [00:01, 56.33it/s] Fetching ... files: 114it [00:01, 56.55it/s] Fetching ... files: 124it [00:02, 50.74it/s] Fetching ... files: 132it [00:02, 52.32it/s] Fetching ... files: 139it [00:02, 47.47it/s] Fetching ... files: 145it [00:02, 43.59it/s] Fetching ... files: 151it [00:02, 43.87it/s] Fetching ... files: 157it [00:02, 41.49it/s] Fetching ... files: 162it [00:03, 35.95it/s] Fetching ... files: 170it [00:03, 40.89it/s] Fetching ... files: 177it [00:03, 40.24it/s] Fetching ... files: 182it [00:03, 41.18it/s] Fetching ... files: 187it [00:03, 31.94it/s] Fetching ... files: 196it [00:03, 42.19it/s] Fetching ... files: 202it [00:04, 38.54it/s] Fetching ... files: 207it [00:04, 38.60it/s] Fetching ... files: 213it [00:04, 38.63it/s] Fetching ... files: 219it [00:04, 42.87it/s] Fetching ... files: 224it [00:04, 43.43it/s] Fetching ... files: 229it [00:04, 43.58it/s] Fetching ... files: 236it [00:04, 46.20it/s] Fetching ... files: 241it [00:05, 40.79it/s] Fetching ... files: 248it [00:05, 42.85it/s] Fetching ... files: 253it [00:05, 42.88it/s] Fetching ... files: 258it [00:05, 42.52it/s] Fetching ... files: 263it [00:05, 39.99it/s] Fetching ... files: 271it [00:05, 43.36it/s] Fetching ... files: 280it [00:05, 48.75it/s] Fetching ... files: 285it [00:06, 34.94it/s] Fetching ... files: 293it [00:06, 42.37it/s] Fetching ... files: 298it [00:06, 42.27it/s] Fetching ... files: 303it [00:06, 41.90it/s] Fetching ... files: 308it [00:06, 42.26it/s] Fetching ... files: 315it [00:06, 44.65it/s] Fetching ... files: 320it [00:06, 36.66it/s] Fetching ... files: 329it [00:07, 47.57it/s] Fetching ... files: 335it [00:07, 50.08it/s] Fetching ... files: 341it [00:07, 51.89it/s] Fetching ... files: 347it [00:07, 39.84it/s] Fetching ... files: 356it [00:07, 44.23it/s] Fetching ... files: 361it [00:07, 42.43it/s] Fetching ... files: 371it [00:07, 50.79it/s] Fetching ... files: 377it [00:08, 46.67it/s] Fetching ... files: 382it [00:08, 43.50it/s] Fetching ... files: 390it [00:08, 46.12it/s] Fetching ... files: 398it [00:08, 52.66it/s] Fetching ... files: 404it [00:08, 37.51it/s] Fetching ... files: 412it [00:08, 42.15it/s] Fetching ... files: 418it [00:09, 45.24it/s] Fetching ... files: 425it [00:09, 46.46it/s] Fetching ... files: 431it [00:09, 44.86it/s] Fetching ... files: 440it [00:09, 43.92it/s] Fetching ... files: 445it [00:09, 44.33it/s] Fetching ... files: 450it [00:09, 39.53it/s] Fetching ... files: 457it [00:09, 44.50it/s] Fetching ... files: 462it [00:10, 45.13it/s] Fetching ... files: 467it [00:10, 42.79it/s] Fetching ... files: 472it [00:10, 41.41it/s] Fetching ... files: 477it [00:10, 39.99it/s] Fetching ... files: 484it [00:10, 42.68it/s] Fetching ... files: 490it [00:10, 41.61it/s] Fetching ... files: 496it [00:10, 44.55it/s] Fetching ... files: 501it [00:11, 38.30it/s] Fetching ... files: 505it [00:11, 38.32it/s] Fetching ... files: 510it [00:11, 34.18it/s] Fetching ... files: 516it [00:11, 36.50it/s] Fetching ... files: 527it [00:11, 52.18it/s] Fetching ... files: 533it [00:11, 48.10it/s] Fetching ... files: 539it [00:11, 47.53it/s] Fetching ... files: 545it [00:12, 41.11it/s] Fetching ... files: 552it [00:12, 45.76it/s] Fetching ... files: 557it [00:12, 40.18it/s] Fetching ... files: 565it [00:12, 44.26it/s] Fetching ... files: 574it [00:12, 43.71it/s] Fetching ... files: 584it [00:12, 47.97it/s] Fetching ... files: 589it [00:12, 44.56it/s] Fetching ... files: 596it [00:13, 46.00it/s] Fetching ... files: 601it [00:13, 39.27it/s] Fetching ... files: 606it [00:13, 41.35it/s] Fetching ... files: 612it [00:13, 42.18it/s] Fetching ... files: 617it [00:13, 42.24it/s] Fetching ... files: 623it [00:13, 43.02it/s] Fetching ... files: 628it [00:13, 40.05it/s] Fetching ... files: 633it [00:14, 39.05it/s] Fetching ... files: 638it [00:14, 34.21it/s] Fetching ... files: 646it [00:14, 43.28it/s] Fetching ... files: 651it [00:14, 39.41it/s] Fetching ... files: 657it [00:14, 43.77it/s] Fetching ... files: 663it [00:14, 45.56it/s] Fetching ... files: 668it [00:14, 42.01it/s] Fetching ... files: 673it [00:15, 37.50it/s] Fetching ... files: 680it [00:15, 41.68it/s] Fetching ... files: 685it [00:15, 42.28it/s] Fetching ... files: 690it [00:15, 43.12it/s] Fetching ... files: 697it [00:15, 39.46it/s] Fetching ... files: 704it [00:15, 38.19it/s] Fetching ... files: 711it [00:15, 42.07it/s] Fetching ... files: 716it [00:16, 41.22it/s] Fetching ... files: 723it [00:16, 45.91it/s] Fetching ... files: 728it [00:16, 45.72it/s] Fetching ... files: 733it [00:16, 43.77it/s] Fetching ... files: 738it [00:16, 44.01it/s] Fetching ... files: 743it [00:16, 41.70it/s] Fetching ... files: 749it [00:16, 41.95it/s] Fetching ... files: 755it [00:17, 36.45it/s] Fetching ... files: 763it [00:17, 39.88it/s] Fetching ... files: 770it [00:17, 46.21it/s] Fetching ... files: 777it [00:17, 50.24it/s] Fetching ... files: 783it [00:17, 45.14it/s] Fetching ... files: 788it [00:17, 33.88it/s] Fetching ... files: 797it [00:17, 44.10it/s] Fetching ... files: 803it [00:18, 35.82it/s] Fetching ... files: 812it [00:18, 38.69it/s] Fetching ... files: 818it [00:18, 40.91it/s] Fetching ... files: 825it [00:18, 44.41it/s] Fetching ... files: 830it [00:18, 31.29it/s] Fetching ... files: 841it [00:19, 41.23it/s] Fetching ... files: 848it [00:19, 44.69it/s] Fetching ... files: 854it [00:19, 38.86it/s] Fetching ... files: 861it [00:19, 43.79it/s] Fetching ... files: 866it [00:19, 36.68it/s] Fetching ... files: 872it [00:19, 38.37it/s] Fetching ... files: 883it [00:20, 40.56it/s] Fetching ... files: 890it [00:20, 43.56it/s] Fetching ... files: 895it [00:20, 41.77it/s] Fetching ... files: 902it [00:20, 44.35it/s] Fetching ... files: 909it [00:20, 49.53it/s] Fetching ... files: 917it [00:20, 47.26it/s] Fetching ... files: 922it [00:21, 46.21it/s] Fetching ... files: 927it [00:21, 46.23it/s] Fetching ... files: 932it [00:21, 46.89it/s] Fetching ... files: 939it [00:21, 46.29it/s] Fetching ... files: 945it [00:21, 43.79it/s] Fetching ... files: 950it [00:21, 38.91it/s] Fetching ... files: 958it [00:21, 46.56it/s] Fetching ... files: 967it [00:21, 46.98it/s] Fetching ... files: 975it [00:22, 52.30it/s] Fetching ... files: 981it [00:22, 50.38it/s] Fetching ... files: 989it [00:22, 55.48it/s] Fetching ... files: 995it [00:22, 47.15it/s] Fetching ... files: 1001it [00:22, 44.43it/s] Fetching ... files: 1006it [00:22, 45.34it/s] Fetching ... files: 1011it [00:23, 30.98it/s] Fetching ... files: 1023it [00:23, 46.39it/s] Fetching ... files: 1029it [00:23, 40.10it/s] Fetching ... files: 1036it [00:23, 42.11it/s] Fetching ... files: 1041it [00:23, 34.75it/s] Fetching ... files: 1049it [00:23, 36.90it/s] Fetching ... files: 1058it [00:24, 40.41it/s] Fetching ... files: 1063it [00:24, 39.30it/s] Fetching ... files: 1068it [00:24, 41.31it/s] Fetching ... files: 1073it [00:24, 42.10it/s] Fetching ... files: 1078it [00:24, 40.68it/s] Fetching ... files: 1084it [00:24, 44.08it/s] Fetching ... files: 1089it [00:24, 40.48it/s] Fetching ... files: 1094it [00:25, 39.40it/s] Fetching ... files: 1100it [00:25, 44.04it/s] Fetching ... files: 1106it [00:25, 46.39it/s] Fetching ... files: 1111it [00:25, 42.79it/s] Fetching ... files: 1116it [00:25, 39.96it/s] Fetching ... files: 1121it [00:25, 41.65it/s] Fetching ... files: 1126it [00:25, 41.15it/s] Fetching ... files: 1131it [00:25, 37.18it/s] Fetching ... files: 1135it [00:26, 33.31it/s] Fetching ... files: 1140it [00:26, 21.63it/s] Fetching ... files: 1153it [00:26, 35.96it/s] Fetching ... files: 1154it [00:26, 43.23it/s] Model loaded successfully! .. GENERATED FROM PYTHON SOURCE LINES 76-80 Next, two datasets are loaded from Hugging Face: a calibration set used to estimate risks and select an appropriate decision threshold, and a test set reserved for evaluating controlled predictions on unseen data. .. GENERATED FROM PYTHON SOURCE LINES 80-103 .. code-block:: Python CALIB_IMAGES_DIR = data_root / "calib" / "images" CALIB_MASKS_DIR = data_root / "calib" / "masks" TEST_IMAGES_DIR = data_root / "test" / "images" TEST_MASKS_DIR = data_root / "test" / "masks" calib_dataset = RoofSegmentationDataset( images_dir=CALIB_IMAGES_DIR, masks_dir=CALIB_MASKS_DIR, transform=get_validation_transforms(), ) calib_loader = torch.utils.data.DataLoader(calib_dataset, batch_size=8) test_dataset = RoofSegmentationDataset( images_dir=TEST_IMAGES_DIR, masks_dir=TEST_MASKS_DIR, transform=get_validation_transforms(), ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8) print(f"Calibration set size: {len(calib_dataset)}") print(f"Test set size: {len(test_dataset)}") .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset initialized with 289 image-mask pairs Dataset initialized with 288 image-mask pairs Calibration set size: 289 Test set size: 288 .. GENERATED FROM PYTHON SOURCE LINES 104-108 A :class:`~mapie.risk_control.SemanticSegmentationController` is instantiated to control the recall risk (1 - recall) and automatically select a threshold that meets the target recall level with high confidence. .. GENERATED FROM PYTHON SOURCE LINES 108-119 .. code-block:: Python TARGET_RECALL = 0.9 recall_controller = SemanticSegmentationController( predict_function=model, risk="recall", target_level=TARGET_RECALL, confidence_level=None, ) print(f"Target recall level: {TARGET_RECALL}") .. rst-class:: sphx-glr-script-out .. code-block:: none Target recall level: 0.9 .. GENERATED FROM PYTHON SOURCE LINES 120-124 During calibration, the controller evaluates the recall risk over a range of thresholds on the calibration dataset in order to identify an optimal decision threshold. .. GENERATED FROM PYTHON SOURCE LINES 124-144 .. code-block:: Python for i, sample in enumerate(calib_loader): image, mask = sample["image"], sample["mask"] image = image.to(DEVICE) mask = mask.cpu().numpy() # Filter images that contain masks has_mask = mask.sum(axis=(1, 2)) > 0 image = image[has_mask] mask = mask[has_mask] if len(image) > 0: with torch.no_grad(): recall_controller.compute_risks(image, mask) # Compute the best threshold recall_controller.compute_best_predict_param() print("Controller calibrated successfully!") print(f"Optimal threshold found: {recall_controller.best_predict_param[0]:.4f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Controller calibrated successfully! Optimal threshold found: 0.0300 .. GENERATED FROM PYTHON SOURCE LINES 145-148 Controlled predictions are visually inspected on a few test images to illustrate the effect of MAPIE thresholding compared to raw model outputs. .. GENERATED FROM PYTHON SOURCE LINES 148-212 .. code-block:: Python def denormalize_image(tensor_image: torch.Tensor) -> np.ndarray: """Denormalize image tensor for visualization.""" mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image = tensor_image.cpu().numpy().transpose(1, 2, 0) image = std * image + mean image = np.clip(image, 0, 1) return image # Select random test images NUM_EXAMPLES = 4 np.random.seed(42) # Get indices of images with masks indices_with_masks = [] for idx in range(len(test_dataset)): sample = test_dataset[idx] if sample["mask"].sum() > 0: indices_with_masks.append(idx) random_indices = np.random.choice(indices_with_masks, NUM_EXAMPLES, replace=False) fig, axes = plt.subplots(2, NUM_EXAMPLES, figsize=(4 * NUM_EXAMPLES, 10)) for col, idx in enumerate(random_indices): sample = test_dataset[idx] image = sample["image"].unsqueeze(0).to(DEVICE) mask = sample["mask"].cpu().numpy() with torch.no_grad(): # Get MAPIE prediction mapie_pred = recall_controller.predict(image)[0] # Denormalize image img_display = denormalize_image(sample["image"]) # Plot original image (top row) axes[0, col].imshow(img_display) axes[0, col].set_title("Original Image") axes[0, col].axis("off") # Plot MAPIE prediction with correct pixels in white and false negatives in red (bottom row) pred_visualization = np.zeros((*mapie_pred[0].shape, 3)) true_positives = mask * mapie_pred[0] pred_visualization[true_positives > 0] = [1, 1, 1] false_negatives = mask * (1 - mapie_pred[0]) pred_visualization[false_negatives > 0] = [1, 0, 0] axes[1, col].imshow(pred_visualization) axes[1, col].set_title( f"MAPIE Prediction (threshold={recall_controller.best_predict_param[0]:.2f})\n" "White: Correct | Red: False Negatives" ) axes[1, col].axis("off") plt.tight_layout() plt.show() .. image-sg:: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_001.png :alt: Original Image, Original Image, Original Image, Original Image, MAPIE Prediction (threshold=0.03) White: Correct | Red: False Negatives, MAPIE Prediction (threshold=0.03) White: Correct | Red: False Negatives, MAPIE Prediction (threshold=0.03) White: Correct | Red: False Negatives, MAPIE Prediction (threshold=0.03) White: Correct | Red: False Negatives :srcset: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 213-217 The controller is finally evaluated on the test set by computing the achieved recall on each image to verify that the target recall level is satisfied on unseen data. .. GENERATED FROM PYTHON SOURCE LINES 217-243 .. code-block:: Python recalls_list = [] for i, sample in enumerate(test_loader): image, mask = sample["image"], sample["mask"] image = image.to(DEVICE) mask = mask.cpu().numpy() # Filter images with masks has_mask = mask.sum(axis=(1, 2)) > 0 image = image[has_mask] mask = mask[has_mask] if len(image) > 0: with torch.no_grad(): pred = recall_controller.predict(image) # Compute recall for each image for j in range(len(image)): tp = (mask[j] * pred[j]).sum() fn = (mask[j] * (1 - pred[j])).sum() recall = tp / (tp + fn + 1e-8) recalls_list.append(recall) recalls_array = np.array(recalls_list) .. GENERATED FROM PYTHON SOURCE LINES 244-247 Finally, the distribution of recall values over the test set is plotted to summarize the controlled performance. .. GENERATED FROM PYTHON SOURCE LINES 247-274 .. code-block:: Python fig, ax = plt.subplots(figsize=(10, 6)) ax.hist(recalls_array, bins=30, alpha=0.7, color="steelblue", edgecolor="black") ax.axvline( TARGET_RECALL, color="red", linestyle="--", linewidth=2, label=f"Target Recall ({TARGET_RECALL})", ) ax.axvline( recalls_array.mean(), color="orange", linestyle="--", linewidth=2, label=f"Mean Recall ({recalls_array.mean():.3f})", ) ax.set_xlabel("Recall", fontsize=12) ax.set_ylabel("Frequency", fontsize=12) ax.set_title("Distribution of Recall on Test Set", fontsize=14, fontweight="bold") ax.legend(fontsize=10) ax.grid(True, alpha=0.3) plt.tight_layout() plt.show() .. image-sg:: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_002.png :alt: Distribution of Recall on Test Set :srcset: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 275-279 The histogram shows that most test images achieve or exceed the target recall level, illustrating the effectiveness of MAPIE’s risk control for semantic segmentation tasks. .. GENERATED FROM PYTHON SOURCE LINES 281-284 Bootstrap the mean recall over different samplings of the test set (resampling images with replacement). .. GENERATED FROM PYTHON SOURCE LINES 284-328 .. code-block:: Python N_BOOTSTRAP = 2000 BOOTSTRAP_SEED = 123 rng = np.random.default_rng(BOOTSTRAP_SEED) bootstrap_means = np.empty(N_BOOTSTRAP, dtype=float) n = recalls_array.size for b in range(N_BOOTSTRAP): bootstrap_sample = rng.choice(recalls_array, size=n, replace=True) bootstrap_means[b] = bootstrap_sample.mean() fig, ax = plt.subplots(figsize=(10, 6)) ax.hist( bootstrap_means, bins=40, alpha=0.7, color="slateblue", edgecolor="black", ) ax.axvline( bootstrap_means.mean(), color="green", linestyle="--", linewidth=2, label=f"Bootstrap mean ({bootstrap_means.mean():.3f})", ) ax.axvline( TARGET_RECALL, color="orange", linestyle="--", linewidth=2, label=f"Target recall ({TARGET_RECALL:.2f})", ) ax.set_xlabel("Bootstrap mean recall", fontsize=12) ax.set_ylabel("Frequency", fontsize=12) ax.set_title( "Bootstrap distribution of mean recall (test set resampling)", fontsize=14, fontweight="bold", ) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) plt.tight_layout() plt.show() .. image-sg:: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_003.png :alt: Bootstrap distribution of mean recall (test set resampling) :srcset: /examples_risk_control/2-advanced-analysis/images/sphx_glr_plot_semantic_segmentation_recall_control_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (6 minutes 14.015 seconds) .. _sphx_glr_download_examples_risk_control_2-advanced-analysis_plot_semantic_segmentation_recall_control.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_semantic_segmentation_recall_control.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_semantic_segmentation_recall_control.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_semantic_segmentation_recall_control.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_