apsys commited on
Commit
240432b
1 Parent(s): 5c89480
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -19,14 +19,16 @@ def compute(dim):
19
 
20
  my_bar = st.progress(0)
21
 
22
- for idx in api.train(iters=10000):
23
- my_bar.progress(idx[0]/10000)
24
-
25
  samples = np.array(api.model.sample(
26
  torch.tensor(api.scaled).float()).detach())
 
 
27
 
28
  # fig, ax = plt.subplots()
29
- g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
30
 
31
  w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
32
  st.pyplot(w.get_figure())
@@ -39,9 +41,8 @@ def compute(dim):
39
 
40
  return api.scaler.inverse_transform(samples)
41
 
42
-
43
 
44
  if uploaded_file is not None:
45
- dim = pd.read_csv(uploaded_file).shape[-1]
46
- samples=compute(dim)
47
  st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
 
19
 
20
  my_bar = st.progress(0)
21
 
22
+ for idx in api.train(iters=iters):
23
+ my_bar.progress(idx[0]/iters)
24
+ my_bar.progress(100)
25
  samples = np.array(api.model.sample(
26
  torch.tensor(api.scaled).float()).detach())
27
+
28
+
29
 
30
  # fig, ax = plt.subplots()
31
+ g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=1000)
32
 
33
  w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
34
  st.pyplot(w.get_figure())
 
41
 
42
  return api.scaler.inverse_transform(samples)
43
 
 
44
 
45
  if uploaded_file is not None:
46
+ dims = len(uploaded_file.getvalue().decode("utf-8").split('\n')[0].split(','))-1
47
+ samples=compute(dims)
48
  st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')