社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

【Python技术】利用机器学习xgboost算法预测股票涨跌例子

子晓聊技术 • 4 月前 • 153 次点击  

群里有同学发了一些预测股票涨停第二天溢价率的方法, 星球有同学私底下问我 能不能写下对应的思路例子。  既然别人有这种软件了,我就不写这种例子了。 那就用类似的例子举例,分享下处理思路。 注意下,这只是示例,真正使用还是需要调优的。
之前已经写过随机森林,那这次就写xgboost(极端梯度提升树)举例把。

机器学习算法 ,我们得先要明白 它算法原理。比如xgboost是什么?它是怎么预测股票涨跌的

先简单介绍下XGBoost
基于Boosting(梯度提升框架),通过串行训练决策树,每一棵树纠正前一棵的残差,逐步减少偏差。

XGBoost的训练方式
  • 树按顺序生成,新树拟合前序模型的残差(梯度方向)。
  • 优化复杂的目标函数(损失函数 + 正则化项,如L1/L2),防止过拟合。

为什么选择XGBoost?

1 传统方法的局限性

传统股票分析方法主要分为技术分析(K线图、均线系统)和基本面分析(财务报表、行业数据),但存在两大痛点:

  • 滞后性
    :技术指标往往基于历史价格计算,难以反映实时市场情绪。
  • 主观性
    :分析师的经验判断容易受情绪影响,缺乏量化依据。

2 XGBoost的独特优势

XGBoost作为一种集成学习算法,在结构化数据预测任务中表现尤为突出:

  • 高效处理非线性关系
    :通过决策树的组合捕捉特征间的复杂交互。
  • 抗过拟合能力强
    :正则化项和特征重要性排序有效降低噪声干扰。
  • 灵活适应金融场景
    :支持缺失值处理、自定义损失函数,适合高频交易数据。

相关文章推荐:
【Python技术】通过akshare、随机森林算法预测股票涨跌可视化例子

题外话:
最近遇到一些哭笑不得的事情,有同学问我能不能实现下XX公众号的代码,  我一看别人公众号是付费的,进入知识星球, 一年几千块。
合着我写的文章附上代码可以白嫖,别人的代码没法白嫖对吧? 

最后还是附上完整源代码
import streamlit as stimport akshare as akimport pandas as pdimport numpy as npimport plotly.graph_objs as gofrom xgboost import XGBClassifierfrom sklearn.metrics import accuracy_scorefrom datetime import datetime, timedelta
today = datetime.now()default_start = today - timedelta(days=365)default_end = today
# Streamlit界面设置st.set_page_config(page_title="股票预测", layout="wide")st.title('股票涨跌预测xgboost')
# 侧边栏输入with st.sidebar:    st.header("参数设置")    stock_code = st.text_input('股票代码''600000')    start_date = st.date_input('开始日期', default_start)    end_date = st.date_input('结束日期', default_end)    train_ratio = st.slider('训练集比例'0.60.950.80.05)    adjust_type = st.radio(        "复权类型",        options=[("前复权""qfq"), ("后复权""hfq"), ("不复权"'')],  # 显示文本与传值分离        index=0,        format_func=lambda x: x[0],  # 只显示文本部分        help="前复权(qfq)/后复权(hfq)/不复权"    )[1]

# 数据获取函数@st.cache_data(ttl=3600, show_spinner="正在获取股票数据...")def get_stock_data(code, start, end, adjust):    try:        # 自动处理代码前缀        symbol =  code
        df = ak.stock_zh_a_hist(            symbol=symbol,            period="daily",            start_date=start.strftime("%Y%m%d"),            end_date=end.strftime("%Y%m%d"),            adjust=adjust        )
        # 数据清洗        df = df.set_index('日期').sort_index()        df.index = pd.to_datetime(df.index)        df = df.rename(columns={            '开盘''open',            '最高''high',            '最低''low',            '收盘''close',            '成交量''volume'        })         return df[['open''high''low''close''volume']]    except Exception as e:        st.error(f"数据获取失败: {str(e)}")        return pd.DataFrame()

# 获取数据with st.spinner('正在加载数据...'):    data = get_stock_data(stock_code, start_date, end_date, adjust_type)
if data.empty:    st.error("无法获取数据,请检查:\n1. 股票代码是否正确\n2. 日期范围是否有效\n3. 网络连接是否正常")    st.stop()

