import streamlit as st
import akshare as ak
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
from datetime import datetime, timedelta
today = datetime.now()
default_start = today - timedelta(days=365)
default_end = today
st.set_page_config(page_title="股票预测", layout="wide")
st.title('股票涨跌预测xgboost')
with st.sidebar:
st.header("参数设置")
stock_code = st.text_input('股票代码', '600000')
start_date = st.date_input('开始日期', default_start)
end_date = st.date_input('结束日期', default_end)
train_ratio = st.slider('训练集比例', 0.6, 0.95, 0.8, 0.05)
adjust_type = st.radio(
"复权类型",
options=[("前复权", "qfq"), ("后复权", "hfq"), ("不复权", '')],
index=0,
format_func=lambda x: x[0],
help="前复权(qfq)/后复权(hfq)/不复权"
)[1]
@st.cache_data(ttl=3600, show_spinner="正在获取股票数据...")
def get_stock_data(code, start, end, adjust):
try:
symbol = code
df = ak.stock_zh_a_hist(
symbol=symbol,
period="daily",
start_date=start.strftime("%Y%m%d"),
end_date=end.strftime("%Y%m%d"),
adjust=adjust
)
df = df.set_index('日期').sort_index()
df.index = pd.to_datetime(df.index)
df = df.rename(columns={
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
})
return df[['open', 'high', 'low', 'close', 'volume']]
except Exception as e:
st.error(f"数据获取失败: {str(e)}")
return pd.DataFrame()
with st.spinner('正在加载数据...'):
data = get_stock_data(stock_code, start_date, end_date, adjust_type)
if data.empty:
st.error("无法获取数据,请检查:\n1. 股票代码是否正确\n2. 日期范围是否有效\n3. 网络连接是否正常")
st.stop()
def create_features(df):
df['target'] = (df['close'].shift(-1) > df['close']).astype(int)
df = df.iloc[:-1]
windows = [5, 10, 20]
for window in windows:
df[f'ma{window}'] = df['close'].rolling(window).mean()
delta = df['close'].diff().dropna()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta 0, 0)
avg_gain = gain.rolling(14).mean()
avg_loss = loss.rolling(14).mean()
rs = avg_gain / (avg_loss + 1e-10)
df['rsi'] = 100 - (100 / (1 + rs))
exp12 = df['close'].ewm(span=12, adjust=False).mean()
exp26 = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = exp12 - exp26
df['signal'] = df['macd'].ewm(span=9, adjust=False).mean()
return df.dropna()
processed_data = create_features(data)
if len(processed_data) 100:
st.warning(f"数据量不足(仅{len(processed_data)}条),建议选择更长的时间范围")
st.stop()
features = ['open', 'high', 'low', 'close', 'volume',
'ma5', 'ma10', 'ma20', 'rsi', 'macd', 'signal']
split_idx = int(len(processed_data) * train_ratio)
X = processed_data[features]
y = processed_data['target']
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
with st.spinner('正在训练模型...'):
model = XGBClassifier(
n_estimators=200,
max_depth=6,
learning_rate=0.05,
subsample=0.8,
colsample_bytree=0.9,
random_state=42
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
col1, col2 = st.columns(2)
with col1:
st.subheader('模型表现')
st.metric("测试集准确率", f"
{accuracy:.2%}")
st.subheader('特征重要性')
importance = pd.DataFrame({
'特征': features,
'重要性': model.feature_importances_
}).sort_values('重要性', ascending=False)
st.bar_chart(importance.set_index('特征'))
with col2:
st.subheader('价格走势与预测')
fig = go.Figure()
fig.add_trace(go.Candlestick(
x=processed_data.index,
open=processed_data['open'],
high=processed_data['high'],
low=processed_data['low'],
close=processed_data['close'],
name='K线'
))
test_dates = processed_data.index[split_idx:]
predictions = pd.Series(y_pred, index=test_dates[:len(y_pred)])
correct_dates = predictions[predictions == y_test].index
fig.add_trace(go.Scatter(
x=correct_dates,
y=processed_data.loc[correct_dates, 'high'] * 1.02,
mode='markers',
marker=dict(color='lime', size=8, symbol='triangle-up'),
name='正确预测'
))
wrong_dates = predictions[predictions != y_test].index
fig.add_trace(go.Scatter(
x=wrong_dates,
y=processed_data.loc[wrong_dates, 'low'] * 0.98,
mode='markers',
marker=dict(color='red', size=8, symbol='triangle-down'),
name='错误预测'
))
fig.update_layout(
height=600,
xaxis_rangeslider_visible=False,
legend=dict(orientation="h", yanchor="bottom", y=1.02)
)
st.plotly_chart(fig, use_container_width=True)
st.markdown("---")
st.warning("""
**风险提示**
本工具仅为技术演示,不构成投资建议。股票市场存在风险,历史表现不代表未来趋势。实际投资请谨慎决策,作者不对任何投资结果负责。
""")