narugo commited on
Commit
9f181f5
1 Parent(s): 6453ee2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -41,12 +41,13 @@ example_images = sorted(
41
 
42
  def predict(
43
  image: Image.Image,
 
44
  threshold: float = 0.5,
45
  ):
46
  # join variant for cache key
47
- model, transform = load_model_and_transform(MODEL_REPO)
48
  # load labels
49
- labels: LabelData = load_labels_hf(MODEL_REPO)
50
  # preprocess image
51
  image = preprocess_image(image, (448, 448))
52
  image = transform(image).unsqueeze(0)
@@ -128,10 +129,10 @@ with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=c
128
  general = gr.Label(label="General")
129
 
130
  with gr.Row():
131
- examples = [[imgpath, 0.35] for imgpath in example_images]
132
  examples = gr.Examples(
133
  examples=examples,
134
- inputs=[img_input, threshold],
135
  )
136
 
137
  # tell clear button which components to clear
@@ -139,7 +140,7 @@ with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=c
139
 
140
  submit.click(
141
  predict,
142
- inputs=[img_input, threshold],
143
  outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
144
  api_name="predict",
145
  )
 
41
 
42
  def predict(
43
  image: Image.Image,
44
+ model_repo: str,
45
  threshold: float = 0.5,
46
  ):
47
  # join variant for cache key
48
+ model, transform = load_model_and_transform(model_repo)
49
  # load labels
50
+ labels: LabelData = load_labels_hf(model_repo)
51
  # preprocess image
52
  image = preprocess_image(image, (448, 448))
53
  image = transform(image).unsqueeze(0)
 
129
  general = gr.Label(label="General")
130
 
131
  with gr.Row():
132
+ examples = [[imgpath, MODEL_REPO, 0.35] for imgpath in example_images]
133
  examples = gr.Examples(
134
  examples=examples,
135
+ inputs=[img_input, model_to_use, threshold],
136
  )
137
 
138
  # tell clear button which components to clear
 
140
 
141
  submit.click(
142
  predict,
143
+ inputs=[img_input, model_to_use, threshold],
144
  outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
145
  api_name="predict",
146
  )