ndhieunguyen commited on
Commit
7cacf8f
1 Parent(s): ad32d4f

feat: use gradio

Browse files
Files changed (2) hide show
  1. app.py +53 -46
  2. app_streamlit.py +111 -0
app.py CHANGED
@@ -5,12 +5,11 @@ from src.scripts.mytokenizers import Tokenizer
5
  from src.improved_diffusion import gaussian_diffusion as gd
6
  from src.improved_diffusion.respace import SpacedDiffusion
7
  from src.improved_diffusion.transformer_model import TransformerNetModel
8
- import streamlit as st
9
  import spaces
10
  import os
11
 
12
 
13
- @st.cache_resource
14
  def get_encoder(device):
15
  model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
16
  model.to(device)
@@ -18,12 +17,10 @@ def get_encoder(device):
18
  return model
19
 
20
 
21
- @st.cache_resource
22
  def get_tokenizer():
23
  return Tokenizer()
24
 
25
 
26
- @st.cache_resource
27
  def get_model(device):
28
  model = TransformerNetModel(
29
  in_channels=32,
@@ -45,7 +42,6 @@ def get_model(device):
45
  return model
46
 
47
 
48
- @st.cache_resource
49
  def get_diffusion():
50
  return SpacedDiffusion(
51
  use_timesteps=[i for i in range(0, 2000, 10)],
@@ -58,43 +54,44 @@ def get_diffusion():
58
  training_mode="e2e",
59
  )
60
 
 
61
  @spaces.GPU
62
  def generate(text_input):
63
- with st.spinner("Please wait..."):
64
- output = tokenizer(
65
- text_input,
66
- max_length=256,
67
- truncation=True,
68
- padding="max_length",
69
- add_special_tokens=True,
70
- return_tensors="pt",
71
- return_attention_mask=True,
72
- )
73
- caption_state = encoder(
74
- input_ids=output["input_ids"].to(device),
75
- attention_mask=output["attention_mask"].to(device),
76
- ).last_hidden_state
77
- caption_mask = output["attention_mask"]
78
-
79
- outputs = diffusion.p_sample_loop(
80
- model,
81
- (1, 256, 32),
82
- clip_denoised=False,
83
- denoised_fn=None,
84
- model_kwargs={},
85
- top_p=1.0,
86
- progress=True,
87
- caption=(caption_state.to(device), caption_mask.to(device)),
88
- )
89
- logits = model.get_logits(torch.tensor(outputs))
90
- cands = torch.topk(logits, k=1, dim=-1)
91
- outputs = cands.indices
92
- outputs = outputs.squeeze(-1)
93
- outputs = tokenizer.decode(outputs)
94
- result = sf.decoder(
95
- outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
96
- ).replace("\t", "")
97
- return result
98
 
99
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
100
 
@@ -103,9 +100,19 @@ encoder = get_encoder(device)
103
  model = get_model(device)
104
  diffusion = get_diffusion()
105
 
106
- st.title("Lang2mol-Diff")
107
- text_input = st.text_area("Enter molecule description")
108
- button = st.button("Submit")
109
- if button:
110
- result = generate(text_input)
111
- st.write(result)
 
 
 
 
 
 
 
 
 
 
 
5
  from src.improved_diffusion import gaussian_diffusion as gd
6
  from src.improved_diffusion.respace import SpacedDiffusion
7
  from src.improved_diffusion.transformer_model import TransformerNetModel
8
+ import gradio as gr
9
  import spaces
10
  import os
11
 
12
 
 
13
  def get_encoder(device):
14
  model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
15
  model.to(device)
 
17
  return model
18
 
19
 
 
20
  def get_tokenizer():
21
  return Tokenizer()
22
 
23
 
 
24
  def get_model(device):
25
  model = TransformerNetModel(
26
  in_channels=32,
 
42
  return model
43
 
44
 
 
45
  def get_diffusion():
46
  return SpacedDiffusion(
47
  use_timesteps=[i for i in range(0, 2000, 10)],
 
54
  training_mode="e2e",
55
  )
56
 
57
+
58
  @spaces.GPU
59
  def generate(text_input):
60
+ output = tokenizer(
61
+ text_input,
62
+ max_length=256,
63
+ truncation=True,
64
+ padding="max_length",
65
+ add_special_tokens=True,
66
+ return_tensors="pt",
67
+ return_attention_mask=True,
68
+ )
69
+ caption_state = encoder(
70
+ input_ids=output["input_ids"].to(device),
71
+ attention_mask=output["attention_mask"].to(device),
72
+ ).last_hidden_state
73
+ caption_mask = output["attention_mask"]
74
+
75
+ outputs = diffusion.p_sample_loop(
76
+ model,
77
+ (1, 256, 32),
78
+ clip_denoised=False,
79
+ denoised_fn=None,
80
+ model_kwargs={},
81
+ top_p=1.0,
82
+ progress=True,
83
+ caption=(caption_state.to(device), caption_mask.to(device)),
84
+ )
85
+ logits = model.get_logits(torch.tensor(outputs))
86
+ cands = torch.topk(logits, k=1, dim=-1)
87
+ outputs = cands.indices
88
+ outputs = outputs.squeeze(-1)
89
+ outputs = tokenizer.decode(outputs)
90
+ result = sf.decoder(
91
+ outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
92
+ ).replace("\t", "")
93
+ return result
94
+
95
 
96
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
97
 
 
100
  model = get_model(device)
101
  diffusion = get_diffusion()
102
 
103
+ # Create a Gradio interface
104
+ iface = gr.Interface(
105
+ fn=generate,
106
+ inputs="text",
107
+ outputs="text",
108
+ title="Lang2mol-Diff",
109
+ description="Enter molecule description",
110
+ examples=[
111
+ [
112
+ "The molecule is a apoptosis, cholesterol translocation, stabilizing mitochondrial structure that impacts barth syndrome and non-alcoholic fatty liver disease. The molecule is a stabilizing cytochrome oxidase and a proton trap for oxidative phosphorylation that impacts aging, diabetic heart disease, and tangier disease."
113
+ ],
114
+ ],
115
+ )
116
+
117
+ # Run the interface
118
+ iface.launch()
app_streamlit.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import selfies as sf
3
+ from transformers import T5EncoderModel
4
+ from src.scripts.mytokenizers import Tokenizer
5
+ from src.improved_diffusion import gaussian_diffusion as gd
6
+ from src.improved_diffusion.respace import SpacedDiffusion
7
+ from src.improved_diffusion.transformer_model import TransformerNetModel
8
+ import streamlit as st
9
+ import spaces
10
+ import os
11
+
12
+
13
+ @st.cache_resource
14
+ def get_encoder(device):
15
+ model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
16
+ model.to(device)
17
+ model.eval()
18
+ return model
19
+
20
+
21
+ @st.cache_resource
22
+ def get_tokenizer():
23
+ return Tokenizer()
24
+
25
+
26
+ @st.cache_resource
27
+ def get_model(device):
28
+ model = TransformerNetModel(
29
+ in_channels=32,
30
+ model_channels=128,
31
+ dropout=0.1,
32
+ vocab_size=35073,
33
+ hidden_size=1024,
34
+ num_attention_heads=16,
35
+ num_hidden_layers=12,
36
+ )
37
+ model.load_state_dict(
38
+ torch.load(
39
+ os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
40
+ map_location=torch.device(device),
41
+ )
42
+ )
43
+ model.to(device)
44
+ model.eval()
45
+ return model
46
+
47
+
48
+ @st.cache_resource
49
+ def get_diffusion():
50
+ return SpacedDiffusion(
51
+ use_timesteps=[i for i in range(0, 2000, 10)],
52
+ betas=gd.get_named_beta_schedule("sqrt", 2000),
53
+ model_mean_type=(gd.ModelMeanType.START_X),
54
+ model_var_type=((gd.ModelVarType.FIXED_LARGE)),
55
+ loss_type=gd.LossType.E2E_MSE,
56
+ rescale_timesteps=True,
57
+ model_arch="transformer",
58
+ training_mode="e2e",
59
+ )
60
+
61
+ @spaces.GPU
62
+ def generate(text_input):
63
+ with st.spinner("Please wait..."):
64
+ output = tokenizer(
65
+ text_input,
66
+ max_length=256,
67
+ truncation=True,
68
+ padding="max_length",
69
+ add_special_tokens=True,
70
+ return_tensors="pt",
71
+ return_attention_mask=True,
72
+ )
73
+ caption_state = encoder(
74
+ input_ids=output["input_ids"].to(device),
75
+ attention_mask=output["attention_mask"].to(device),
76
+ ).last_hidden_state
77
+ caption_mask = output["attention_mask"]
78
+
79
+ outputs = diffusion.p_sample_loop(
80
+ model,
81
+ (1, 256, 32),
82
+ clip_denoised=False,
83
+ denoised_fn=None,
84
+ model_kwargs={},
85
+ top_p=1.0,
86
+ progress=True,
87
+ caption=(caption_state.to(device), caption_mask.to(device)),
88
+ )
89
+ logits = model.get_logits(torch.tensor(outputs))
90
+ cands = torch.topk(logits, k=1, dim=-1)
91
+ outputs = cands.indices
92
+ outputs = outputs.squeeze(-1)
93
+ outputs = tokenizer.decode(outputs)
94
+ result = sf.decoder(
95
+ outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
96
+ ).replace("\t", "")
97
+ return result
98
+
99
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
100
+
101
+ tokenizer = get_tokenizer()
102
+ encoder = get_encoder(device)
103
+ model = get_model(device)
104
+ diffusion = get_diffusion()
105
+
106
+ st.title("Lang2mol-Diff")
107
+ text_input = st.text_area("Enter molecule description")
108
+ button = st.button("Submit")
109
+ if button:
110
+ result = generate(text_input)
111
+ st.write(result)