Skip to content

Reference for ultralytics/models/fastsam/predict.py

Note

This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!


ultralytics.models.fastsam.predict.FastSAMPredictor

FastSAMPredictor(cfg=DEFAULT_CFG, overrides=None, _callbacks=None)

Bases: SegmentationPredictor

FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.

This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for single-class segmentation.

Attributes:

Name Type Description
prompts dict

Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).

device device

Device on which model and tensors are processed.

clip_model Any

CLIP model for text-based prompting, loaded on demand.

clip_preprocess Any

CLIP preprocessing function for images, loaded on demand.

Methods:

Name Description
postprocess

Applies box postprocessing for FastSAM predictions.

prompt

Performs image segmentation inference based on various prompt types.

_clip_inference

Performs CLIP inference to calculate similarity between images and text prompts.

set_prompts

Sets prompts to be used during inference.

This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression optimized for single-class segmentation.

Parameters:

Name Type Description Default
cfg dict

Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.

DEFAULT_CFG
overrides dict

Configuration overrides.

None
_callbacks list

List of callback functions.

None
Source code in ultralytics/models/fastsam/predict.py
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
    """
    Initialize the FastSAMPredictor with configuration and callbacks.

    This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
    extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
    optimized for single-class segmentation.

    Args:
        cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
        overrides (dict, optional): Configuration overrides.
        _callbacks (list, optional): List of callback functions.
    """
    super().__init__(cfg, overrides, _callbacks)
    self.prompts = {}

postprocess

postprocess(preds, img, orig_imgs)

Apply postprocessing to FastSAM predictions and handle prompts.

Parameters:

Name Type Description Default
preds List[Tensor]

Raw predictions from the model.

required
img Tensor

Input image tensor that was fed to the model.

required
orig_imgs List[ndarray]

Original images before preprocessing.

required

Returns:

Type Description
List[Results]

Processed results with prompts applied.

Source code in ultralytics/models/fastsam/predict.py
def postprocess(self, preds, img, orig_imgs):
    """
    Apply postprocessing to FastSAM predictions and handle prompts.

    Args:
        preds (List[torch.Tensor]): Raw predictions from the model.
        img (torch.Tensor): Input image tensor that was fed to the model.
        orig_imgs (List[numpy.ndarray]): Original images before preprocessing.

    Returns:
        (List[Results]): Processed results with prompts applied.
    """
    bboxes = self.prompts.pop("bboxes", None)
    points = self.prompts.pop("points", None)
    labels = self.prompts.pop("labels", None)
    texts = self.prompts.pop("texts", None)
    results = super().postprocess(preds, img, orig_imgs)
    for result in results:
        full_box = torch.tensor(
            [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
        )
        boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
        idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
        if idx.numel() != 0:
            result.boxes.xyxy[idx] = full_box

    return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)

prompt

prompt(results, bboxes=None, points=None, labels=None, texts=None)

Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

Parameters:

Name Type Description Default
results Results | List[Results]

Original inference results from FastSAM models without any prompts.

required
bboxes ndarray | List

Bounding boxes with shape (N, 4), in XYXY format.

None
points ndarray | List

Points indicating object locations with shape (N, 2), in pixels.

None
labels ndarray | List

Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.

None
texts str | List[str]

Textual prompts, a list containing string objects.

None

Returns:

Type Description
List[Results]

Output results filtered and determined by the provided prompts.

Source code in ultralytics/models/fastsam/predict.py
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
    """
    Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.

    Args:
        results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
        bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
        points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
        labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
        texts (str | List[str], optional): Textual prompts, a list containing string objects.

    Returns:
        (List[Results]): Output results filtered and determined by the provided prompts.
    """
    if bboxes is None and points is None and texts is None:
        return results
    prompt_results = []
    if not isinstance(results, list):
        results = [results]
    for result in results:
        if len(result) == 0:
            prompt_results.append(result)
            continue
        masks = result.masks.data
        if masks.shape[1:] != result.orig_shape:
            masks = scale_masks(masks[None], result.orig_shape)[0]
        # bboxes prompt
        idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
        if bboxes is not None:
            bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
            bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
            mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
            full_mask_areas = torch.sum(masks, dim=(1, 2))

            union = bbox_areas[:, None] + full_mask_areas - mask_areas
            idx[torch.argmax(mask_areas / union, dim=1)] = True
        if points is not None:
            points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
            points = points[None] if points.ndim == 1 else points
            if labels is None:
                labels = torch.ones(points.shape[0])
            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
            assert len(labels) == len(points), (
                f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
            )
            point_idx = (
                torch.ones(len(result), dtype=torch.bool, device=self.device)
                if labels.sum() == 0  # all negative points
                else torch.zeros(len(result), dtype=torch.bool, device=self.device)
            )
            for point, label in zip(points, labels):
                point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
            idx |= point_idx
        if texts is not None:
            if isinstance(texts, str):
                texts = [texts]
            crop_ims, filter_idx = [], []
            for i, b in enumerate(result.boxes.xyxy.tolist()):
                x1, y1, x2, y2 = (int(x) for x in b)
                if masks[i].sum() <= 100:
                    filter_idx.append(i)
                    continue
                crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
            similarity = self._clip_inference(crop_ims, texts)
            text_idx = torch.argmax(similarity, dim=-1)  # (M, )
            if len(filter_idx):
                text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
            idx[text_idx] = True

        prompt_results.append(result[idx])

    return prompt_results

set_prompts

set_prompts(prompts)

Set prompts to be used during inference.

Source code in ultralytics/models/fastsam/predict.py
def set_prompts(self, prompts):
    """Set prompts to be used during inference."""
    self.prompts = prompts



📅 Created 1 year ago ✏️ Updated 7 months ago