normflows / app.py
apsys's picture
app
5c89480
raw
history blame
1.56 kB
import streamlit as st
import torch
from normflows import nflow
import numpy as np
import seaborn as sns
import pandas as pd
uploaded_file = st.file_uploader("Choose original dataset")
col1,col2,col3 = st.columns(3)
bw = col1.number_input('Scale',value=3.05)
wd = col2.number_input('Weight Decay',value=0.0002)
iters = col3.number_input('Iterations',value=400)
def compute(dim):
api = nflow(dim=dim,latent=16,dataset=uploaded_file)
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=wd)
my_bar = st.progress(0)
for idx in api.train(iters=10000):
my_bar.progress(idx[0]/10000)
samples = np.array(api.model.sample(
torch.tensor(api.scaled).float()).detach())
# fig, ax = plt.subplots()
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
st.pyplot(w.get_figure())
def random_normal_samples(n, dim=2):
return torch.zeros(n, dim).normal_(mean=0, std=1)
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
return api.scaler.inverse_transform(samples)
if uploaded_file is not None:
dim = pd.read_csv(uploaded_file).shape[-1]
samples=compute(dim)
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')