Module sam_onnx

Expand source code
from copy import deepcopy
from typing import Any, Tuple, Union, List
import cv2
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import glob
import gdown
import os

def check_and_download_weights(model_name='l0'):

    __supported_modelnames = ['l0', 'xl0']

    assert model_name in __supported_modelnames, f'Model name not supported. Please use one of : {__supported_modelnames}'
    
    l0_weights = {'encoder' : 'https://drive.google.com/file/d/1a0tRmHQeGTAbSeMqBMhu4DinsOR3cSv6/view?usp=sharing',
                 'decoder': 'https://drive.google.com/file/d/13J7pNfh016sBqOQ17CludkUFdKgkkyQM/view?usp=sharing'}
    
    xl0_weights = {'encoder': 'https://drive.google.com/file/d/1NzavgCAqk6mSzTnQ_LKfl78V_O68lWNX/view?usp=sharing',
                  'decoder': 'https://drive.google.com/file/d/1lrn5bQRE01Mwtp-nr9DBNTHcxk4Q6iiP/view?usp=sharing'}
    
    if os.path.exists('model_weights'):
        model_weights_folder_path = os.path.abspath('model_weights')
    else:
        os.makedirs('model_weights', 
                    exist_ok = False)
        model_weights_folder_path = os.path.abspath('model_weights')
    
    if os.path.exists(f'model_weights/{model_name}/encoder.onnx'):
        encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')
    else:
        os.makedirs(f'model_weights/{model_name}', 
                    exist_ok = True)
        if model_name == 'l0':
            gdown.download(l0_weights['encoder'],
                           f'model_weights/{model_name}/encoder.onnx', 
                           fuzzy=True)
        if model_name == 'xl0':
            gdown.download(xl0_weights['encoder'],
                           f'model_weights/{model_name}/encoder.onnx', 
                           fuzzy=True)
        encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')
            
    if os.path.exists(f'model_weights/{model_name}/decoder.onnx'):
        decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')
    else:
        if model_name == 'l0':
            gdown.download(l0_weights['decoder'],
                           f'model_weights/{model_name}/decoder.onnx', 
                           fuzzy=True)
        if model_name == 'xl0':
            gdown.download(xl0_weights['decoder'],
                           f'model_weights/{model_name}/decoder.onnx', 
                           fuzzy=True)
        decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')

    return encoder_weights_path, decoder_weights_path
    
def show_mask(mask, ax, random_color=False):
    """
    Visualize a mask image on the given axis.

    Parameters
    ----------
    mask : np.ndarray
        The mask image to visualize.
    ax : matplotlib.axes.Axes
        The axis to plot on.
    random_color : bool, optional
        Whether to use a random color for the mask, by default False
    """
    if random_color:
        # Create a random color with some transparency
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        # Use a specific color with some transparency
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    """
    Show points on the axis.

    Parameters
    ----------
    coords : np.ndarray
        The coordinates of the points to show.
    labels : np.ndarray
        The labels of the points.
    ax : matplotlib.axes.Axes
        The axis to plot on.
    marker_size : int, optional
        The size of the markers, by default 375
    """
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    """
    Show a bounding box on the axis.

    Parameters
    ----------
    box : list
        The bounding box coordinates as [x0, y0, x1, y1].
    ax : matplotlib.axes.Axes
        The axis to plot on.
    """
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


class SamEncoder:
    """
    The encoder class that loads and runs the SAM encoder model.

    Parameters
    ----------
    model_path: str
        The path to the encoder model.
    device: str, optional (default is 'cpu')
        The device to run the model, either 'cuda' or 'cpu'.
    kwargs: dict
        Additional arguments to be passed to the `InferenceSession` class from
        the onnxruntime library.

    Attributes
    ----------
    session: InferenceSession
        The loaded encoder model.
    input_name: str
        The name of the input layer of the model.
    """

    def __init__(self, model_path: str, device: str = "cpu", **kwargs):
        opt = ort.SessionOptions()

        if device == "cuda":
            provider = ["CUDAExecutionProvider"]
        elif device == "cpu":
            provider = ["CPUExecutionProvider"]
        else:
            raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.")

        print(f"loading encoder model from {model_path}...")
        self.session = ort.InferenceSession(
            model_path, opt, providers=provider, **kwargs
        )
        self.input_name = self.session.get_inputs()[0].name

    def _extract_feature(self, tensor: np.ndarray) -> np.ndarray:
        """
        Extract the feature from the input image tensor using the loaded
        encoder model.

        Parameters
        ----------
        tensor: numpy.ndarray
            The input image tensor.

        Returns
        -------
        feature: numpy.ndarray
            The feature extracted from the input image.
        """
        feature = self.session.run(None, {self.input_name: tensor})[0]
        return feature

    def __call__(self, img: np.array, *args: Any, **kwds: Any) -> Any:
        """
        Call the encoder with the input image.

        Parameters
        ----------
        img: numpy.ndarray
            The input image.
        args, kwargs:
            Additional positional and keyword arguments to be passed to the
            encoder.

        Returns
        -------
        feature: numpy.ndarray
            The feature extracted from the input image.
        """
        return self._extract_feature(img)


