import google.generativeai as genai from google.generativeai.types import HarmBlockThreshold, HarmCategory import gradio as gr from PIL import Image, ImageDraw, ImageFont import json # Fetch bounding boxes and labels async def get_bounding_boxes(prompt: str, image: str, api_key: str): system_prompt = """ You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else. Your response can also include multiple bounding boxes and their labels in the list. The values in the list should be integers. Here are some example responses: { "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.", "bounding_boxes": [ {"label": "dragon", "box": [ymin, xmin, ymax, xmax]} ] } { "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.", "bounding_boxes": [ {"label": "apple", "box": [ymin, xmin, ymax, xmax]}, {"label": "tomato", "box": [ymin, xmin, ymax, xmax]} ] } """.strip() prompt = f"Return the bounding boxes and labels of: {prompt}" messages = [ {"role": "user", "parts": [prompt, image]}, ] genai.configure(api_key=api_key) generation_config = { "temperature": 1, "max_output_tokens": 8192, "response_mime_type": "application/json", } model = genai.GenerativeModel( model_name="gemini-1.5-flash", generation_config=generation_config, safety_settings={ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE }, system_instruction=system_prompt ) try: response = await model.generate_content_async(messages) except Exception as e: if "API key not valid" in str(e): raise gr.Error( "Invalid API key. Please provide a valid Gemini API key.") elif "rate limit" in str(e).lower(): raise gr.Error("Rate limit exceeded for the API key.") else: raise gr.Error(f"Failed to generate content: {str(e)}") response_json = json.loads(response.text) explanation = response_json["explanation"] bounding_boxes = response_json["bounding_boxes"] return bounding_boxes, explanation # Adjust bounding boxes based on image size async def adjust_bounding_box(bounding_boxes, image): width, height = image.size adjusted_boxes = [] for item in bounding_boxes: label = item["label"] ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]] xmin *= width xmax *= width ymin *= height ymax *= height adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]}) return adjusted_boxes # Process the image and draw bounding boxes and labels async def process_image(image, text, api_key): if not api_key: raise gr.Error("Please provide a Gemini API key.") # Open the image using PIL image = Image.open(image) # Call the async bounding box function bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key) # Adjust the bounding box based on the image dimensions adjusted_boxes = await adjust_bounding_box(bounding_boxes, image) # Draw the bounding boxes and labels on the image draw = ImageDraw.Draw(image) font = ImageFont.load_default(size=20) for item in adjusted_boxes: box = item["box"] label = item["label"] draw.rectangle(box, outline="red", width=3) # Draw the label above the bounding box draw.text((box[0], box[1] - 25), label, fill="red", font=font) # Format adjusted boxes for display adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes) return explanation, image, adjusted_boxes_str # Gradio app async def gradio_app(image, text, api_key): return await process_image(image, text, api_key) # Launch the Gradio interface iface = gr.Interface( fn=gradio_app, inputs=[ gr.Image(type="filepath"), gr.Textbox(label="Object(s) to detect", value="person"), gr.Textbox(label="Your Gemini API Key", type="password") ], outputs=[ gr.Textbox(label="Explanation"), gr.Image(type="pil", label="Output Image"), gr.Textbox(label="Coordinates of the detected objects") ], title="Gemini Object Detection ✨", description="Detect objects in images using the Gemini 1.5 Flash model.", allow_flagging="never" ) iface.launch()