File size: 5,093 Bytes
8c27766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import google.generativeai as genai
from google.generativeai.types import HarmBlockThreshold, HarmCategory
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import json

# Fetch bounding boxes and labels
async def get_bounding_boxes(prompt: str, image: str, api_key: str):
    system_prompt = """

You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else.

Your response can also include multiple bounding boxes and their labels in the list.

The values in the list should be integers.

Here are some example responses:

{

    "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.",

    "bounding_boxes": [

        {"label": "dragon", "box": [ymin, xmin, ymax, xmax]}

    ]

}

{

    "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.",

    "bounding_boxes": [

        {"label": "apple", "box": [ymin, xmin, ymax, xmax]},

        {"label": "tomato", "box": [ymin, xmin, ymax, xmax]}

    ]

}

""".strip()
    
    prompt = f"Return the bounding boxes and labels of: {prompt}"

    messages = [
        {"role": "user", "parts": [prompt, image]},
    ]

    genai.configure(api_key=api_key)

    generation_config = {
        "temperature": 1,
        "max_output_tokens": 8192,
        "response_mime_type": "application/json",
    }

    model = genai.GenerativeModel(
        model_name="gemini-1.5-flash",
        generation_config=generation_config,
        safety_settings={
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE
        },
        system_instruction=system_prompt
    )

    try:
        response = await model.generate_content_async(messages)
    except Exception as e:
        if "API key not valid" in str(e):
            raise gr.Error(
                "Invalid API key. Please provide a valid Gemini API key.")
        elif "rate limit" in str(e).lower():
            raise gr.Error("Rate limit exceeded for the API key.")
        else:
            raise gr.Error(f"Failed to generate content: {str(e)}")

    response_json = json.loads(response.text)

    explanation = response_json["explanation"]
    bounding_boxes = response_json["bounding_boxes"]

    return bounding_boxes, explanation

# Adjust bounding boxes based on image size
async def adjust_bounding_box(bounding_boxes, image):
    width, height = image.size
    adjusted_boxes = []
    for item in bounding_boxes:
        label = item["label"]
        ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]]
        xmin *= width
        xmax *= width
        ymin *= height
        ymax *= height
        adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]})
    return adjusted_boxes

# Process the image and draw bounding boxes and labels
async def process_image(image, text, api_key):
    if not api_key:
        raise gr.Error("Please provide a Gemini API key.")

    # Open the image using PIL
    image = Image.open(image)

    # Call the async bounding box function
    bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key)

    # Adjust the bounding box based on the image dimensions
    adjusted_boxes = await adjust_bounding_box(bounding_boxes, image)

    # Draw the bounding boxes and labels on the image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default(size=20)
    
    for item in adjusted_boxes:
        box = item["box"]
        label = item["label"]
        draw.rectangle(box, outline="red", width=3)
        # Draw the label above the bounding box
        draw.text((box[0], box[1] - 25), label, fill="red", font=font)

    # Format adjusted boxes for display
    adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes)

    return explanation, image, adjusted_boxes_str

# Gradio app
async def gradio_app(image, text, api_key):
    return await process_image(image, text, api_key)

# Launch the Gradio interface
iface = gr.Interface(
    fn=gradio_app,
    inputs=[
        gr.Image(type="filepath"),
        gr.Textbox(label="Object(s) to detect", value="person"),
        gr.Textbox(label="Your Gemini API Key", type="password")
    ],
    outputs=[
        gr.Textbox(label="Explanation"),
        gr.Image(type="pil", label="Output Image"),
        gr.Textbox(label="Coordinates and Labels of the Bounding Box(es)")
    ],
    title="Gemini Object Detection ✨",
    description="Detect objects in images using the Gemini 1.5 Flash model.",
    allow_flagging="never"
)

iface.launch()