class SamDecoder:
    """
    The decoder class that loads and runs the SAM decoder model.

    Parameters
    ----------
    model_path: str
        The path to the decoder model.
    device: str, default="cpu"
        The device to run the model, either "cuda" or "cpu".
    target_size: int, default=1024
        The target size of the output mask. The final mask size may be
        smaller if the original image is too small.
    mask_threshold: float, default=0.0
        The threshold value to binarize the output mask.
    kwargs: Any
        Additional arguments to be passed to onnxruntime.InferenceSession.

    Attributes
    ----------
    target_size: int
        The target size of the output mask.
    mask_threshold: float
        The threshold value to binarize the output mask.
    session: onnxruntime.InferenceSession
        The inference session of the loaded decoder model.
    """

    def __init__(
        self,
        model_path: str,
        device: str = "cpu",
        target_size: int = 1024,
        mask_threshold: float = 0.0,
        **kwargs,
    ):
        opt = ort.SessionOptions()

        if device == "cuda":
            provider = ["CUDAExecutionProvider"]
        elif device == "cpu":
            provider = ["CPUExecutionProvider"]
        else:
            raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.")

        print(f"loading decoder model from {model_path}...")
        self.target_size = target_size
        self.mask_threshold = mask_threshold
        self.session = ort.InferenceSession(
            model_path, opt, providers=provider, **kwargs
        )

    @staticmethod
    def get_preprocess_shape(
        oldh: int, oldw: int, long_side_length: int
    ) -> Tuple[int, int]:
        """
        Compute the output size given input size and target long side length.

        Parameters
        ----------
        oldh: int
            The height of the input image.
        oldw: int
            The width of the input image.
        long_side_length: int
            The target long side length of the output image.

        Returns
        -------
        Tuple[int, int]
            The (height, width) of the output image after resizing.
        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

    def run(
        self,
        img_embeddings: np.ndarray,
        origin_image_size: Union[list, tuple],
        point_coords: Union[list, np.ndarray] = None,
        point_labels: Union[list, np.ndarray] = None,
        boxes: Union[list, np.ndarray] = None,
        return_logits: bool = False,
    ) -> Tuple[np.ndarray, Any, Any]:
        """
        Run the SAM decoder to segment an input image.

        Parameters
        ----------
        img_embeddings: np.ndarray
            The image embeddings obtained from SAM encoder.
            The shape should be (1, 256, 64, 64).
        origin_image_size: Union[list, tuple]
            The original size of the input image, (height, width)
        point_coords: Union[list, np.ndarray], optional
            The coordinates of the points in the input image.
            The shape should be (N, 2), where N is the number of points.
        point_labels: Union[list, np.ndarray], optional
            The labels of the points.
            The shape should be (N,) where N is the number of points.
        boxes: Union[list, np.ndarray], optional
            The coordinates of the bounding boxes in the input image.
            The shape should be (M, 4), where M is the number of boxes.
        return_logits: bool, default False
            Whether to return the logits (before sigmoid) of the mask predictions.

        Returns
        -------
        Tuple[np.ndarray, Any, Any]
            The segmentation masks, IoU scores, and low-resolution masks.
        """
        input_size = self.get_preprocess_shape(
            *origin_image_size, long_side_length=self.target_size
        )

        if point_coords is None and point_labels is None and boxes is None:
            raise ValueError(
                "Unable to segment, please input at least one box or point."
            )

        if img_embeddings.shape != (1, 256, 64, 64):
            raise ValueError("Got wrong embedding shape!")

        if point_coords is not None:
            point_coords = self.apply_coords(
                point_coords, origin_image_size, input_size
            ).astype(np.float32)

            prompts, labels = point_coords, point_labels

        if boxes is not None:
            boxes = self.apply_boxes(boxes, origin_image_size, input_size).astype(
                np.float32
            )
            box_labels = np.array(
                [[2, 3] for _ in range(boxes.shape[0])], dtype=np.float32
            ).reshape((-1, 2))

            if point_coords is not None:
                prompts = np.concatenate([prompts, boxes], axis=1)
                labels = np.concatenate([labels, box_labels], axis=1)
            else:
                prompts, labels = boxes, box_labels

        input_dict = {
            "image_embeddings": img_embeddings,
            "point_coords": prompts,
            "point_labels": labels,
        }

        # Run the inference
        low_res_masks, iou_predictions = self.session.run(None, input_dict)

        # Post-process the masks
        masks = np_mask_postprocessing(low_res_masks, np.array(origin_image_size))

        if not return_logits:
            masks = masks > self.mask_threshold

        return masks, iou_predictions, low_res_masks

    def apply_coords(self, coords, original_size, new_size):
        """
        Applies the resizing to the coordinates.

        Parameters
        ----------
        coords : np.ndarray
            The coordinates to be resized.
            The shape should be (N, 2), where N is the number of points.
        original_size : Union[list, tuple]
            The original size of the input image, (height, width)
        new_size : Union[list, tuple]
            The new size of the input image, (height, width)

        Returns
        -------
        np.ndarray
            The resized coordinates.
        """
        old_h, old_w = original_size
        new_h, new_w = new_size
        coords = deepcopy(coords).astype(float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    def apply_boxes(self, boxes, original_size, new_size):
        """
        Applies the resizing to the bounding boxes.

        Parameters
        ----------
        boxes : np.ndarray
            The coordinates of the bounding boxes in the input image.
            The shape should be (M, 4), where M is the number of boxes.
        original_size : Union[list, tuple]
            The original size of the input image, (height, width)
        new_size : Union[list, tuple]
            The new size of the input image, (height, width)

        Returns
        -------
        np.ndarray
            The resized bounding boxes.
        """
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, new_size)
        return boxes


def np_resize_longest_image_size(
    input_image_size: np.array, longest_side: int
) -> np.array:
    """Resizes the image size to the longest side.

    Parameters
    ----------
    input_image_size : np.array
        Size of the input image in (height, width) format.
    longest_side : int
        Desired longest side of the resized image.

    Returns
    -------
    np.array
        Size of the resized image in (height, width) format.
    """
    scale = longest_side / np.max(input_image_size)
    transformed_size = scale * input_image_size
    transformed_size = np.floor(transformed_size + 0.5).astype(np.int64)
    return transformed_size


def np_interp(x: np.array, size: tuple) -> np.array:
    """Interpolates a batch of masks to a given size.

    Parameters
    ----------
    x : np.array
        A batch of masks with shape (batch_size, 1, height, width).
    size : tuple
        Desired size of the masks (height, width) format.

    Returns
    -------
    np.array
        A batch of interpolated masks with shape (batch_size, 1, height, width).
    """
    _rmsk = []
    for m in range(x.shape[0]):
        msk = x[m, 0, :, :]
        resized_array = cv2.resize(msk, size, interpolation=cv2.INTER_LINEAR)
        _rmsk.append(resized_array)
    np_rmsk = np.array(_rmsk)
    np_rmsk = np_rmsk[:, np.newaxis, :, :]
    return np_rmsk


def np_mask_postprocessing(masks: np.array, orig_im_size: np.array) -> np.array:
    """
    Perform postprocessing on predicted masks by interpolating them to
    desired size and then resizing them back to original image size.

    Parameters
    ----------
    masks : np.array
        Predicted masks.
    orig_im_size : np.array
        Original image size.

    Returns
    -------
    np.array
        Postprocessed masks.
    """

    img_size = 1024  # Desired output size
    masks = np_interp(masks, (img_size, img_size))

    # Pad predicted masks to desired output size
    prepadded_size = np_resize_longest_image_size(orig_im_size, img_size)
    masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]

    # Resize padded masks back to original image size
    origin_image_size = orig_im_size.astype(np.int64)
    w, h = origin_image_size[0], origin_image_size[1]
    masks = np_interp(masks, (h, w))
    return masks

def preprocess_np(x, img_size):
    """
    Preprocess an image with mean and std normalization and padding to
    desired size.

    Parameters
    ----------
    x : numpy.ndarray
        Image to be preprocessed.
    img_size : int
        Desired size of the longer edge of the image.

    Returns
    -------
    numpy.ndarray
        Preprocessed image.

    """
    pixel_mean = np.array([123.675 / 255, 116.28 / 255, 103.53 / 255]).astype(np.float32)
    pixel_std = np.array([58.395 / 255, 57.12 / 255, 57.375 / 255]).astype(np.float32)

    oh, ow, _ = x.shape
    long_side = max(oh, ow)
    if long_side != img_size:
        # Resize the image with long side == img_size
        scale = img_size * 1.0 / max(oh, ow)
        newh, neww = int(oh * scale + 0.5), int(ow * scale + 0.5)
        x = cv2.resize(x, (neww, newh))

    h, w = x.shape[:2]
    x = x.astype(np.float32) / 255  # Normalize to [0, 1]
    x = (x - pixel_mean) / pixel_std  # Normalize pixel values
    th, tw = img_size, img_size
    assert th >= h and tw >= w, "image is too small"

    # Pad the image with zeros if shorter than desired size
    x = np.pad(
        x,
        ((0, th - h), (0, tw - w), (0, 0)),
        mode="constant",
        constant_values=0,  # (top, bottom), (left, right)
    ).astype(np.float32)

    # Transpose the image from HWC to CHW and add batch dimension
    x = x.transpose((2, 0, 1))[np.newaxis, :, :, :]

    return x
            
class InferSAM:
    """
    Class for inference with SAM models.

    Parameters
    ----------
    model_dir : str
        Directory containing trained SAM model.
    model_name : str, default 'l0'
        Name of the model to use.
        Must be one of ['l0', 'l1', 'l2', 'xl0', 'xl1'].

    Attributes
    ----------
    model_name : str
        Name of the model to use.
    encoder : SamEncoder
        The encoder part of the SAM model.
    decoder : SamDecoder
        The decoder part of the SAM model.

    """

    def __init__(self, model_name: str = "l0"):
        # assert model_dir is not None, "model_dir is null"
        assert model_name is not None, "model_name is null"

        self.model_name = model_name
        encoder_weights_path, decoder_weights_path = check_and_download_weights(model_name)
        
        # Find encoder and decoder models
        encoder_path = encoder_weights_path # glob.glob(model_dir + "/*_encoder.onnx")[0]
        decoder_path = decoder_weights_path # glob.glob(model_dir + "/*_decoder.onnx")[0]

        self.encoder = SamEncoder(encoder_path)
        self.decoder = SamDecoder(decoder_path)

        self.figsize = (10,10)

    def infer(
        self,
        img_path: str,
        boxes: List[list] = [[80, 50, 320, 420], [300, 20, 530, 420]],
        visualize=False,
    ) -> np.array:
        """
        Infer segmentation masks for a given image using the SAM model.

        Parameters
        ----------
        img_path : str
            Path to the input image.
        boxes : list of lists, default [[80, 50, 320, 420], [300, 20, 530, 420]]
            List of boxes, each box is a list of 4 ints, representing
            [xmax, ymax, xmin, ymin] coordinates.

        Returns
        -------
        masks : np.array
            A numpy array of shape (N, 1, H, W) containing segmentation masks,
            where N is the number of boxes, H and W are the height and width of
            the input image.

        """
        assert img_path is not None, "img_path is null"

        raw_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        assert raw_img is not None, "raw_img is null"

        origin_image_size = raw_img.shape[:2]

        img = None
        if self.model_name in ["l0", "l1", "l2"]:
            img = preprocess_np(raw_img, img_size=512)
        elif self.model_name in ["xl0", "xl1"]:
            img = preprocess_np(raw_img, img_size=1024)
        assert img is not None, "img is null"

        boxes = np.array(boxes, dtype=np.float32)  # xmax, ymax, xmin, ymin

        img_embeddings = self.encoder(img)
        masks, _, _ = self.decoder.run(
            img_embeddings=img_embeddings,
            origin_image_size=origin_image_size,
            boxes=boxes,
        )
        if visualize:
            plt.figure(figsize=self.figsize)
            plt.imshow(raw_img)
            for mask in masks:
                show_mask(mask, plt.gca(), 
                          random_color=True)
            for box in boxes:
                show_box(box, plt.gca())
            plt.show()
        return masks

    def set_figsize(self,figsize=(10,10)):
        self.figsize = figsize

Functions

def check_and_download_weights(model_name='l0')
Expand source code
def check_and_download_weights(model_name='l0'):

    __supported_modelnames = ['l0', 'xl0']

    assert model_name in __supported_modelnames, f'Model name not supported. Please use one of : {__supported_modelnames}'
    
    l0_weights = {'encoder' : 'https://drive.google.com/file/d/1a0tRmHQeGTAbSeMqBMhu4DinsOR3cSv6/view?usp=sharing',
                 'decoder': 'https://drive.google.com/file/d/13J7pNfh016sBqOQ17CludkUFdKgkkyQM/view?usp=sharing'}
    
    xl0_weights = {'encoder': 'https://drive.google.com/file/d/1NzavgCAqk6mSzTnQ_LKfl78V_O68lWNX/view?usp=sharing',
                  'decoder': 'https://drive.google.com/file/d/1lrn5bQRE01Mwtp-nr9DBNTHcxk4Q6iiP/view?usp=sharing'}
    
    if os.path.exists('model_weights'):
        model_weights_folder_path = os.path.abspath('model_weights')
    else:
        os.makedirs('model_weights', 
                    exist_ok = False)
        model_weights_folder_path = os.path.abspath('model_weights')
    
    if os.path.exists(f'model_weights/{model_name}/encoder.onnx'):
        encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')
    else:
        os.makedirs(f'model_weights/{model_name}', 
                    exist_ok = True)
        if model_name == 'l0':
            gdown.download(l0_weights['encoder'],
                           f'model_weights/{model_name}/encoder.onnx', 
                           fuzzy=True)
        if model_name == 'xl0':
            gdown.download(xl0_weights['encoder'],
                           f'model_weights/{model_name}/encoder.onnx', 
                           fuzzy=True)
        encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')
            
    if os.path.exists(f'model_weights/{model_name}/decoder.onnx'):
        decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')
    else:
        if model_name == 'l0':
            gdown.download(l0_weights['decoder'],
                           f'model_weights/{model_name}/decoder.onnx', 
                           fuzzy=True)
        if model_name == 'xl0':
            gdown.download(xl0_weights['decoder'],
                           f'model_weights/{model_name}/decoder.onnx', 
                           fuzzy=True)
        decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')

    return encoder_weights_path, decoder_weights_path
def np_interp(x: , size: tuple) ‑> 

Interpolates a batch of masks to a given size.

Parameters

x : np.array
A batch of masks with shape (batch_size, 1, height, width).
size : tuple
Desired size of the masks (height, width) format.

Returns

np.array
A batch of interpolated masks with shape (batch_size, 1, height, width).
Expand source code
def np_interp(x: np.array, size: tuple) -> np.array:
    """Interpolates a batch of masks to a given size.

    Parameters
    ----------
    x : np.array
        A batch of masks with shape (batch_size, 1, height, width).
    size : tuple
        Desired size of the masks (height, width) format.

    Returns
    -------
    np.array
        A batch of interpolated masks with shape (batch_size, 1, height, width).
    """
    _rmsk = []
    for m in range(x.shape[0]):
        msk = x[m, 0, :, :]
        resized_array = cv2.resize(msk, size, interpolation=cv2.INTER_LINEAR)
        _rmsk.append(resized_array)
    np_rmsk = np.array(_rmsk)
    np_rmsk = np_rmsk[:, np.newaxis, :, :]
    return np_rmsk
def np_mask_postprocessing(masks: , orig_im_size: ) ‑> 

Perform postprocessing on predicted masks by interpolating them to desired size and then resizing them back to original image size.

Parameters

masks : np.array
Predicted masks.
orig_im_size : np.array
Original image size.

Returns

np.array
Postprocessed masks.
Expand source code
def np_mask_postprocessing(masks: np.array, orig_im_size: np.array) -> np.array:
    """
    Perform postprocessing on predicted masks by interpolating them to
    desired size and then resizing them back to original image size.

    Parameters
    ----------
    masks : np.array
        Predicted masks.
    orig_im_size : np.array
        Original image size.

    Returns
    -------
    np.array
        Postprocessed masks.
    """

    img_size = 1024  # Desired output size
    masks = np_interp(masks, (img_size, img_size))

    # Pad predicted masks to desired output size
    prepadded_size = np_resize_longest_image_size(orig_im_size, img_size)
    masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]

    # Resize padded masks back to original image size
    origin_image_size = orig_im_size.astype(np.int64)
    w, h = origin_image_size[0], origin_image_size[1]
    masks = np_interp(masks, (h, w))
    return masks
def np_resize_longest_image_size(input_image_size: , longest_side: int) ‑> 

Resizes the image size to the longest side.

Parameters

input_image_size : np.array
Size of the input image in (height, width) format.
longest_side : int
Desired longest side of the resized image.

Returns

np.array
Size of the resized image in (height, width) format.
Expand source code
def np_resize_longest_image_size(
    input_image_size: np.array, longest_side: int
) -> np.array:
    """Resizes the image size to the longest side.

    Parameters
    ----------
    input_image_size : np.array
        Size of the input image in (height, width) format.
    longest_side : int
        Desired longest side of the resized image.

    Returns
    -------
    np.array
        Size of the resized image in (height, width) format.
    """
    scale = longest_side / np.max(input_image_size)
    transformed_size = scale * input_image_size
    transformed_size = np.floor(transformed_size + 0.5).astype(np.int64)
    return transformed_size
def preprocess_np(x, img_size)

Preprocess an image with mean and std normalization and padding to desired size.

Parameters

x : numpy.ndarray
Image to be preprocessed.
img_size : int
Desired size of the longer edge of the image.

Returns

numpy.ndarray
Preprocessed image.
Expand source code
def preprocess_np(x, img_size):
    """
    Preprocess an image with mean and std normalization and padding to
    desired size.

    Parameters
    ----------
    x : numpy.ndarray
        Image to be preprocessed.
    img_size : int
        Desired size of the longer edge of the image.

    Returns
    -------
    numpy.ndarray
        Preprocessed image.

    """
    pixel_mean = np.array([123.675 / 255, 116.28 / 255, 103.53 / 255]).astype(np.float32)
    pixel_std = np.array([58.395 / 255, 57.12 / 255, 57.375 / 255]).astype(np.float32)

    oh, ow, _ = x.shape
    long_side = max(oh, ow)
    if long_side != img_size:
        # Resize the image with long side == img_size
        scale = img_size * 1.0 / max(oh, ow)
        newh, neww = int(oh * scale + 0.5), int(ow * scale + 0.5)
        x = cv2.resize(x, (neww, newh))

    h, w = x.shape[:2]
    x = x.astype(np.float32) / 255  # Normalize to [0, 1]
    x = (x - pixel_mean) / pixel_std  # Normalize pixel values
    th, tw = img_size, img_size
    assert th >= h and tw >= w, "image is too small"

    # Pad the image with zeros if shorter than desired size
    x = np.pad(
        x,
        ((0, th - h), (0, tw - w), (0, 0)),
        mode="constant",
        constant_values=0,  # (top, bottom), (left, right)
    ).astype(np.float32)

    # Transpose the image from HWC to CHW and add batch dimension
    x = x.transpose((2, 0, 1))[np.newaxis, :, :, :]

    return x
def show_box(box, ax)

Show a bounding box on the axis.

Parameters

box : list
The bounding box coordinates as [x0, y0, x1, y1].
ax : matplotlib.axes.Axes
The axis to plot on.
Expand source code
def show_box(box, ax):
    """
    Show a bounding box on the axis.

    Parameters
    ----------
    box : list
        The bounding box coordinates as [x0, y0, x1, y1].
    ax : matplotlib.axes.Axes
        The axis to plot on.
    """
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )
def show_mask(mask, ax, random_color=False)

Visualize a mask image on the given axis.

Parameters

mask : np.ndarray
The mask image to visualize.
ax : matplotlib.axes.Axes
The axis to plot on.
random_color : bool, optional
Whether to use a random color for the mask, by default False
Expand source code
def show_mask(mask, ax, random_color=False):
    """
    Visualize a mask image on the given axis.

    Parameters
    ----------
    mask : np.ndarray
        The mask image to visualize.
    ax : matplotlib.axes.Axes
        The axis to plot on.
    random_color : bool, optional
        Whether to use a random color for the mask, by default False
    """
    if random_color:
        # Create a random color with some transparency
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        # Use a specific color with some transparency
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375)

Show points on the axis.

Parameters

coords : np.ndarray
The coordinates of the points to show.
labels : np.ndarray
The labels of the points.
ax : matplotlib.axes.Axes
The axis to plot on.
marker_size : int, optional
The size of the markers, by default 375
Expand source code
def show_points(coords, labels, ax, marker_size=375):
    """
    Show points on the axis.

    Parameters
    ----------
    coords : np.ndarray
        The coordinates of the points to show.
    labels : np.ndarray
        The labels of the points.
    ax : matplotlib.axes.Axes
        The axis to plot on.
    marker_size : int, optional
        The size of the markers, by default 375
    """
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )

Classes

class InferSAM (model_name: str = 'l0')

Class for inference with SAM models.

Parameters

model_dir : str
Directory containing trained SAM model.
model_name : str, default 'l0'
Name of the model to use. Must be one of ['l0', 'l1', 'l2', 'xl0', 'xl1'].

Attributes

model_name : str
Name of the model to use.
encoder : SamEncoder
The encoder part of the SAM model.
decoder : SamDecoder
The decoder part of the SAM model.
Expand source code
class InferSAM:
    """
    Class for inference with SAM models.

    Parameters
    ----------
    model_dir : str
        Directory containing trained SAM model.
    model_name : str, default 'l0'
        Name of the model to use.
        Must be one of ['l0', 'l1', 'l2', 'xl0', 'xl1'].

    Attributes
    ----------
    model_name : str
        Name of the model to use.
    encoder : SamEncoder
        The encoder part of the SAM model.
    decoder : SamDecoder
        The decoder part of the SAM model.

    """

    def __init__(self, model_name: str = "l0"):
        # assert model_dir is not None, "model_dir is null"
        assert model_name is not None, "model_name is null"

        self.model_name = model_name
        encoder_weights_path, decoder_weights_path = check_and_download_weights(model_name)
        
        # Find encoder and decoder models
        encoder_path = encoder_weights_path # glob.glob(model_dir + "/*_encoder.onnx")[0]
        decoder_path = decoder_weights_path # glob.glob(model_dir + "/*_decoder.onnx")[0]

        self.encoder = SamEncoder(encoder_path)
        self.decoder = SamDecoder(decoder_path)

        self.figsize = (10,10)

    def infer(
        self,
        img_path: str,
        boxes: List[list] = [[80, 50, 320, 420], [300, 20, 530, 420]],
        visualize=False,
    ) -> np.array:
        """
        Infer segmentation masks for a given image using the SAM model.

        Parameters
        ----------
        img_path : str
            Path to the input image.
        boxes : list of lists, default [[80, 50, 320, 420], [300, 20, 530, 420]]
            List of boxes, each box is a list of 4 ints, representing
            [xmax, ymax, xmin, ymin] coordinates.

        Returns
        -------
        masks : np.array
            A numpy array of shape (N, 1, H, W) containing segmentation masks,
            where N is the number of boxes, H and W are the height and width of
            the input image.

        """
        assert img_path is not None, "img_path is null"

        raw_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        assert raw_img is not None, "raw_img is null"

        origin_image_size = raw_img.shape[:2]

        img = None
        if self.model_name in ["l0", "l1", "l2"]:
            img = preprocess_np(raw_img, img_size=512)
        elif self.model_name in ["xl0", "xl1"]:
            img = preprocess_np(raw_img, img_size=1024)
        assert img is not None, "img is null"

        boxes = np.array(boxes, dtype=np.float32)  # xmax, ymax, xmin, ymin

        img_embeddings = self.encoder(img)
        masks, _, _ = self.decoder.run(
            img_embeddings=img_embeddings,
            origin_image_size=origin_image_size,
            boxes=boxes,
        )
        if visualize:
            plt.figure(figsize=self.figsize)
            plt.imshow(raw_img)
            for mask in masks:
                show_mask(mask, plt.gca(), 
                          random_color=True)
            for box in boxes:
                show_box(box, plt.gca())
            plt.show()
        return masks

    def set_figsize(self,figsize=(10,10)):
        self.figsize = figsize

Methods

def infer(self, img_path: str, boxes: List[list] = [[80, 50, 320, 420], [300, 20, 530, 420]], visualize=False) ‑> 

Infer segmentation masks for a given image using the SAM model.

Parameters

img_path : str
Path to the input image.
boxes : list of lists, default [[80, 50, 320, 420], [300, 20, 530, 420]]
List of boxes, each box is a list of 4 ints, representing [xmax, ymax, xmin, ymin] coordinates.

Returns

masks : np.array
A numpy array of shape (N, 1, H, W) containing segmentation masks, where N is the number of boxes, H and W are the height and width of the input image.
Expand source code
def infer(
    self,
    img_path: str,
    boxes: List[list] = [[80, 50, 320, 420], [300, 20, 530, 420]],
    visualize=False,
) -> np.array:
    """
    Infer segmentation masks for a given image using the SAM model.

    Parameters
    ----------
    img_path : str
        Path to the input image.
    boxes : list of lists, default [[80, 50, 320, 420], [300, 20, 530, 420]]
        List of boxes, each box is a list of 4 ints, representing
        [xmax, ymax, xmin, ymin] coordinates.

    Returns
    -------
    masks : np.array
        A numpy array of shape (N, 1, H, W) containing segmentation masks,
        where N is the number of boxes, H and W are the height and width of
        the input image.

    """
    assert img_path is not None, "img_path is null"

    raw_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    assert raw_img is not None, "raw_img is null"

    origin_image_size = raw_img.shape[:2]

    img = None
    if self.model_name in ["l0", "l1", "l2"]:
        img = preprocess_np(raw_img, img_size=512)
    elif self.model_name in ["xl0", "xl1"]:
        img = preprocess_np(raw_img, img_size=1024)
    assert img is not None, "img is null"

    boxes = np.array(boxes, dtype=np.float32)  # xmax, ymax, xmin, ymin

    img_embeddings = self.encoder(img)
    masks, _, _ = self.decoder.run(
        img_embeddings=img_embeddings,
        origin_image_size=origin_image_size,
        boxes=boxes,
    )
    if visualize:
        plt.figure(figsize=self.figsize)
        plt.imshow(raw_img)
        for mask in masks:
            show_mask(mask, plt.gca(), 
                      random_color=True)
        for box in boxes:
            show_box(box, plt.gca())
        plt.show()
    return masks
def set_figsize(self, figsize=(10, 10))
Expand source code
def set_figsize(self,figsize=(10,10)):
    self.figsize = figsize
class SamDecoder (model_path: str, device: str = 'cpu', target_size: int = 1024, mask_threshold: float = 0.0, **kwargs)

The decoder class that loads and runs the SAM decoder model.

Parameters

model_path : str
The path to the decoder model.
device : str, default="cpu"
The device to run the model, either "cuda" or "cpu".
target_size : int, default=1024
The target size of the output mask. The final mask size may be smaller if the original image is too small.
mask_threshold : float, default=0.0
The threshold value to binarize the output mask.
kwargs : Any
Additional arguments to be passed to onnxruntime.InferenceSession.

Attributes

target_size : int
The target size of the output mask.
mask_threshold : float
The threshold value to binarize the output mask.
session : onnxruntime.InferenceSession
The inference session of the loaded decoder model.
Expand source code
class SamDecoder:
    """
    The decoder class that loads and runs the SAM decoder model.

    Parameters
    ----------
    model_path: str
        The path to the decoder model.
    device: str, default="cpu"
        The device to run the model, either "cuda" or "cpu".
    target_size: int, default=1024
        The target size of the output mask. The final mask size may be
        smaller if the original image is too small.
    mask_threshold: float, default=0.0
        The threshold value to binarize the output mask.
    kwargs: Any
        Additional arguments to be passed to onnxruntime.InferenceSession.

    Attributes
    ----------
    target_size: int
        The target size of the output mask.
    mask_threshold: float
        The threshold value to binarize the output mask.
    session: onnxruntime.InferenceSession
        The inference session of the loaded decoder model.
    """

    def __init__(
        self,
        model_path: str,
        device: str = "cpu",
        target_size: int = 1024,
        mask_threshold: float = 0.0,
        **kwargs,
    ):
        opt = ort.SessionOptions()

        if device == "cuda":
            provider = ["CUDAExecutionProvider"]
        elif device == "cpu":
            provider = ["CPUExecutionProvider"]
        else:
            raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.")

        print(f"loading decoder model from {model_path}...")
        self.target_size = target_size
        self.mask_threshold = mask_threshold
        self.session = ort.InferenceSession(
            model_path, opt, providers=provider, **kwargs
        )

    @staticmethod
    def get_preprocess_shape(
        oldh: int, oldw: int, long_side_length: int
    ) -> Tuple[int, int]:
        """
        Compute the output size given input size and target long side length.

        Parameters
        ----------
        oldh: int
            The height of the input image.
        oldw: int
            The width of the input image.
        long_side_length: int
            The target long side length of the output image.

        Returns
        -------
        Tuple[int, int]
            The (height, width) of the output image after resizing.
        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

    def run(
        self,
        img_embeddings: np.ndarray,
        origin_image_size: Union[list, tuple],
        point_coords: Union[list, np.ndarray] = None,
        point_labels: Union[list, np.ndarray] = None,
        boxes: Union[list, np.ndarray] = None,
        return_logits: bool = False,
    ) -> Tuple[np.ndarray, Any, Any]:
        """
        Run the SAM decoder to segment an input image.

        Parameters
        ----------
        img_embeddings: np.ndarray
            The image embeddings obtained from SAM encoder.
            The shape should be (1, 256, 64, 64).
        origin_image_size: Union[list, tuple]
            The original size of the input image, (height, width)
        point_coords: Union[list, np.ndarray], optional
            The coordinates of the points in the input image.
            The shape should be (N, 2), where N is the number of points.
        point_labels: Union[list, np.ndarray], optional
            The labels of the points.
            The shape should be (N,) where N is the number of points.
        boxes: Union[list, np.ndarray], optional
            The coordinates of the bounding boxes in the input image.
            The shape should be (M, 4), where M is the number of boxes.
        return_logits: bool, default False
            Whether to return the logits (before sigmoid) of the mask predictions.

        Returns
        -------
        Tuple[np.ndarray, Any, Any]
            The segmentation masks, IoU scores, and low-resolution masks.
        """
        input_size = self.get_preprocess_shape(
            *origin_image_size, long_side_length=self.target_size
        )

        if point_coords is None and point_labels is None and boxes is None:
            raise ValueError(
                "Unable to segment, please input at least one box or point."
            )

        if img_embeddings.shape != (1, 256, 64, 64):
            raise ValueError("Got wrong embedding shape!")

        if point_coords is not None:
            point_coords = self.apply_coords(
                point_coords, origin_image_size, input_size
            ).astype(np.float32)

            prompts, labels = point_coords, point_labels

        if boxes is not None:
            boxes = self.apply_boxes(boxes, origin_image_size, input_size).astype(
                np.float32
            )
            box_labels = np.array(
                [[2, 3] for _ in range(boxes.shape[0])], dtype=np.float32
            ).reshape((-1, 2))

            if point_coords is not None:
                prompts = np.concatenate([prompts, boxes], axis=1)
                labels = np.concatenate([labels, box_labels], axis=1)
            else:
                prompts, labels = boxes, box_labels

        input_dict = {
            "image_embeddings": img_embeddings,
            "point_coords": prompts,
            "point_labels": labels,
        }

        # Run the inference
        low_res_masks, iou_predictions = self.session.run(None, input_dict)

        # Post-process the masks
        masks = np_mask_postprocessing(low_res_masks, np.array(origin_image_size))

        if not return_logits:
            masks = masks > self.mask_threshold

        return masks, iou_predictions, low_res_masks

    def apply_coords(self, coords, original_size, new_size):
        """
        Applies the resizing to the coordinates.

        Parameters
        ----------
        coords : np.ndarray
            The coordinates to be resized.
            The shape should be (N, 2), where N is the number of points.
        original_size : Union[list, tuple]
            The original size of the input image, (height, width)
        new_size : Union[list, tuple]
            The new size of the input image, (height, width)

        Returns
        -------
        np.ndarray
            The resized coordinates.
        """
        old_h, old_w = original_size
        new_h, new_w = new_size
        coords = deepcopy(coords).astype(float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    def apply_boxes(self, boxes, original_size, new_size):
        """
        Applies the resizing to the bounding boxes.

        Parameters
        ----------
        boxes : np.ndarray
            The coordinates of the bounding boxes in the input image.
            The shape should be (M, 4), where M is the number of boxes.
        original_size : Union[list, tuple]
            The original size of the input image, (height, width)
        new_size : Union[list, tuple]
            The new size of the input image, (height, width)

        Returns
        -------
        np.ndarray
            The resized bounding boxes.
        """
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, new_size)
        return boxes

