oscarwang2 commited on
Commit
dbc69f1
1 Parent(s): c5963a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -72
app.py CHANGED
@@ -1,94 +1,79 @@
1
- import numpy as np
 
 
 
2
  import pandas as pd
3
  import logging
4
- import matplotlib.pyplot as plt
5
- from statsmodels.tsa.arima.model import ARIMA
6
- import yfinance as yf
7
- import gradio as gr
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
- def fetch_data(period='1d'):
12
- logging.info(f"Fetching data for the period {period}...")
13
- data = yf.download(tickers='ETH-USD', period=period, interval='1m')
14
- if data.empty:
15
- logging.error("No data fetched. Check the period or ticker symbol.")
16
- return None
17
- logging.info(f"Fetched {len(data)} data points for the period {period}.")
18
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def make_predictions(data, predict_steps, freq):
21
  if data is None or data.empty:
22
- logging.error("No data available for model training.")
23
- return None
24
 
25
  logging.info(f"Starting model training with {len(data)} data points...")
26
-
27
- # Check for NaN values in the data
28
- if data['Close'].isna().any():
29
- logging.error("Data contains NaN values. Please clean the data before model training.")
30
- return None
31
-
32
- try:
33
- model = ARIMA(data['Close'], order=(5, 1, 0))
34
- model_fit = model.fit()
35
- logging.info(model_fit.summary())
36
- except Exception as e:
37
- logging.error(f"Model training failed: {e}")
38
- return None
39
-
40
  logging.info("Model training completed.")
41
-
42
- logging.info("Generating predictions...")
43
- try:
44
- forecast = model_fit.forecast(steps=predict_steps)
45
- if np.isnan(forecast).any():
46
- logging.error("Generated predictions contain NaN values. Model might be improperly configured.")
47
- return None
48
- except Exception as e:
49
- logging.error(f"Prediction generation failed: {e}")
50
- return None
51
-
52
- future_dates = pd.date_range(start=data.index[-1], periods=predict_steps + 1, freq=freq, inclusive='right')
53
  forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
54
-
55
- logging.info(f"Forecast Data:\n{forecast_df.head()}")
56
  logging.info("Predictions generated successfully.")
57
-
58
  return forecast_df
59
 
60
- def plot_eth(period='1d'):
61
- data = fetch_data(period)
62
- predict_steps = 5 # Modify as needed
63
- freq = 'T' # 'T' stands for minutes
64
-
65
  forecast_df = make_predictions(data, predict_steps, freq)
66
- if forecast_df is None:
67
- logging.error("Failed to generate predictions.")
68
- return None
69
 
70
- plt.figure(figsize=(10, 5))
71
- plt.plot(data['Close'], label='Actual ETH Price')
72
- plt.plot(forecast_df['Prediction'], label='Forecasted ETH Price', linestyle='dotted', color='orange')
73
- plt.title('ETH Price Prediction')
74
- plt.xlabel('Time')
75
- plt.ylabel('Price (USD)')
76
- plt.legend()
77
- plt.grid(True)
78
- plt.tight_layout()
79
-
80
- # Save the plot to a file
81
- plot_filename = '/home/user/app/eth_price_prediction.png'
82
- plt.savefig(plot_filename)
83
- logging.info("Plotting completed.")
84
 
85
- return plot_filename
 
86
 
87
  def refresh_predictions(period):
88
- plot_filename = plot_eth(period)
89
- if plot_filename is None:
90
- return "Error in generating plot."
91
- return plot_filename
92
 
93
- iface = gr.Interface(fn=refresh_predictions, inputs="text", outputs="image", live=True)
 
 
 
 
 
 
 
94
  iface.launch()
 
1
+ import gradio as gr
2
+ import yfinance as yf
3
+ import plotly.graph_objects as go
4
+ from statsmodels.tsa.arima.model import ARIMA
5
  import pandas as pd
6
  import logging
 
 
 
 
7
 
8
  logging.basicConfig(level=logging.INFO)
9
 
10
+ def fetch_eth_price(period):
11
+ eth = yf.Ticker("ETH-USD")
12
+ if period == '1d':
13
+ data = eth.history(period="1d", interval="1m")
14
+ predict_steps = 60 # Next 60 minutes
15
+ freq = 'min' # Minute frequency
16
+ elif period == '5d':
17
+ data = eth.history(period="5d", interval="15m")
18
+ predict_steps = 96 # Next 24 hours
19
+ freq = '15min' # 15 minutes frequency
20
+ elif period == '1wk':
21
+ data = eth.history(period="1wk", interval="30m")
22
+ predict_steps = 336 # Next 7 days
23
+ freq = '30min' # 30 minutes frequency
24
+ elif period == '1mo':
25
+ data = eth.history(period="1mo", interval="1h")
26
+ predict_steps = 720 # Next 30 days
27
+ freq = 'H' # Hourly frequency
28
+ else:
29
+ return None, None, None
30
+
31
+ data.index = pd.DatetimeIndex(data.index)
32
+ data = data.asfreq(freq) # Ensure the data has a consistent frequency
33
+
34
+ # Limit the data to the last 200 points to reduce prediction time
35
+ data = data[-200:]
36
+
37
+ return data, predict_steps, freq
38
 
39
  def make_predictions(data, predict_steps, freq):
40
  if data is None or data.empty:
41
+ logging.error("No data available for prediction.")
42
+ return pd.DataFrame(index=pd.date_range(start=pd.Timestamp.now(), periods=predict_steps+1, freq=freq)[1:])
43
 
44
  logging.info(f"Starting model training with {len(data)} data points...")
45
+ model = ARIMA(data['Close'], order=(5, 1, 0))
46
+ model_fit = model.fit()
 
 
 
 
 
 
 
 
 
 
 
 
47
  logging.info("Model training completed.")
48
+
49
+ forecast = model_fit.forecast(steps=predict_steps)
50
+ future_dates = pd.date_range(start=data.index[-1], periods=predict_steps+1, freq=freq, inclusive='right')
 
 
 
 
 
 
 
 
 
51
  forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
52
+
 
53
  logging.info("Predictions generated successfully.")
 
54
  return forecast_df
55
 
56
+ def plot_eth(period):
57
+ data, predict_steps, freq = fetch_eth_price(period)
 
 
 
58
  forecast_df = make_predictions(data, predict_steps, freq)
 
 
 
59
 
60
+ fig = go.Figure()
61
+ fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='ETH Price'))
62
+ fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Prediction'], mode='lines', name='Prediction', line=dict(dash='dash', color='orange')))
63
+ fig.update_layout(title=f"ETH Price and Predictions ({period})", xaxis_title="Date", yaxis_title="Price (USD)")
 
 
 
 
 
 
 
 
 
 
64
 
65
+ logging.info("Plotting completed.")
66
+ return fig
67
 
68
  def refresh_predictions(period):
69
+ return plot_eth(period)
 
 
 
70
 
71
+ with gr.Blocks() as iface:
72
+ period = gr.Radio(["1d", "5d", "1wk", "1mo"], label="Select Period")
73
+ plot = gr.Plot()
74
+ refresh_button = gr.Button("Refresh Predictions and Prices")
75
+
76
+ period.change(fn=plot_eth, inputs=period, outputs=plot)
77
+ refresh_button.click(fn=refresh_predictions, inputs=period, outputs=plot)
78
+
79
  iface.launch()