File size: 1,253 Bytes
2d58eea
95aa084
9ba309e
 
e0734d3
9b99880
e0734d3
2d58eea
e0734d3
 
 
 
 
 
9b99880
e0734d3
9b99880
e0734d3
 
2d58eea
e0734d3
 
2d58eea
e0734d3
9b99880
e0734d3
 
9b99880
e0734d3
0606046
c74ba54
e0734d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import pipeline
import torch
from PIL import Image
import base64
from io import BytesIO

class EndpointHandler:
    def __init__(self, model_path=""):
        # Dynamically assign computing device based on availability.
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {'GPU: ' + torch.cuda.get_device_name(0) if self.device == 'cuda' else 'CPU'}")

        # Initialize model with the capability to automatically adjust to GPU or CPU.
        self.pipeline = pipeline("zero-shot-object-detection", model=model_path, device=0 if self.device == 'cuda' else -1)

    def __call__(self, data):
        """
        Decode image, run zero-shot object detection, and return results.

        Args:
            data (dict): Contains base64-encoded image and candidate labels.

        Returns:
            list[dict]: Each dict contains a label and its score from object detection.
        """
        # Decode the base64 image to PIL format.
        image = Image.open(BytesIO(base64.b64decode(data['inputs']['image'])))

        # Run detection and obtain results.
        results = self.pipeline(image=image, candidate_labels=data['inputs']['candidates'], threshold = .01)

        return results