Static methods

def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) ‑> Tuple[int, int]

Compute the output size given input size and target long side length.

Parameters

oldh : int
The height of the input image.
oldw : int
The width of the input image.
long_side_length : int
The target long side length of the output image.

Returns

Tuple[int, int]
The (height, width) of the output image after resizing.
Expand source code
@staticmethod
def get_preprocess_shape(
    oldh: int, oldw: int, long_side_length: int
) -> Tuple[int, int]:
    """
    Compute the output size given input size and target long side length.

    Parameters
    ----------
    oldh: int
        The height of the input image.
    oldw: int
        The width of the input image.
    long_side_length: int
        The target long side length of the output image.

    Returns
    -------
    Tuple[int, int]
        The (height, width) of the output image after resizing.
    """
    scale = long_side_length * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    return (newh, neww)

Methods

def apply_boxes(self, boxes, original_size, new_size)

Applies the resizing to the bounding boxes.

Parameters

boxes : np.ndarray
The coordinates of the bounding boxes in the input image. The shape should be (M, 4), where M is the number of boxes.
original_size : Union[list, tuple]
The original size of the input image, (height, width)
new_size : Union[list, tuple]
The new size of the input image, (height, width)

Returns

np.ndarray
The resized bounding boxes.
Expand source code
def apply_boxes(self, boxes, original_size, new_size):
    """
    Applies the resizing to the bounding boxes.

    Parameters
    ----------
    boxes : np.ndarray
        The coordinates of the bounding boxes in the input image.
        The shape should be (M, 4), where M is the number of boxes.
    original_size : Union[list, tuple]
        The original size of the input image, (height, width)
    new_size : Union[list, tuple]
        The new size of the input image, (height, width)

    Returns
    -------
    np.ndarray
        The resized bounding boxes.
    """
    boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, new_size)
    return boxes
