ndhieunguyen commited on
Commit
ad32d4f
1 Parent(s): 23a7a4b

feat: use gpu space

Browse files
Files changed (1) hide show
  1. app.py +25 -21
app.py CHANGED
@@ -6,6 +6,7 @@ from src.improved_diffusion import gaussian_diffusion as gd
6
  from src.improved_diffusion.respace import SpacedDiffusion
7
  from src.improved_diffusion.transformer_model import TransformerNetModel
8
  import streamlit as st
 
9
  import os
10
 
11
 
@@ -57,28 +58,18 @@ def get_diffusion():
57
  training_mode="e2e",
58
  )
59
 
60
-
61
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
62
-
63
- tokenizer = get_tokenizer()
64
- encoder = get_encoder(device)
65
- model = get_model(device)
66
- diffusion = get_diffusion()
67
-
68
- st.title("Lang2mol-Diff")
69
- text_input = st.text_area("Enter molecule description")
70
- button = st.button("Submit")
71
- if button:
72
  with st.spinner("Please wait..."):
73
  output = tokenizer(
74
- text_input,
75
- max_length=256,
76
- truncation=True,
77
- padding="max_length",
78
- add_special_tokens=True,
79
- return_tensors="pt",
80
- return_attention_mask=True,
81
- )
82
  caption_state = encoder(
83
  input_ids=output["input_ids"].to(device),
84
  attention_mask=output["attention_mask"].to(device),
@@ -103,5 +94,18 @@ if button:
103
  result = sf.decoder(
104
  outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
105
  ).replace("\t", "")
 
 
 
 
 
 
 
 
106
 
107
- st.write(result)
 
 
 
 
 
 
6
  from src.improved_diffusion.respace import SpacedDiffusion
7
  from src.improved_diffusion.transformer_model import TransformerNetModel
8
  import streamlit as st
9
+ import spaces
10
  import os
11
 
12
 
 
58
  training_mode="e2e",
59
  )
60
 
61
+ @spaces.GPU
62
+ def generate(text_input):
 
 
 
 
 
 
 
 
 
 
63
  with st.spinner("Please wait..."):
64
  output = tokenizer(
65
+ text_input,
66
+ max_length=256,
67
+ truncation=True,
68
+ padding="max_length",
69
+ add_special_tokens=True,
70
+ return_tensors="pt",
71
+ return_attention_mask=True,
72
+ )
73
  caption_state = encoder(
74
  input_ids=output["input_ids"].to(device),
75
  attention_mask=output["attention_mask"].to(device),
 
94
  result = sf.decoder(
95
  outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
96
  ).replace("\t", "")
97
+ return result
98
+
99
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
100
+
101
+ tokenizer = get_tokenizer()
102
+ encoder = get_encoder(device)
103
+ model = get_model(device)
104
+ diffusion = get_diffusion()
105
 
106
+ st.title("Lang2mol-Diff")
107
+ text_input = st.text_area("Enter molecule description")
108
+ button = st.button("Submit")
109
+ if button:
110
+ result = generate(text_input)
111
+ st.write(result)