|
|
import os |
|
|
import logging |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from transformers import pipeline |
|
|
from sklearn.ensemble import HistGradientBoostingClassifier |
|
|
from sklearn.model_selection import TimeSeriesSplit |
|
|
from sklearn.metrics import accuracy_score, f1_score |
|
|
from sklearn.base import clone |
|
|
import joblib |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
import warnings |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
import shap |
|
|
from datetime import datetime, timedelta |
|
|
import sys |
|
|
from wordcloud import WordCloud |
|
|
from typing import Optional, Tuple, Dict, List |
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
DATA_PATH = "data/ibm_cleaned.parquet" |
|
|
MODEL_PATH = "models/stock_prediction_model.joblib" |
|
|
CACHE_DIR = "./cache" |
|
|
LOGS_DIR = "./logs" |
|
|
PLOTS_DIR = "./plots" |
|
|
SENTIMENT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" |
|
|
SENTIMENT_BATCH_SIZE = 8 |
|
|
SENTIMENT_MAX_LEN = 256 |
|
|
CONFIDENCE_THRESHOLD = 0.3 |
|
|
DATA_REFRESH_DAYS = 7 |
|
|
MAX_HISTORY_ENTRIES = 50 |
|
|
|
|
|
@classmethod |
|
|
def setup(cls): |
|
|
os.makedirs(cls.CACHE_DIR, exist_ok=True) |
|
|
os.makedirs(cls.LOGS_DIR, exist_ok=True) |
|
|
os.makedirs(cls.PLOTS_DIR, exist_ok=True) |
|
|
os.makedirs(os.path.dirname(cls.MODEL_PATH), exist_ok=True) |
|
|
|
|
|
Config.setup() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
filename=os.path.join(Config.LOGS_DIR, 'app.log'), |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(test_mode: bool = False, force_refresh: bool = False) -> pd.DataFrame: |
|
|
"""Load and preprocess the stock data with caching and validation.""" |
|
|
try: |
|
|
|
|
|
if not force_refresh and os.path.exists(Config.DATA_PATH): |
|
|
file_age = datetime.now() - datetime.fromtimestamp(os.path.getmtime(Config.DATA_PATH)) |
|
|
if file_age.days < Config.DATA_REFRESH_DAYS: |
|
|
logger.info("Using cached data") |
|
|
|
|
|
df = pd.read_parquet(Config.DATA_PATH) |
|
|
|
|
|
|
|
|
required_cols = {'Open', 'High', 'Low', 'Close', 'Volume', 'Date', 'News'} |
|
|
missing_cols = required_cols - set(df.columns) |
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing columns in data: {missing_cols}") |
|
|
|
|
|
|
|
|
num_cols = ['Open', 'High', 'Low', 'Close'] |
|
|
df[num_cols] = df[num_cols].astype('float32') |
|
|
|
|
|
if (df['Volume'] < 0).any(): |
|
|
logger.warning("Negative values found in 'Volume'; downcasting as signed int.") |
|
|
df['Volume'] = pd.to_numeric(df['Volume'], downcast='integer') |
|
|
else: |
|
|
df['Volume'] = pd.to_numeric(df['Volume'], downcast='unsigned') |
|
|
|
|
|
if not pd.api.types.is_datetime64_any_dtype(df['Date']): |
|
|
df['Date'] = pd.to_datetime(df['Date'], errors='coerce') |
|
|
if df['Date'].isna().any(): |
|
|
raise ValueError("Date column contains invalid dates after parsing.") |
|
|
|
|
|
|
|
|
if df.isnull().sum().sum() > 0: |
|
|
logger.warning("Data contains null values - filling with forward fill") |
|
|
df = df.ffill() |
|
|
|
|
|
if test_mode: |
|
|
sample_size = min(5000, len(df)) |
|
|
if len(df) < 5000: |
|
|
logger.warning(f"Data has only {len(df)} rows - using full dataset in test mode") |
|
|
return df.sample(sample_size, random_state=42).sort_values('Date').reset_index(drop=True) |
|
|
|
|
|
return df.sort_values('Date').reset_index(drop=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Data loading failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SentimentAnalyzer: |
|
|
"""Handles sentiment analysis with caching and batching.""" |
|
|
|
|
|
def __init__(self): |
|
|
self._initialize_pipeline() |
|
|
self.sentiment_cache = {} |
|
|
|
|
|
def _initialize_pipeline(self): |
|
|
"""Initialize the sentiment analysis pipeline.""" |
|
|
try: |
|
|
self.pipeline = pipeline( |
|
|
"text-classification", |
|
|
model=Config.SENTIMENT_MODEL, |
|
|
device=0 if torch.cuda.is_available() else -1, |
|
|
batch_size=Config.SENTIMENT_BATCH_SIZE, |
|
|
truncation=True, |
|
|
max_length=Config.SENTIMENT_MAX_LEN, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Sentiment model load failed: {e}") |
|
|
self.pipeline = None |
|
|
raise |
|
|
|
|
|
def analyze_texts(self, texts: List[str]) -> np.ndarray: |
|
|
"""Analyze sentiment for a batch of texts with caching.""" |
|
|
if self.pipeline is None or not texts: |
|
|
return np.zeros(len(texts), dtype='float32') |
|
|
|
|
|
|
|
|
uncached_texts = [] |
|
|
uncached_indices = [] |
|
|
cached_scores = np.zeros(len(texts), dtype='float32') |
|
|
|
|
|
for i, text in enumerate(texts): |
|
|
text = str(text)[:Config.SENTIMENT_MAX_LEN] if text else "" |
|
|
if text in self.sentiment_cache: |
|
|
cached_scores[i] = self.sentiment_cache[text] |
|
|
else: |
|
|
uncached_texts.append(text) |
|
|
uncached_indices.append(i) |
|
|
|
|
|
|
|
|
if uncached_texts: |
|
|
try: |
|
|
results = self.pipeline( |
|
|
uncached_texts, |
|
|
truncation=True, |
|
|
max_length=Config.SENTIMENT_MAX_LEN, |
|
|
batch_size=Config.SENTIMENT_BATCH_SIZE |
|
|
) |
|
|
|
|
|
for idx, res in zip(uncached_indices, results): |
|
|
label = res[0]['label'] |
|
|
score = res[0]['score'] |
|
|
sentiment_score = score if label == "POSITIVE" else -score |
|
|
cached_scores[idx] = sentiment_score |
|
|
self.sentiment_cache[uncached_texts[idx]] = sentiment_score |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Sentiment analysis error: {e}") |
|
|
|
|
|
return cached_scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_rsi(series: pd.Series, window: int = 14) -> pd.Series: |
|
|
"""Calculate Relative Strength Index.""" |
|
|
try: |
|
|
series = pd.to_numeric(series, errors='coerce').ffill() |
|
|
delta = series.diff() |
|
|
gain = delta.clip(lower=0) |
|
|
loss = -delta.clip(upper=0) |
|
|
|
|
|
avg_gain = gain.rolling(window, min_periods=1).mean() |
|
|
avg_loss = loss.rolling(window, min_periods=1).mean().replace(0, 1e-10) |
|
|
|
|
|
rs = avg_gain / avg_loss |
|
|
return (100 - (100 / (1 + rs))).astype('float32') |
|
|
except Exception as e: |
|
|
logger.error(f"RSI calculation failed: {e}") |
|
|
return pd.Series(np.nan, index=series.index) |
|
|
|
|
|
def calculate_macd( |
|
|
series: pd.Series, |
|
|
fast: int = 12, |
|
|
slow: int = 26, |
|
|
signal: int = 9 |
|
|
) -> Tuple[pd.Series, pd.Series, pd.Series]: |
|
|
"""Calculate MACD line, signal line, and histogram.""" |
|
|
try: |
|
|
series = pd.to_numeric(series, errors='coerce').ffill() |
|
|
ema_fast = series.ewm(span=fast, adjust=False).mean() |
|
|
ema_slow = series.ewm(span=slow, adjust=False).mean() |
|
|
macd_line = ema_fast - ema_slow |
|
|
signal_line = macd_line.ewm(span=signal, adjust=False).mean() |
|
|
return ( |
|
|
macd_line.astype('float32'), |
|
|
signal_line.astype('float32'), |
|
|
(macd_line - signal_line).astype('float32') |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"MACD calculation failed: {e}") |
|
|
return ( |
|
|
pd.Series(np.nan, index=series.index), |
|
|
pd.Series(np.nan, index=series.index), |
|
|
pd.Series(np.nan, index=series.index) |
|
|
) |
|
|
|
|
|
def calculate_bollinger_bands( |
|
|
series: pd.Series, |
|
|
window: int = 20, |
|
|
no_of_std: int = 2 |
|
|
) -> Tuple[pd.Series, pd.Series]: |
|
|
"""Calculate Bollinger Bands.""" |
|
|
try: |
|
|
series = pd.to_numeric(series, errors='coerce').ffill() |
|
|
rolling_mean = series.rolling(window, min_periods=1).mean() |
|
|
rolling_std = series.rolling(window, min_periods=1).std() |
|
|
return ( |
|
|
(rolling_mean + (no_of_std * rolling_std)).astype('float32'), |
|
|
(rolling_mean - (no_of_std * rolling_std)).astype('float32') |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Bollinger Bands calculation failed: {e}") |
|
|
return ( |
|
|
pd.Series(np.nan, index=series.index), |
|
|
pd.Series(np.nan, index=series.index) |
|
|
) |
|
|
|
|
|
def calculate_volatility(series: pd.Series, window: int = 10) -> pd.Series: |
|
|
"""Calculate price volatility.""" |
|
|
try: |
|
|
series = pd.to_numeric(series, errors='coerce').ffill() |
|
|
return series.pct_change().rolling(window, min_periods=1).std().astype('float32') |
|
|
except Exception as e: |
|
|
logger.error(f"Volatility calculation failed: {e}") |
|
|
return pd.Series(np.nan, index=series.index) |
|
|
|
|
|
def create_features(df: pd.DataFrame, analyzer: SentimentAnalyzer) -> pd.DataFrame: |
|
|
"""Create all features for the prediction model.""" |
|
|
try: |
|
|
df = df.copy() |
|
|
|
|
|
|
|
|
required_cols = {'Close', 'News'} |
|
|
missing_cols = required_cols - set(df.columns) |
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing columns for feature creation: {missing_cols}") |
|
|
|
|
|
|
|
|
df['Price_Change'] = df['Close'].pct_change().astype('float32') |
|
|
df['Log_Return'] = np.log1p(df['Price_Change'].clip(lower=-0.9999)).astype('float32') |
|
|
df['MA_5'] = df['Close'].rolling(5, min_periods=1).mean().astype('float32') |
|
|
df['MA_20'] = df['Close'].rolling(20, min_periods=1).mean().astype('float32') |
|
|
df['RSI'] = calculate_rsi(df['Close']) |
|
|
df['Volatility_10'] = calculate_volatility(df['Close'], 10) |
|
|
|
|
|
|
|
|
macd_line, signal_line, hist = calculate_macd(df['Close']) |
|
|
df['MACD'] = macd_line |
|
|
df['MACD_Signal'] = signal_line |
|
|
df['MACD_Hist'] = hist |
|
|
|
|
|
|
|
|
bb_upper, bb_lower = calculate_bollinger_bands(df['Close']) |
|
|
df['BB_Upper'] = bb_upper |
|
|
df['BB_Lower'] = bb_lower |
|
|
|
|
|
|
|
|
if 'News' not in df or df['News'].isnull().all(): |
|
|
logger.warning("No news data - setting Sentiment features to zero") |
|
|
df['Sentiment'] = 0.0 |
|
|
else: |
|
|
df['Sentiment'] = analyzer.analyze_texts(df['News'].fillna('').tolist()) |
|
|
|
|
|
df['Sentiment_MA'] = df['Sentiment'].rolling(5, min_periods=1).mean().astype('float32') |
|
|
df['Sentiment_Lag1'] = df['Sentiment'].shift(1).fillna(0).astype('float32') |
|
|
|
|
|
|
|
|
df['Target'] = (df['Close'].shift(-1) > df['Close']).astype('int8') |
|
|
|
|
|
return df.dropna().reset_index(drop=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Feature engineering error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model() -> HistGradientBoostingClassifier: |
|
|
"""Train and evaluate the prediction model.""" |
|
|
try: |
|
|
logger.info("Starting model training...") |
|
|
df = load_data(test_mode=True) |
|
|
|
|
|
|
|
|
split_idx = int(0.8 * len(df)) |
|
|
train_df = df.iloc[:split_idx].copy() |
|
|
test_df = df.iloc[split_idx:].copy() |
|
|
|
|
|
analyzer = SentimentAnalyzer() |
|
|
features = [] |
|
|
|
|
|
|
|
|
chunk_size = 1000 |
|
|
for i in tqdm(range(0, len(train_df), chunk_size), desc="Creating features"): |
|
|
chunk = train_df.iloc[i:i+chunk_size] |
|
|
features.append(create_features(chunk, analyzer)) |
|
|
|
|
|
train_processed = pd.concat(features) |
|
|
|
|
|
|
|
|
expected_features = [ |
|
|
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change', |
|
|
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10', |
|
|
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower', 'Target' |
|
|
] |
|
|
|
|
|
|
|
|
missing_cols = [c for c in expected_features if c not in train_processed.columns] |
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing features in training data: {missing_cols}") |
|
|
|
|
|
train_processed = train_processed[expected_features + ['Date']] |
|
|
|
|
|
X_train = train_processed.drop(columns=['Date', 'Target']) |
|
|
y_train = train_processed['Target'] |
|
|
|
|
|
|
|
|
base_model = HistGradientBoostingClassifier( |
|
|
max_iter=100, |
|
|
max_depth=5, |
|
|
learning_rate=0.05, |
|
|
random_state=42, |
|
|
verbose=1, |
|
|
early_stopping=True, |
|
|
validation_fraction=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
tscv = TimeSeriesSplit(n_splits=3) |
|
|
f1_scores = [] |
|
|
|
|
|
for train_idx, val_idx in tscv.split(X_train): |
|
|
model = clone(base_model) |
|
|
model.fit(X_train.iloc[train_idx], y_train.iloc[train_idx]) |
|
|
preds = model.predict(X_train.iloc[val_idx]) |
|
|
f1_scores.append(f1_score(y_train.iloc[val_idx], preds)) |
|
|
logger.info(f"Fold F1: {f1_scores[-1]:.3f}") |
|
|
|
|
|
logger.info(f"Validation F1: {np.mean(f1_scores):.3f} ± {np.std(f1_scores):.3f}") |
|
|
|
|
|
|
|
|
model = clone(base_model) |
|
|
model.fit(X_train, y_train) |
|
|
|
|
|
|
|
|
test_processed = create_features(test_df, analyzer) |
|
|
X_test = test_processed[expected_features[:-1]] |
|
|
y_test = test_processed['Target'] |
|
|
|
|
|
|
|
|
if 'Date' in X_test.columns: |
|
|
X_test = X_test.drop(columns=['Date']) |
|
|
|
|
|
y_pred = model.predict(X_test) |
|
|
acc = accuracy_score(y_test, y_pred) |
|
|
f1 = f1_score(y_test, y_pred) |
|
|
|
|
|
logger.info(f"Test Accuracy: {acc:.3f} | F1: {f1:.3f}") |
|
|
|
|
|
|
|
|
joblib.dump(model, Config.MODEL_PATH) |
|
|
logger.info(f"Model saved to {Config.MODEL_PATH}") |
|
|
|
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model training failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(df: pd.DataFrame, threshold: float = Config.CONFIDENCE_THRESHOLD) -> pd.DataFrame: |
|
|
"""Make predictions with confidence scores.""" |
|
|
try: |
|
|
if model is None: |
|
|
raise ValueError("Model is not loaded") |
|
|
|
|
|
df_features = create_features(df, analyzer) |
|
|
|
|
|
required_features = [ |
|
|
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change', |
|
|
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10', |
|
|
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower' |
|
|
] |
|
|
|
|
|
X = df_features[required_features] |
|
|
proba = model.predict_proba(X)[:, 1] |
|
|
|
|
|
|
|
|
preds = np.where( |
|
|
proba > (0.5 + threshold), "Buy", |
|
|
np.where(proba < (0.5 - threshold), "Sell", "Hold") |
|
|
) |
|
|
|
|
|
df_features['Prediction'] = preds |
|
|
df_features['Confidence'] = np.abs(proba - 0.5) * 2 |
|
|
|
|
|
return df_features[['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment']] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {e}") |
|
|
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_shap_plot(df: pd.DataFrame) -> Optional[str]: |
|
|
"""Generate SHAP explanation plot.""" |
|
|
try: |
|
|
df = df.tail(50) |
|
|
|
|
|
required_features = [ |
|
|
'Sentiment', 'Sentiment_MA', 'Sentiment_Lag1', 'Price_Change', |
|
|
'Log_Return', 'MA_5', 'MA_20', 'RSI', 'Volatility_10', |
|
|
'MACD', 'MACD_Signal', 'MACD_Hist', 'BB_Upper', 'BB_Lower' |
|
|
] |
|
|
|
|
|
X = df[required_features] |
|
|
explainer = shap.TreeExplainer(model) |
|
|
shap_values = explainer.shap_values(X) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
shap.summary_plot(shap_values, X, show=False, plot_size=(10, 6)) |
|
|
|
|
|
plot_path = os.path.join(Config.PLOTS_DIR, "shap_summary.png") |
|
|
plt.savefig(plot_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
return plot_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SHAP plot generation failed: {e}") |
|
|
return None |
|
|
|
|
|
def plot_sentiment(df: pd.DataFrame) -> Optional[str]: |
|
|
"""Generate sentiment trend plot.""" |
|
|
try: |
|
|
plt.figure(figsize=(10, 4)) |
|
|
plt.plot(df['Date'], df['Sentiment'], label="Daily Sentiment") |
|
|
plt.plot(df['Date'], df['Sentiment_MA'], label="5-Day MA", alpha=0.7) |
|
|
plt.title("News Sentiment Over Time") |
|
|
plt.xlabel("Date") |
|
|
plt.ylabel("Sentiment Score") |
|
|
plt.legend() |
|
|
plt.grid(True) |
|
|
|
|
|
plot_path = os.path.join(Config.PLOTS_DIR, "sentiment_plot.png") |
|
|
plt.savefig(plot_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
return plot_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Sentiment plot failed: {e}") |
|
|
return None |
|
|
|
|
|
def generate_wordcloud(texts: List[str]) -> Optional[str]: |
|
|
"""Generate word cloud from news texts.""" |
|
|
try: |
|
|
if not texts or all(pd.isna(texts)): |
|
|
return None |
|
|
|
|
|
text = " ".join(str(t) for t in texts if t and not pd.isna(t)) |
|
|
wc = WordCloud(width=600, height=400, background_color='white').generate(text) |
|
|
|
|
|
plot_path = os.path.join(Config.PLOTS_DIR, "wordcloud.png") |
|
|
wc.to_file(plot_path) |
|
|
|
|
|
return plot_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Word cloud generation failed: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_single(text_input: str, threshold: float, history_state: list) -> Tuple[pd.DataFrame, list, dict]: |
|
|
"""Make a single prediction based on news input.""" |
|
|
try: |
|
|
|
|
|
today = datetime.today().strftime('%Y-%m-%d') |
|
|
data = { |
|
|
'Date': [today], |
|
|
'Open': [100.0], |
|
|
'High': [105.0], |
|
|
'Low': [95.0], |
|
|
'Close': [100.0], |
|
|
'Volume': [1000000], |
|
|
'News': [text_input] |
|
|
} |
|
|
df = pd.DataFrame(data) |
|
|
|
|
|
|
|
|
try: |
|
|
hist_data = load_data() |
|
|
if not hist_data.empty: |
|
|
df['Close'] = hist_data['Close'].iloc[-1] |
|
|
except Exception as e: |
|
|
logger.warning(f"Couldn't load historical data: {e}") |
|
|
|
|
|
|
|
|
preds = predict(df, threshold) |
|
|
|
|
|
if preds.empty: |
|
|
|
|
|
error_row = { |
|
|
'Date': [today], |
|
|
'Close': [0], |
|
|
'Prediction': ["Error"], |
|
|
'Confidence': [0], |
|
|
'Sentiment': [0] |
|
|
} |
|
|
error_df = pd.DataFrame(error_row) |
|
|
return error_df, history_state, None |
|
|
|
|
|
|
|
|
new_history = history_state.copy() if history_state else [] |
|
|
new_history.append(preds.iloc[0].to_dict()) |
|
|
|
|
|
|
|
|
if len(new_history) > Config.MAX_HISTORY_ENTRIES: |
|
|
new_history.pop(0) |
|
|
|
|
|
return preds, new_history, preds.iloc[0].to_dict() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Single prediction failed: {e}") |
|
|
today = datetime.today().strftime('%Y-%m-%d') |
|
|
error_df = pd.DataFrame({ |
|
|
'Date': [today], |
|
|
'Close': [0], |
|
|
'Prediction': ["Error"], |
|
|
'Confidence': [0], |
|
|
'Sentiment': [0] |
|
|
}) |
|
|
return error_df, history_state, None |
|
|
|
|
|
def batch_predict(csv_file: str, threshold: float) -> pd.DataFrame: |
|
|
"""Process batch prediction from uploaded CSV file.""" |
|
|
try: |
|
|
if not csv_file: |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
df = pd.read_csv(csv_file) |
|
|
required_cols = {'Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'News'} |
|
|
missing_cols = required_cols - set(df.columns) |
|
|
|
|
|
if missing_cols: |
|
|
logger.error(f"Missing columns in CSV: {missing_cols}") |
|
|
return pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"]) |
|
|
|
|
|
|
|
|
preds = predict(df, threshold) |
|
|
return preds |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Batch prediction failed: {e}") |
|
|
return pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"]) |
|
|
|
|
|
def explain_current() -> Optional[str]: |
|
|
"""Generate SHAP explanation plot using current data.""" |
|
|
try: |
|
|
df = load_data().tail(100) |
|
|
df = create_features(df, analyzer) |
|
|
return generate_shap_plot(df) |
|
|
except Exception as e: |
|
|
logger.error(f"SHAP explanation failed: {e}") |
|
|
return None |
|
|
|
|
|
def show_sentiment_plot() -> Optional[str]: |
|
|
"""Generate sentiment visualization plot.""" |
|
|
try: |
|
|
df = load_data().tail(100) |
|
|
df = create_features(df, analyzer) |
|
|
return plot_sentiment(df) |
|
|
except Exception as e: |
|
|
logger.error(f"Sentiment plot failed: {e}") |
|
|
return None |
|
|
|
|
|
def show_history_table(history_state: list) -> pd.DataFrame: |
|
|
"""Display prediction history as a DataFrame.""" |
|
|
try: |
|
|
if not history_state: |
|
|
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment']) |
|
|
return pd.DataFrame(history_state) |
|
|
except Exception as e: |
|
|
logger.error(f"History display failed: {e}") |
|
|
return pd.DataFrame(columns=['Date', 'Close', 'Prediction', 'Confidence', 'Sentiment']) |
|
|
|
|
|
def summarize_session(df: pd.DataFrame) -> str: |
|
|
"""Generate summary statistics for the session.""" |
|
|
if df.empty: |
|
|
return "No predictions to summarize" |
|
|
|
|
|
return f""" |
|
|
✅ **Total Predictions:** {len(df)} |
|
|
📊 **Avg Confidence:** {df['Confidence'].mean():.2f} |
|
|
🎯 **Predictions:** {df['Prediction'].value_counts().to_dict()} |
|
|
😊 **Avg Sentiment:** {df['Sentiment'].mean():.2f} |
|
|
📅 **Date Range:** {df['Date'].min()} → {df['Date'].max()}""" |
|
|
|
|
|
def generate_downloadable(df: pd.DataFrame) -> str: |
|
|
"""Generate downloadable CSV file.""" |
|
|
path = os.path.join(Config.CACHE_DIR, "predictions.csv") |
|
|
df.to_csv(path, index=False) |
|
|
return path |
|
|
|
|
|
def prediction_distribution(df: pd.DataFrame) -> Optional[str]: |
|
|
"""Generate prediction distribution plot.""" |
|
|
try: |
|
|
if df.empty: |
|
|
return None |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
|
df['Prediction'].value_counts().plot( |
|
|
kind='bar', |
|
|
ax=ax, |
|
|
color=['green' if x == 'Buy' else 'red' if x == 'Sell' else 'gray' for x in df['Prediction'].unique()] |
|
|
) |
|
|
plt.title("Prediction Distribution") |
|
|
plt.ylabel("Count") |
|
|
|
|
|
plot_path = os.path.join(Config.PLOTS_DIR, "pred_dist.png") |
|
|
fig.savefig(plot_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
return plot_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction distribution plot failed: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
try: |
|
|
model = joblib.load(Config.MODEL_PATH) |
|
|
logger.info("Model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"Model loading failed: {e}") |
|
|
try: |
|
|
model = train_model() |
|
|
logger.info("New model trained successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Model training fallback failed: {e}") |
|
|
raise |
|
|
|
|
|
analyzer = SentimentAnalyzer() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Stock Prediction", theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
history_state = gr.State(value=[]) |
|
|
last_prediction_state = gr.State(value=None) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
# 📈 Stock Price Prediction with News Sentiment Analysis |
|
|
*Predict next-day stock movements using technical indicators and news sentiment* |
|
|
""") |
|
|
|
|
|
with gr.Tab("🔍 Single Prediction"): |
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
label="Latest News Headline/Text", |
|
|
placeholder="Enter news text about the company...", |
|
|
lines=3, |
|
|
max_lines=5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
threshold_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=0.5, |
|
|
value=Config.CONFIDENCE_THRESHOLD, |
|
|
step=0.01, |
|
|
label="Confidence Threshold", |
|
|
info="Higher values require more confidence for Buy/Sell decisions" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
predict_btn = gr.Button("Predict", variant="primary") |
|
|
last_pred_btn = gr.Button("Show Last Prediction") |
|
|
refresh_data_btn = gr.Button("Refresh Data") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prediction_output = gr.Dataframe( |
|
|
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"], |
|
|
label="Prediction Results", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
last_prediction_output = gr.Dataframe( |
|
|
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"], |
|
|
label="Last Prediction", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Tab("📂 Batch Prediction"): |
|
|
with gr.Row(): |
|
|
batch_file = gr.File( |
|
|
label="Upload CSV File (CSV should contain: Date, Open, High, Low, Close, Volume, News)", |
|
|
file_types=[".csv"], |
|
|
type="filepath", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
batch_threshold = gr.Slider( |
|
|
label="Confidence Threshold", |
|
|
minimum=0, |
|
|
maximum=0.5, |
|
|
value=Config.CONFIDENCE_THRESHOLD, |
|
|
step=0.01 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
batch_predict_btn = gr.Button("Run Batch Prediction", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
batch_output = gr.Dataframe( |
|
|
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"], |
|
|
label="Batch Predictions", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
batch_summary = gr.Markdown() |
|
|
|
|
|
with gr.Row(): |
|
|
download_btn = gr.Button("Download Predictions") |
|
|
download_file = gr.File(label="Download", visible=False) |
|
|
|
|
|
with gr.Tab("📊 Explanations"): |
|
|
with gr.Row(): |
|
|
explain_btn = gr.Button("Generate SHAP Explanation") |
|
|
sentiment_plot_btn = gr.Button("Show Sentiment Trend") |
|
|
|
|
|
with gr.Row(): |
|
|
shap_image = gr.Image(label="SHAP Feature Importance", interactive=False) |
|
|
sentiment_image = gr.Image(label="Sentiment Over Time", interactive=False) |
|
|
|
|
|
with gr.Tab("📜 History"): |
|
|
with gr.Row(): |
|
|
history_btn = gr.Button("Show Prediction History") |
|
|
clear_history_btn = gr.Button("Clear History", variant="stop") |
|
|
|
|
|
with gr.Row(): |
|
|
history_output = gr.Dataframe( |
|
|
headers=["Date", "Close", "Prediction", "Confidence", "Sentiment"], |
|
|
label="Prediction History", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
history_summary = gr.Markdown() |
|
|
dist_image = gr.Image(label="Prediction Distribution", interactive=False) |
|
|
|
|
|
with gr.Tab("ℹ️ About"): |
|
|
gr.Markdown(""" |
|
|
## Stock Prediction App |
|
|
|
|
|
This application predicts next-day stock price movements using: |
|
|
|
|
|
- **Technical Indicators**: RSI, MACD, Bollinger Bands, Moving Averages |
|
|
- **News Sentiment Analysis**: Using transformer models |
|
|
- **Machine Learning**: Gradient Boosted Decision Trees |
|
|
|
|
|
### Features: |
|
|
- Single prediction with news input |
|
|
- Batch prediction from CSV files |
|
|
- Model explainability with SHAP values |
|
|
- Sentiment analysis visualization |
|
|
- Prediction history tracking |
|
|
|
|
|
### Technical Details: |
|
|
- Built with Python, scikit-learn, and Gradio |
|
|
- Uses DistilBERT for sentiment analysis |
|
|
- Implements time-series validation |
|
|
""") |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_single, |
|
|
inputs=[text_input, threshold_slider, history_state], |
|
|
outputs=[prediction_output, history_state, last_prediction_state], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
last_pred_btn.click( |
|
|
fn=lambda x: pd.DataFrame([x]) if x else pd.DataFrame(), |
|
|
inputs=[last_prediction_state], |
|
|
outputs=[last_prediction_output] |
|
|
) |
|
|
|
|
|
refresh_data_btn.click( |
|
|
fn=lambda: load_data(force_refresh=True), |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
batch_predict_btn.click( |
|
|
fn=batch_predict, |
|
|
inputs=[batch_file, batch_threshold], |
|
|
outputs=[batch_output] |
|
|
).then( |
|
|
fn=summarize_session, |
|
|
inputs=[batch_output], |
|
|
outputs=[batch_summary] |
|
|
).then( |
|
|
fn=prediction_distribution, |
|
|
inputs=[batch_output], |
|
|
outputs=[dist_image] |
|
|
) |
|
|
|
|
|
download_btn.click( |
|
|
fn=generate_downloadable, |
|
|
inputs=[batch_output], |
|
|
outputs=[download_file] |
|
|
).then( |
|
|
fn=lambda: gr.File(visible=True), |
|
|
outputs=[download_file] |
|
|
) |
|
|
|
|
|
explain_btn.click( |
|
|
fn=explain_current, |
|
|
outputs=[shap_image] |
|
|
) |
|
|
|
|
|
sentiment_plot_btn.click( |
|
|
fn=show_sentiment_plot, |
|
|
outputs=[sentiment_image] |
|
|
) |
|
|
|
|
|
history_btn.click( |
|
|
fn=show_history_table, |
|
|
outputs=[history_output] |
|
|
).then( |
|
|
fn=summarize_session, |
|
|
inputs=[history_output], |
|
|
outputs=[history_summary] |
|
|
).then( |
|
|
fn=prediction_distribution, |
|
|
inputs=[history_output], |
|
|
outputs=[dist_image] |
|
|
) |
|
|
|
|
|
clear_history_btn.click( |
|
|
fn=lambda: [], |
|
|
outputs=[history_state] |
|
|
).then( |
|
|
fn=lambda: pd.DataFrame(columns=["Date", "Close", "Prediction", "Confidence", "Sentiment"]), |
|
|
outputs=[history_output] |
|
|
).then( |
|
|
fn=lambda: "History cleared", |
|
|
outputs=[history_summary] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True, |
|
|
share=False |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Application failed: {e}") |
|
|
raise |