def apply_coords(self, coords, original_size, new_size)

Applies the resizing to the coordinates.

Parameters

coords : np.ndarray
The coordinates to be resized. The shape should be (N, 2), where N is the number of points.
original_size : Union[list, tuple]
The original size of the input image, (height, width)
new_size : Union[list, tuple]
The new size of the input image, (height, width)

Returns

np.ndarray
The resized coordinates.
Expand source code
def apply_coords(self, coords, original_size, new_size):
    """
    Applies the resizing to the coordinates.

    Parameters
    ----------
    coords : np.ndarray
        The coordinates to be resized.
        The shape should be (N, 2), where N is the number of points.
    original_size : Union[list, tuple]
        The original size of the input image, (height, width)
    new_size : Union[list, tuple]
        The new size of the input image, (height, width)

    Returns
    -------
    np.ndarray
        The resized coordinates.
    """
    old_h, old_w = original_size
    new_h, new_w = new_size
    coords = deepcopy(coords).astype(float)
    coords[..., 0] = coords[..., 0] * (new_w / old_w)
    coords[..., 1] = coords[..., 1] * (new_h / old_h)
    return coords
def run(self, img_embeddings: numpy.ndarray, origin_image_size: Union[list, tuple], point_coords: Union[list, numpy.ndarray] = None, point_labels: Union[list, numpy.ndarray] = None, boxes: Union[list, numpy.ndarray] = None, return_logits: bool = False) ‑> Tuple[numpy.ndarray, Any, Any]