# 特征工程函数def create_features(df):    # 创建目标变量(次日是否上涨)    df['target'] = (df['close'].shift(-1) > df['close']).astype(int)    df = df.iloc[:-1]  # 删除最后一个无法确定target的交易日
    # 移动平均线    windows = [51020]    for window in windows:        df[f'ma{window}'] = df['close'].rolling(window).mean()
    # RSI计算    delta = df['close'].diff().dropna()    gain = delta.where(delta > 00)    loss = -delta.where(delta 00)    avg_gain = gain.rolling(14).mean()    avg_loss = loss.rolling(14).mean()    rs = avg_gain / (avg_loss + 1e-10)  # 防止除零错误    df['rsi'] = 100 - (100 / (1 + rs))
    # MACD计算    exp12 = df['close'].ewm(span=12, adjust=False).mean()    exp26 = df['close'].ewm(span=26, adjust=False).mean()    df['macd'] = exp12 - exp26    df['signal'] = df['macd'].ewm(span=9, adjust=False).mean()
    return df.dropna()

# 处理数据processed_data = create_features(data)
# 数据校验if len(processed_data) 100:    st.warning(f"数据量不足(仅{len(processed_data)}条),建议选择更长的时间范围")    st.stop()
# 特征列定义features = ['open''high''low''close''volume',            'ma5''ma10''ma20''rsi''macd''signal']
# 划分数据集split_idx = int(len(processed_data) * train_ratio)X = processed_data[features]y = processed_data['target']
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
# 模型训练with st.spinner('正在训练模型...'):    model = XGBClassifier(        n_estimators=200,        max_depth=6,        learning_rate=0.05,        subsample=0.8,        colsample_bytree=0.9,        random_state=42    )    model.fit(X_train, y_train)
# 模型评估y_pred = model.predict(X_test)accuracy = accuracy_score(y_test, y_pred)
# 界面展示col1, col2 = st.columns(2)
with col1:    st.subheader('模型表现')    st.metric("测试集准确率"f" {accuracy:.2%}")
    # 特征重要性    st.subheader('特征重要性')    importance = pd.DataFrame({        '特征': features,        '重要性': model.feature_importances_    }).sort_values('重要性', ascending=False)    st.bar_chart(importance.set_index('特征'))
with col2:    st.subheader('价格走势与预测')
    # 创建K线图    fig = go.Figure()
    # K线主图    fig.add_trace(go.Candlestick(        x=processed_data.index,        open=processed_data['open'],        high=processed_data['high'],        low=processed_data['low'],        close=processed_data['close'],        name='K线'    ))
    # 预测标记(测试集部分)    test_dates = processed_data.index[split_idx:]    predictions = pd.Series(y_pred, index=test_dates[:len(y_pred)])
    # 正确预测标记    correct_dates = predictions[predictions == y_test].index    fig.add_trace(go.Scatter(        x=correct_dates,        y=processed_data.loc[correct_dates, 'high'] * 1.02,        mode='markers',        marker=dict(color='lime', size=8, symbol='triangle-up'),        name='正确预测'    ))
    # 错误预测标记    wrong_dates = predictions[predictions != y_test].index    fig.add_trace(go.Scatter(        x=wrong_dates,        y=processed_data.loc[wrong_dates, 'low'] * 0.98,        mode='markers',        marker=dict(color='red', size=8, symbol='triangle-down'),        name='错误预测'    ))
    fig.update_layout(        height=600,        xaxis_rangeslider_visible=False,        legend=dict(orientation="h", yanchor="bottom", y=1.02)    )    st.plotly_chart(fig, use_container_width=True)
# 风险提示st.markdown("---")st.warning("""**风险提示**  本工具仅为技术演示,不构成投资建议。股票市场存在风险,历史表现不代表未来趋势。实际投资请谨慎决策,作者不对任何投资结果负责。""")


最后打个广告,推广下我的知识星球。 本想随缘躺平的,但想起最近的一些事情我觉得有点扯。 星球里随缘上传一些资料,没啥干货,1年99,信者入。  加入后加我微信, 拉入专属星球微信群。

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/181398
 
153 次点击