AnyegaAlex's picture
info parameter is still being passed to gr.File thus we removed it
599bf50
import os
import logging
import pandas as pd
import numpy as np
from transformers import pipeline
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import accuracy_score, f1_score
from sklearn.base import clone
import joblib
import torch
from tqdm import tqdm
import warnings
import gradio as gr
import matplotlib.pyplot as plt
import shap
from datetime import datetime, timedelta
import sys
from wordcloud import WordCloud
from typing import Optional, Tuple, Dict, List
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
# --------------------------
# Configuration & Setup
# --------------------------
class Config:
DATA_PATH = "data/ibm_cleaned.parquet"
MODEL_PATH = "models/stock_prediction_model.joblib"
CACHE_DIR = "./cache"
LOGS_DIR = "./logs"
PLOTS_DIR = "./plots"
SENTIMENT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
SENTIMENT_BATCH_SIZE = 8
SENTIMENT_MAX_LEN = 256
CONFIDENCE_THRESHOLD = 0.3
DATA_REFRESH_DAYS = 7
MAX_HISTORY_ENTRIES = 50
@classmethod
def setup(cls):
os.makedirs(cls.CACHE_DIR, exist_ok=True)
os.makedirs(cls.LOGS_DIR, exist_ok=True)
os.makedirs(cls.PLOTS_DIR, exist_ok=True)
os.makedirs(os.path.dirname(cls.MODEL_PATH), exist_ok=True)
Config.setup()
# Logger config
logging.basicConfig(
filename=os.path.join(Config.LOGS_DIR, 'app.log'),
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
# --------------------------
# Data loading and optimization
# --------------------------
def load_data(test_mode: bool = False, force_refresh: bool = False) -> pd.DataFrame:
"""Load and preprocess the stock data with caching and validation."""
try:
# Check if data needs refresh
if not force_refresh and os.path.exists(Config.DATA_PATH):
file_age = datetime.now() - datetime.fromtimestamp(os.path.getmtime(Config.DATA_PATH))
if file_age.days < Config.DATA_REFRESH_DAYS:
logger.info("Using cached data")
df = pd.read_parquet(Config.DATA_PATH)
# Validate required columns
required_cols = {'Open', 'High', 'Low', 'Close', 'Volume', 'Date', 'News'}
missing_cols = required_cols - set(df.columns)
if missing_cols:
raise ValueError(f"Missing columns in data: {missing_cols}")
# Optimize dtypes
num_cols = ['Open', 'High', 'Low', 'Close']
df[num_cols] = df[num_cols].astype('float32')
if (df['Volume'] < 0).any():
logger.warning("Negative values found in 'Volume'; downcasting as signed int.")
df['Volume'] = pd.to_numeric(df['Volume'], downcast='integer')
else:
df['Volume'] = pd.to_numeric(df['Volume'], downcast='unsigned')
if not pd.api.types.is_datetime64_any_dtype(df['Date']):
df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
if df['Date'].isna().any():
raise ValueError("Date column contains invalid dates after parsing.")
# Data quality checks
if df.isnull().sum().sum() > 0:
logger.warning("Data contains null values - filling with forward fill")
df = df.ffill()
if test_mode:
sample_size = min(5000, len(df))
if len(df) < 5000:
logger.warning(f"Data has only {len(df)} rows - using full dataset in test mode")
return df.sample(sample_size, random_state=42).sort_values('Date').reset_index(drop=True)
return df.sort_values('Date').reset_index(drop=True)
except Exception as e:
logger.error(f"Data loading failed: {e}")
raise
# --------------------------
# Sentiment Analysis
# --------------------------
class SentimentAnalyzer:
"""Handles sentiment analysis with caching and batching."""
def __init__(self):
self._initialize_pipeline()
self.sentiment_cache = {}
def _initialize_pipeline(self):
"""Initialize the sentiment analysis pipeline."""
try:
self.pipeline = pipeline(
"text-classification",
model=Config.SENTIMENT_MODEL,
device=0 if torch.cuda.is_available() else -1,
batch_size=Config.SENTIMENT_BATCH_SIZE,
truncation=True,
max_length=Config.SENTIMENT_MAX_LEN,
)
except Exception as e:
logger.error(f"Sentiment model load failed: {e}")
self.pipeline = None
raise
def analyze_texts(self, texts: List[str]) -> np.ndarray:
"""Analyze sentiment for a batch of texts with caching."""
if self.pipeline is None or not texts:
return np.zeros(len(texts), dtype='float32')
# Check cache first
uncached_texts = []
uncached_indices = []
cached_scores = np.zeros(len(texts), dtype='float32')
for i, text in enumerate(texts):
text = str(text)[:Config.SENTIMENT_MAX_LEN] if text else ""
if text in self.sentiment_cache:
cached_scores[i] = self.sentiment_cache[text]
else:
uncached_texts.append(text)
uncached_indices.append(i)
# Process uncached texts
if uncached_texts:
try:
results = self.pipeline(
uncached_texts,
truncation=True,
max_length=Config.SENTIMENT_MAX_LEN,
batch_size=Config.SENTIMENT_BATCH_SIZE
)
for idx, res in zip(uncached_indices, results):
label = res[0]['label']
score = res[0]['score']
sentiment_score = score if label == "POSITIVE" else -score
cached_scores[idx] = sentiment_score
self.sentiment_cache[uncached_texts[idx]] = sentiment_score
except Exception as e:
logger.error(f"Sentiment analysis error: {e}")
return cached_scores
# --------------------------
# Feature Engineering
# --------------------------
def calculate_rsi(series: pd.Series, window: int = 14) -> pd.Series:
"""Calculate Relative Strength Index."""
try:
series = pd.to_numeric(series, errors='coerce').ffill()
delta = series.diff()
gain = delta.clip(lower=0)
loss = -delta.clip(upper=0)
avg_gain = gain.rolling(window, min_periods=1).mean()
avg_loss = loss.rolling(window, min_periods=1).mean().replace(0, 1e-10)
rs = avg_gain / avg_loss
return (100 - (100 / (1 + rs))).astype('float32')
except Exception as e:
logger.error(f"RSI calculation failed: {e}")
return pd.Series(np.nan, index=series.index)
def calculate_macd(
series: pd.Series,
fast: int = 12,
slow: int = 26,
signal: int = 9
) -> Tuple[pd.Series, pd.Series, pd.Series]:
"""Calculate MACD line, signal line, and histogram."""
try:
series = pd.to_numeric(series, errors='coerce').ffill()
ema_fast = series.ewm(span=fast, adjust=False).mean()
ema_slow = series.ewm(span=slow, adjust=False).mean()
macd_line = ema_fast - ema_slow
signal_line = macd_line.ewm(span=signal, adjust=False).mean()
return (
macd_line.astype('float32'),
signal_line.astype('float32'),
(macd_line - signal_line).astype('float32')
)
except Exception as e:
logger.error(f"MACD calculation failed: {e}")
return (
pd.Series(np.nan, index=series.index),
pd.Series(np.nan, index=series.index),
pd.Series(np.nan, index=series.index)
)
def calculate_bollinger_bands(
series: pd.Series,
window: int = 20,
no_of_std: int = 2
) -> Tuple[pd.Series, pd.Series]:
"""Calculate Bollinger Bands."""
try:
series = pd.to_numeric(series, errors='coerce').ffill()
rolling_mean = series.rolling(window, min_periods=1).mean()
rolling_std = series.rolling(window, min_periods=1).std()
return (
(rolling_mean + (no_of_std * rolling_std)).astype('float32'),
(rolling_mean - (no_of_std * rolling_std)).astype('float32')
)
except Exception as e:
logger.error(f"Bollinger Bands calculation failed: {e}")
return (
pd.Series(np.nan, index=series.index),
pd.Series(np.nan, index=series.index)
)
def calculate_volatility(series: pd.Series, window: int = 10) -> pd.Series:
"""Calculate price volatility."""
try:
series = pd.to_numeric(series, errors='coerce').ffill()
return series.pct_change().rolling(window, min_periods=1).std().astype('float32')
except Exception as e:
logger.error(f"Volatility calculation failed: {e}")
return pd.Series(np.nan, index=series.index)
def create_features(df: pd.DataFrame, analyzer: SentimentAnalyzer) -> pd.DataFrame:
"""Create all features for the prediction model."""
try:
df = df.copy()
# Validate required columns
required_cols = {'Close', 'News'}
missing_cols = required_cols - set(df.columns)
if missing_cols:
raise ValueError(f"Missing columns for feature creation: {missing_cols}")
# Price features
df['Price_Change'] = df['Close'].pct_change().astype('float32')
df['Log_Return'] = np.log1p(df['Price_Change'].clip(lower=-0.9999)).astype('float32')
df['MA_5'] = df['Close'].rolling(5, min_periods=1).mean().astype('float32')
df['MA_20'] = df['Close'].rolling(20, min_periods=1).mean().astype('float32')
df['RSI'] = calculate_rsi(df['Close'])
df['Volatility_10'] = calculate_volatility(df['Close'], 10)
# MACD features
macd_line, signal_line, hist = calculate_macd(df['Close'])
df['MACD'] = macd_line
df['MACD_Signal'] = signal_line
df['MACD_Hist'] = hist
# Bollinger Bands
bb_upper, bb_lower = calculate_bollinger_bands(df['Close'])
df['BB_Upper'] = bb_upper
df['BB_Lower'] = bb_lower
# Sentiment features
if 'News' not in df or df['News'].isnull().all():
logger.warning("No news data - setting Sentiment features to zero")
df['Sentiment'] = 0.0
else:
df['Sentiment'] = analyzer.analyze_texts(df['News'].fillna('').tolist())
df['Sentiment_MA'] = df['Sentiment'].rolling(5, min_periods=1).mean().astype('float32')
df['Sentiment_Lag1'] = df['Sentiment'].shift(1).fillna(0).astype('float32')
# Target
df['Target'] = (df['Close'].shift(-1) > df['Close']).astype('int8')
return df.dropna().reset_index(drop=True)
except Exception as e:
logger.error(f"Feature engineering error: {e}")
raise
# --------------------------
# Model Training
# --------------------------
def train_model() -> HistGradientBoostingClassifier:
"""Train and evaluate the prediction model."""
try:
logger.info("Starting model training...")
df = load_data(test_mode=True)
# Train-test split maintaining temporal order
split_idx = int(0.8 * len(df))
train_df = df.iloc[:split_idx].copy()
test_df = df.iloc[split_idx:].copy()
analyzer = SentimentAnalyzer()
features = []
# Process in chunks for memory efficiency
chunk_size = 1000
for i in tqdm(range(0, len(train_df), chunk_size), desc="Creating features"):
chunk = train_df.iloc[i:i+chunk_size]
features.append(create_features(chunk, analyzer))
train_processed = pd.concat(features)
# Define expected features
expected_features = [
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change',
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10',
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower', 'Target'
]
# Validate features
missing_cols = [c for c in expected_features if c not in train_processed.columns]
if missing_cols:
raise ValueError(f"Missing features in training data: {missing_cols}")
# Ensure only numeric data is passed to model
train_processed = train_processed[expected_features + ['Date']] # Keep 'Date' only for dropping
X_train = train_processed.drop(columns=['Date', 'Target'])
y_train = train_processed['Target']
# Initialize model with good defaults
base_model = HistGradientBoostingClassifier(
max_iter=100,
max_depth=5,
learning_rate=0.05,
random_state=42,
verbose=1,
early_stopping=True,
validation_fraction=0.1
)
# Time-series cross-validation
tscv = TimeSeriesSplit(n_splits=3)
f1_scores = []
for train_idx, val_idx in tscv.split(X_train):
model = clone(base_model)
model.fit(X_train.iloc[train_idx], y_train.iloc[train_idx])
preds = model.predict(X_train.iloc[val_idx])
f1_scores.append(f1_score(y_train.iloc[val_idx], preds))
logger.info(f"Fold F1: {f1_scores[-1]:.3f}")
logger.info(f"Validation F1: {np.mean(f1_scores):.3f} ± {np.std(f1_scores):.3f}")
# Final training
model = clone(base_model)
model.fit(X_train, y_train)
# Test evaluation
test_processed = create_features(test_df, analyzer)
X_test = test_processed[expected_features[:-1]]
y_test = test_processed['Target']
# Drop Date if present
if 'Date' in X_test.columns:
X_test = X_test.drop(columns=['Date'])
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
logger.info(f"Test Accuracy: {acc:.3f} | F1: {f1:.3f}")
# Save model
joblib.dump(model, Config.MODEL_PATH)
logger.info(f"Model saved to {Config.MODEL_PATH}")
return model
except Exception as e:
logger.error(f"Model training failed: {e}")
raise
# --------------------------
# Model Prediction
# --------------------------
def predict(df: pd.DataFrame, threshold: float = Config.CONFIDENCE_THRESHOLD) -> pd.DataFrame:
"""Make predictions with confidence scores."""
try:
if model is None:
raise ValueError("Model is not loaded")
df_features = create_features(df, analyzer)
required_features = [
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change',
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10',
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower'
]
X = df_features[required_features]
proba = model.predict_proba(X)[:, 1] # Probability of 'up' class
# Generate predictions with confidence threshold
preds = np.where(
proba > (0.5 + threshold), "Buy",
np.where(proba < (0.5 - threshold), "Sell", "Hold")
)
df_features['Prediction'] = preds
df_features['Confidence'] = np.abs(proba - 0.5) * 2 # Normalized to 0-1
return df_features[['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment']]
except Exception as e:
logger.error(f"Prediction error: {e}")
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment'])
# --------------------------
# Visualization Functions
# --------------------------
def generate_shap_plot(df: pd.DataFrame) -> Optional[str]:
"""Generate SHAP explanation plot."""
try:
df = df.tail(50) # Use most recent 50 samples
required_features = [
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change',
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10',
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower'
]
X = df[required_features]
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X, show=False, plot_size=(10, 6))
plot_path = os.path.join(Config.PLOTS_DIR, "shap_summary.png")
plt.savefig(plot_path, bbox_inches='tight')
plt.close()
return plot_path
except Exception as e:
logger.error(f"SHAP plot generation failed: {e}")
return None
def plot_sentiment(df: pd.DataFrame) -> Optional[str]:
"""Generate sentiment trend plot."""
try:
plt.figure(figsize=(10, 4))
plt.plot(df['Date'], df['Sentiment'], label="Daily Sentiment")
plt.plot(df['Date'], df['Sentiment_MA'], label="5-Day MA", alpha=0.7)
plt.title("News Sentiment Over Time")
plt.xlabel("Date")
plt.ylabel("Sentiment Score")
plt.legend()
plt.grid(True)
plot_path = os.path.join(Config.PLOTS_DIR, "sentiment_plot.png")
plt.savefig(plot_path, bbox_inches='tight')
plt.close()
return plot_path
except Exception as e:
logger.error(f"Sentiment plot failed: {e}")
return None
def generate_wordcloud(texts: List[str]) -> Optional[str]:
"""Generate word cloud from news texts."""
try:
if not texts or all(pd.isna(texts)):
return None
text = " ".join(str(t) for t in texts if t and not pd.isna(t))
wc = WordCloud(width=600, height=400, background_color='white').generate(text)
plot_path = os.path.join(Config.PLOTS_DIR, "wordcloud.png")
wc.to_file(plot_path)
return plot_path
except Exception as e:
logger.error(f"Word cloud generation failed: {e}")
return None
# --------------------------
# UI Helper Functions
# --------------------------
# --------------------------
# UI Functions Implementation
# --------------------------
def predict_single(text_input: str, threshold: float, history_state: list) -> Tuple[pd.DataFrame, list, dict]:
"""Make a single prediction based on news input."""
try:
# Create input DataFrame
today = datetime.today().strftime('%Y-%m-%d')
data = {
'Date': [today],
'Open': [100.0],
'High': [105.0],
'Low': [95.0],
'Close': [100.0],
'Volume': [1000000],
'News': [text_input]
}
df = pd.DataFrame(data)
# Try to get real close price from historical data
try:
hist_data = load_data()
if not hist_data.empty:
df['Close'] = hist_data['Close'].iloc[-1]
except Exception as e:
logger.warning(f"Couldn't load historical data: {e}")
# Get predictions
preds = predict(df, threshold)
if preds.empty:
# Create error response
error_row = {
'Date': [today],
'Close': [0],
'Prediction': ["Error"],
'Confidence': [0],
'Sentiment': [0]
}
error_df = pd.DataFrame(error_row)
return error_df, history_state, None
# Update history
new_history = history_state.copy() if history_state else []
new_history.append(preds.iloc[0].to_dict())
# Keep only last MAX_HISTORY_ENTRIES
if len(new_history) > Config.MAX_HISTORY_ENTRIES:
new_history.pop(0)
return preds, new_history, preds.iloc[0].to_dict()
except Exception as e:
logger.error(f"Single prediction failed: {e}")
today = datetime.today().strftime('%Y-%m-%d')
error_df = pd.DataFrame({
'Date': [today],
'Close': [0],
'Prediction': ["Error"],
'Confidence': [0],
'Sentiment': [0]
})
return error_df, history_state, None
def batch_predict(csv_file: str, threshold: float) -> pd.DataFrame:
"""Process batch prediction from uploaded CSV file."""
try:
if not csv_file:
return pd.DataFrame()
# Read and validate CSV
df = pd.read_csv(csv_file)
required_cols = {'Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'News'}
missing_cols = required_cols - set(df.columns)
if missing_cols:
logger.error(f"Missing columns in CSV: {missing_cols}")
return pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"])
# Make predictions
preds = predict(df, threshold)
return preds
except Exception as e:
logger.error(f"Batch prediction failed: {e}")
return pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"])
def explain_current() -> Optional[str]:
"""Generate SHAP explanation plot using current data."""
try:
df = load_data().tail(100)
df = create_features(df, analyzer)
return generate_shap_plot(df)
except Exception as e:
logger.error(f"SHAP explanation failed: {e}")
return None
def show_sentiment_plot() -> Optional[str]:
"""Generate sentiment visualization plot."""
try:
df = load_data().tail(100)
df = create_features(df, analyzer)
return plot_sentiment(df)
except Exception as e:
logger.error(f"Sentiment plot failed: {e}")
return None
def show_history_table(history_state: list) -> pd.DataFrame:
"""Display prediction history as a DataFrame."""
try:
if not history_state:
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment'])
return pd.DataFrame(history_state)
except Exception as e:
logger.error(f"History display failed: {e}")
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment'])
def summarize_session(df: pd.DataFrame) -> str:
"""Generate summary statistics for the session."""
if df.empty:
return "No predictions to summarize"
return f"""
✅ **Total Predictions:** {len(df)}
📊 **Avg Confidence:** {df['Confidence'].mean():.2f}
🎯 **Predictions:** {df['Prediction'].value_counts().to_dict()}
😊 **Avg Sentiment:** {df['Sentiment'].mean():.2f}
📅 **Date Range:** {df['Date'].min()}{df['Date'].max()}"""
def generate_downloadable(df: pd.DataFrame) -> str:
"""Generate downloadable CSV file."""
path = os.path.join(Config.CACHE_DIR, "predictions.csv")
df.to_csv(path, index=False)
return path
def prediction_distribution(df: pd.DataFrame) -> Optional[str]:
"""Generate prediction distribution plot."""
try:
if df.empty:
return None
fig, ax = plt.subplots(figsize=(8, 4))
df['Prediction'].value_counts().plot(
kind='bar',
ax=ax,
color=['green' if x == 'Buy' else 'red' if x == 'Sell' else 'gray' for x in df['Prediction'].unique()]
)
plt.title("Prediction Distribution")
plt.ylabel("Count")
plot_path = os.path.join(Config.PLOTS_DIR, "pred_dist.png")
fig.savefig(plot_path, bbox_inches='tight')
plt.close()
return plot_path
except Exception as e:
logger.error(f"Prediction distribution plot failed: {e}")
return None
# --------------------------
# Initialize Components
# --------------------------
model = None
try:
model = joblib.load(Config.MODEL_PATH)
logger.info("Model loaded successfully")
except Exception as e:
logger.warning(f"Model loading failed: {e}")
try:
model = train_model()
logger.info("New model trained successfully")
except Exception as e:
logger.error(f"Model training fallback failed: {e}")
raise
analyzer = SentimentAnalyzer()
# --------------------------
# Gradio UI
# --------------------------
with gr.Blocks(title="Stock Prediction", theme=gr.themes.Soft()) as demo:
# State variables
history_state = gr.State(value=[])
last_prediction_state = gr.State(value=None)
# Header
gr.Markdown("""
# 📈 Stock Price Prediction with News Sentiment Analysis
*Predict next-day stock movements using technical indicators and news sentiment*
""")
with gr.Tab("🔍 Single Prediction"):
with gr.Row():
text_input = gr.Textbox(
label="Latest News Headline/Text",
placeholder="Enter news text about the company...",
lines=3,
max_lines=5
)
with gr.Row():
threshold_slider = gr.Slider(
minimum=0,
maximum=0.5,
value=Config.CONFIDENCE_THRESHOLD,
step=0.01,
label="Confidence Threshold",
info="Higher values require more confidence for Buy/Sell decisions"
)
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
last_pred_btn = gr.Button("Show Last Prediction")
refresh_data_btn = gr.Button("Refresh Data")
with gr.Row():
with gr.Column():
prediction_output = gr.Dataframe(
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"],
label="Prediction Results",
interactive=False
)
with gr.Column():
last_prediction_output = gr.Dataframe(
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"],
label="Last Prediction",
interactive=False,
)
with gr.Tab("📂 Batch Prediction"):
with gr.Row():
batch_file = gr.File(
label="Upload CSV File (CSV should contain: Date, Open, High, Low, Close, Volume, News)",
file_types=[".csv"],
type="filepath",
)
with gr.Row():
batch_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0,
maximum=0.5,
value=Config.CONFIDENCE_THRESHOLD,
step=0.01
)
with gr.Row():
batch_predict_btn = gr.Button("Run Batch Prediction", variant="primary")
with gr.Row():
batch_output = gr.Dataframe(
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"],
label="Batch Predictions",
interactive=False
)
with gr.Row():
batch_summary = gr.Markdown()
with gr.Row():
download_btn = gr.Button("Download Predictions")
download_file = gr.File(label="Download", visible=False)
with gr.Tab("📊 Explanations"):
with gr.Row():
explain_btn = gr.Button("Generate SHAP Explanation")
sentiment_plot_btn = gr.Button("Show Sentiment Trend")
with gr.Row():
shap_image = gr.Image(label="SHAP Feature Importance", interactive=False)
sentiment_image = gr.Image(label="Sentiment Over Time", interactive=False)
with gr.Tab("📜 History"):
with gr.Row():
history_btn = gr.Button("Show Prediction History")
clear_history_btn = gr.Button("Clear History", variant="stop")
with gr.Row():
history_output = gr.Dataframe(
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"],
label="Prediction History",
interactive=False
)
with gr.Row():
history_summary = gr.Markdown()
dist_image = gr.Image(label="Prediction Distribution", interactive=False)
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## Stock Prediction App
This application predicts next-day stock price movements using:
- **Technical Indicators**: RSI, MACD, Bollinger Bands, Moving Averages
- **News Sentiment Analysis**: Using transformer models
- **Machine Learning**: Gradient Boosted Decision Trees
### Features:
- Single prediction with news input
- Batch prediction from CSV files
- Model explainability with SHAP values
- Sentiment analysis visualization
- Prediction history tracking
### Technical Details:
- Built with Python, scikit-learn, and Gradio
- Uses DistilBERT for sentiment analysis
- Implements time-series validation
""")
# Event handlers
predict_btn.click(
fn=predict_single,
inputs=[text_input, threshold_slider, history_state],
outputs=[prediction_output, history_state, last_prediction_state],
api_name="predict"
)
last_pred_btn.click(
fn=lambda x: pd.DataFrame([x]) if x else pd.DataFrame(),
inputs=[last_prediction_state],
outputs=[last_prediction_output]
)
refresh_data_btn.click(
fn=lambda: load_data(force_refresh=True),
outputs=[]
)
batch_predict_btn.click(
fn=batch_predict,
inputs=[batch_file, batch_threshold],
outputs=[batch_output]
).then(
fn=summarize_session,
inputs=[batch_output],
outputs=[batch_summary]
).then(
fn=prediction_distribution,
inputs=[batch_output],
outputs=[dist_image]
)
download_btn.click(
fn=generate_downloadable,
inputs=[batch_output],
outputs=[download_file]
).then(
fn=lambda: gr.File(visible=True),
outputs=[download_file]
)
explain_btn.click(
fn=explain_current,
outputs=[shap_image]
)
sentiment_plot_btn.click(
fn=show_sentiment_plot,
outputs=[sentiment_image]
)
history_btn.click(
fn=show_history_table,
outputs=[history_output]
).then(
fn=summarize_session,
inputs=[history_output],
outputs=[history_summary]
).then(
fn=prediction_distribution,
inputs=[history_output],
outputs=[dist_image]
)
clear_history_btn.click(
fn=lambda: [],
outputs=[history_state]
).then(
fn=lambda: pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"]),
outputs=[history_output]
).then(
fn=lambda: "History cleared",
outputs=[history_summary]
)
# --------------------------
# Main Execution
# --------------------------
if __name__ == "__main__":
try:
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False
)
except Exception as e:
logger.error(f"Application failed: {e}")
raise