Run the SAM decoder to segment an input image.

Parameters

img_embeddings : np.ndarray
The image embeddings obtained from SAM encoder. The shape should be (1, 256, 64, 64).
origin_image_size : Union[list, tuple]
The original size of the input image, (height, width)
point_coords : Union[list, np.ndarray], optional
The coordinates of the points in the input image. The shape should be (N, 2), where N is the number of points.
point_labels : Union[list, np.ndarray], optional
The labels of the points. The shape should be (N,) where N is the number of points.
boxes : Union[list, np.ndarray], optional
The coordinates of the bounding boxes in the input image. The shape should be (M, 4), where M is the number of boxes.
return_logits : bool, default False
Whether to return the logits (before sigmoid) of the mask predictions.

Returns

Tuple[np.ndarray, Any, Any]
The segmentation masks, IoU scores, and low-resolution masks.
Expand source code
def run(
    self,
    img_embeddings: np.ndarray,
    origin_image_size: Union[list, tuple],
    point_coords: Union[list, np.ndarray] = None,
    point_labels: Union[list, np.ndarray] = None,
    boxes: Union[list, np.ndarray] = None,
    return_logits: bool = False,
) -> Tuple[np.ndarray, Any, Any]:
    """
    Run the SAM decoder to segment an input image.

    Parameters
    ----------
    img_embeddings: np.ndarray
        The image embeddings obtained from SAM encoder.
        The shape should be (1, 256, 64, 64).
    origin_image_size: Union[list, tuple]
        The original size of the input image, (height, width)
    point_coords: Union[list, np.ndarray], optional
        The coordinates of the points in the input image.
        The shape should be (N, 2), where N is the number of points.
    point_labels: Union[list, np.ndarray], optional
        The labels of the points.
        The shape should be (N,) where N is the number of points.
    boxes: Union[list, np.ndarray], optional
        The coordinates of the bounding boxes in the input image.
        The shape should be (M, 4), where M is the number of boxes.
    return_logits: bool, default False
        Whether to return the logits (before sigmoid) of the mask predictions.

    Returns
    -------
    Tuple[np.ndarray, Any, Any]
        The segmentation masks, IoU scores, and low-resolution masks.
    """
    input_size = self.get_preprocess_shape(
        *origin_image_size, long_side_length=self.target_size
    )

    if point_coords is None and point_labels is None and boxes is None:
        raise ValueError(
            "Unable to segment, please input at least one box or point."
        )

    if img_embeddings.shape != (1, 256, 64, 64):
        raise ValueError("Got wrong embedding shape!")

    if point_coords is not None:
        point_coords = self.apply_coords(
            point_coords, origin_image_size, input_size
        ).astype(np.float32)

        prompts, labels = point_coords, point_labels

    if boxes is not None:
        boxes = self.apply_boxes(boxes, origin_image_size, input_size).astype(
            np.float32
        )
        box_labels = np.array(
            [[2, 3] for _ in range(boxes.shape[0])], dtype=np.float32
        ).reshape((-1, 2))

        if point_coords is not None:
            prompts = np.concatenate([prompts, boxes], axis=1)
            labels = np.concatenate([labels, box_labels], axis=1)
        else:
            prompts, labels = boxes, box_labels

    input_dict = {
        "image_embeddings": img_embeddings,
        "point_coords": prompts,
        "point_labels": labels,
    }

    # Run the inference
    low_res_masks, iou_predictions = self.session.run(None, input_dict)

    # Post-process the masks
    masks = np_mask_postprocessing(low_res_masks, np.array(origin_image_size))

    if not return_logits:
        masks = masks > self.mask_threshold

    return masks, iou_predictions, low_res_masks
