from huggingface_hub import snapshot_download import gradio as gr import numpy as np import torch import sys from tinysam import sam_model_registry, SamPredictor snapshot_download("merve/tinysam", local_dir="tinysam") model_type = "vit_t" sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth") predictor = SamPredictor(sam) def infer(img): # background (original image) layers[0] ( point prompt) composite (total image) image = img["background"].convert("RGB") point_prompt = img["layers"][0] total_image = img["composite"] #torch_img = torch.from_numpy(np.array(image)) #torch_img = torch_img.permute(2, 0, 1) predictor.set_image(np.array(image)) # get point prompt img_arr = np.array(point_prompt) nonzero_indices = np.nonzero(img_arr) center_x = int(np.mean(nonzero_indices[1])) center_y = int(np.mean(nonzero_indices[0])) input_point = np.array([[center_x, center_y]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, ) result_label = [(masks[0, :, :], "mask")] return image, result_label with gr.Blocks() as demo: with gr.Row(): with gr.Column(): im = gr.ImageEditor( type="pil" ) submit_btn = gr.Button() output = gr.AnnotatedImage() submit_btn.click(infer, inputs=im, outputs=gr.AnnotatedImage()) demo.launch(debug=True)