from typing import Sequence, Union, cast
import numpy as np
from numpy.typing import ArrayLike, NDArray
from mapie.risk_control import MultiLabelClassificationController
from mapie.utils import check_is_fitted
[docs]
class SemanticSegmentationController(MultiLabelClassificationController):
"""
Risk controller for semantic segmentation tasks,
inheriting from MultiLabelClassificationController.
"""
def _transform_pred_proba(
self, y_pred_proba: Union[Sequence[NDArray], NDArray], ravel: bool = True
) -> NDArray:
"""
Transform predicted probabilities for semantic segmentation tasks.
Parameters
----------
y_pred_proba: Union[Sequence[NDArray], NDArray]
Predicted probabilities or logits for each class and for each pixel.
Shape: (n_samples, n_classes, height, width)
ravel: bool, default=True
Whether to ravel the output array. Ravel is used when computing risks
on the calibration dataset.
Returns
-------
NDArray
Transformed predicted probabilities of shape
(n_samples, n_classes*height*width, 1) if ravel=True, else unchanged shape.
"""
if not isinstance(y_pred_proba, np.ndarray):
y_pred_proba = np.array(y_pred_proba)
y_pred_proba_array = cast(NDArray, y_pred_proba) # for mypy
if np.min(y_pred_proba_array) < 0 or np.max(y_pred_proba_array) > 1:
# Apply sigmoid to convert logits to probabilities
y_pred_proba_array = 1 / (1 + np.exp(-y_pred_proba_array))
if ravel:
return y_pred_proba_array.reshape(len(y_pred_proba_array), -1)[
..., np.newaxis
]
return y_pred_proba_array
[docs]
def predict(
self,
X: ArrayLike,
) -> NDArray:
"""
Prediction sets on new samples based on the target risk level.
Prediction sets for a given `alpha` are deduced from the computed
risks.
Parameters
----------
X: ArrayLike of shape (n_samples, n_features)
Returns
-------
NDArray of shape (n_samples, n_classes, n_alpha)
"""
check_is_fitted(self)
# Estimate prediction sets
y_pred_proba = self._predict_function(X)
y_pred_proba_array = self._transform_pred_proba(y_pred_proba, ravel=False)
y_pred_proba_array = np.repeat(y_pred_proba_array, len(self._alpha), axis=1)
y_pred_proba_array = (
y_pred_proba_array
> self.best_predict_param[np.newaxis, :, np.newaxis, np.newaxis]
)
return y_pred_proba_array