class SamEncoder (model_path: str, device: str = 'cpu', **kwargs)

The encoder class that loads and runs the SAM encoder model.

Parameters

model_path : str
The path to the encoder model.
device : str, optional (default is 'cpu')
The device to run the model, either 'cuda' or 'cpu'.
kwargs : dict
Additional arguments to be passed to the InferenceSession class from the onnxruntime library.

Attributes

session : InferenceSession
The loaded encoder model.
input_name : str
The name of the input layer of the model.
Expand source code
class SamEncoder:
    """
    The encoder class that loads and runs the SAM encoder model.

    Parameters
    ----------
    model_path: str
        The path to the encoder model.
    device: str, optional (default is 'cpu')
        The device to run the model, either 'cuda' or 'cpu'.
    kwargs: dict
        Additional arguments to be passed to the `InferenceSession` class from
        the onnxruntime library.

    Attributes
    ----------
    session: InferenceSession
        The loaded encoder model.
    input_name: str
        The name of the input layer of the model.
    """

    def __init__(self, model_path: str, device: str = "cpu", **kwargs):
        opt = ort.SessionOptions()

        if device == "cuda":
            provider = ["CUDAExecutionProvider"]
        elif device == "cpu":
            provider = ["CPUExecutionProvider"]
        else:
            raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.")

        print(f"loading encoder model from {model_path}...")
        self.session = ort.InferenceSession(
            model_path, opt, providers=provider, **kwargs
        )
        self.input_name = self.session.get_inputs()[0].name

    def _extract_feature(self, tensor: np.ndarray) -> np.ndarray:
        """
        Extract the feature from the input image tensor using the loaded
        encoder model.

        Parameters
        ----------
        tensor: numpy.ndarray
            The input image tensor.

        Returns
        -------
        feature: numpy.ndarray
            The feature extracted from the input image.
        """
        feature = self.session.run(None, {self.input_name: tensor})[0]
        return feature

    def __call__(self, img: np.array, *args: Any, **kwds: Any) -> Any:
        """
        Call the encoder with the input image.

        Parameters
        ----------
        img: numpy.ndarray
            The input image.
        args, kwargs:
            Additional positional and keyword arguments to be passed to the
            encoder.

        Returns
        -------
        feature: numpy.ndarray
            The feature extracted from the input image.
        """
        return self._extract_feature(img)