tinysam / demo_hierachical_everything.py
merve's picture
merve HF staff
Upload 19 files
cd6bcbd
raw
history blame contribute delete
No virus
1.07 kB
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
import sys
sys.path.append("..")
from tinysam import sam_model_registry, SamHierarchicalMaskGenerator
model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint="./weights/tinysam.pth")
sam.eval()
mask_generator = SamHierarchicalMaskGenerator(sam)
image = cv2.imread('fig/picture3.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.hierarchical_generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.savefig("test_everthing.png")