File size: 971 Bytes
e8e1c24
 
adff5a8
28b27d8
51090f6
28b27d8
e8e1c24
 
 
 
3a54e1f
e8e1c24
 
 
 
 
 
28b27d8
3005834
41efffb
28b27d8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from ldm.util import load_and_preprocess
from carvekit.api.high import HiInterface
import spaces


def load_preprocess_model():
    carvekit = HiInterface(object_type="object",  # Can be "object" or "hairs-like".
                        batch_size_seg=5,
                        batch_size_matting=1,
                        # device='cuda' if torch.cuda.is_available() else 'cpu',
                        device='cpu',
                        seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
                        matting_mask_size=2048,
                        trimap_prob_threshold=231,
                        trimap_dilation=30,
                        trimap_erosion_iters=5,
                        fp16=False)
    return carvekit
    
# @spaces.GPU
def preprocess_image(models, input_im):
    '''
    :param input_im (PIL Image).
    :return input_im (H, W, 3) array.
    '''
    input_im = load_and_preprocess(models, input_im)
    return input_im