Reference for ultralytics/models/nas/predict.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/nas/predict.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.models.nas.predict.NASPredictor
Bases: DetectionPredictor
Ultralytics YOLO NAS Predictor for object detection.
This class extends the DetectionPredictor
from Ultralytics engine and is responsible for post-processing the
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
scaling the bounding boxes to fit the original image dimensions.
Attributes:
Name | Type | Description |
---|---|---|
args |
Namespace
|
Namespace containing various configurations for post-processing including confidence threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options. |
model |
Module
|
The YOLO NAS model used for inference. |
batch |
list
|
Batch of inputs for processing. |
Examples:
Assume that raw_preds, img, orig_imgs are available
Notes
Typically, this class is not instantiated directly. It is used internally within the NAS
class.
Source code in ultralytics/engine/predictor.py
postprocess
Postprocess NAS model predictions to generate final detection results.
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies post-processing operations to generate the final detection results compatible with Ultralytics result visualization and analysis tools.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds_in
|
list
|
Raw predictions from the NAS model, typically containing bounding boxes and class scores. |
required |
img
|
Tensor
|
Input image tensor that was fed to the model, with shape (B, C, H, W). |
required |
orig_imgs
|
list | Tensor | ndarray
|
Original images before preprocessing, used for scaling coordinates back to original dimensions. |
required |
Returns:
Type | Description |
---|---|
list
|
List of Results objects containing the processed predictions for each image in the batch. |
Examples:
>>> predictor = NAS("yolo_nas_s").predictor
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)