dokster commited on
Commit
9f8e9e6
β€’
1 Parent(s): cbafdbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -29
app.py CHANGED
@@ -72,46 +72,53 @@ def main():
72
 
73
  st.write("---")
74
 
75
- question = st.text_input("❔ Enter question prompt: ", "")
 
 
 
 
 
 
76
 
77
- try:
78
- tfile = tempfile.NamedTemporaryFile(delete=False)
79
- tfile.write(uploaded_file.read())
 
80
 
81
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
82
- val_embeddings = []
83
- val_captions = []
84
- result = ''
85
- text = f'Question: {question}? Answer:'
86
 
87
- #read video -> get_ans
88
- video = read_video(tfile.name, transform=None, frames_num=4)
89
 
90
- if len(video) > 0:
91
- i = image_grid(video, 2, 2)
92
- image = preprocess(i).unsqueeze(0).to(device)
93
 
94
- with torch.no_grad():
95
- prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
96
 
97
- val_embeddings.append(prefix)
98
- val_captions.append(text)
99
 
100
- answers = []
101
 
102
- for i in tqdm(range(len(val_embeddings))):
103
- emb = val_embeddings[i]
104
- caption = val_captions[i]
105
 
106
- ans = get_ans(model, tokenizer, emb, prefix_length, caption)
107
- answers.append(ans['answer'])
108
 
109
- result = answers[0].split(' A: ')[0]
110
-
111
- res = st.text_input('βœ… Answer to the question', result, disabled=False)
112
 
113
- except:
114
- pass
115
 
116
  if __name__ == '__main__':
117
  main()
 
72
 
73
  st.write("---")
74
 
75
+ a, b = st.columns([4, 1])
76
+ question = a.text_input(
77
+ label="❔ Enter question prompt: ",
78
+ placeholder="",
79
+ # label_visibility="collapsed",
80
+ )
81
+ button = b.button("Send", use_container_width=True)
82
 
83
+ if button:
84
+ try:
85
+ tfile = tempfile.NamedTemporaryFile(delete=False)
86
+ tfile.write(uploaded_file.read())
87
 
88
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
89
+ val_embeddings = []
90
+ val_captions = []
91
+ result = ''
92
+ text = f'Question: {question}? Answer:'
93
 
94
+ # read video -> get_ans
95
+ video = read_video(tfile.name, transform=None, frames_num=4)
96
 
97
+ if len(video) > 0:
98
+ i = image_grid(video, 2, 2)
99
+ image = preprocess(i).unsqueeze(0).to(device)
100
 
101
+ with torch.no_grad():
102
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
103
 
104
+ val_embeddings.append(prefix)
105
+ val_captions.append(text)
106
 
107
+ answers = []
108
 
109
+ for i in tqdm(range(len(val_embeddings))):
110
+ emb = val_embeddings[i]
111
+ caption = val_captions[i]
112
 
113
+ ans = get_ans(model, tokenizer, emb, prefix_length, caption)
114
+ answers.append(ans['answer'])
115
 
116
+ result = answers[0].split(' A: ')[0]
117
+
118
+ res = st.text_input('βœ… Answer to the question', result, disabled=False)
119
 
120
+ except:
121
+ pass
122
 
123
  if __name__ == '__main__':
124
  main()