首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >【Python技术】利用LSTM时间序列算法预测股票涨跌例子

【Python技术】利用LSTM时间序列算法预测股票涨跌例子

作者头像
子晓聊技术
发布2026-04-23 14:12:24
发布2026-04-23 14:12:24
1870
举报
文章被收录于专栏:子晓AI量化子晓AI量化

之前写了xgboost、随机森林预测股票涨跌的例子,有同学问LSTM时间序列,那这里写一个简单示例。仅做演示哈

先简单讲一下LSTM预测股票涨跌原理。 怎么解释呢?简单来说,LSTM就像一个能记住重点、过滤杂音的“股票分析师”,通过分析历史数据中的关键规律,预测未来价格走势。 它的核心能力在于分辨哪些信息需要长期记住(比如牛市趋势),哪些可以忽略(比如某天的偶然波动)。


一、LSTM的三大“智能开关”

  1. 遗忘门:负责断舍离
    • 功能:丢掉过时的信息。
    • 举例:如果央行突然加息,模型会自动降低对过去低利率时期股价规律的依赖。
  2. 输入门:负责抓重点
    • 功能:识别当前的重要信号。
    • 举例:当某只股票出现“成交量暴增+突破压力位”,输入门会重点记录这一组合特征。
  3. 输出门:负责做决策
    • 功能:结合记忆和当前信息输出预测结果。
    • 举例:如果过去3个月是震荡上行趋势,且今天出现利好消息,输出门可能判断明天上涨概率更高。

二、数据准备的三大步骤

  1. 数据清洗
    • 删除明显错误数据(如股价突然归零的系统故障记录)
    • 补全缺失值(如用前后两天的平均成交量填充节假日数据)
  2. 特征加工
    • 新闻情绪分(例如用AI判断“某公司疫苗获批”是利好)
    • 行业指数(如半导体板块整体走势)
    • 均线(例如5日均线代表短期趋势)
    • MACD(判断买卖力道的“多空能量柱”)
    • 基础数据:每日开盘价、收盘价、成交量(相当于股票的“心跳、体温”)
    • 技术指标
    • 外部信息
  3. 时间切片
    • 把数据切成连续的时间段,就像用短视频片段训练模型: 输入:过去60天的价格+成交量+新闻数据 输出:第61天涨跌概率

三、训练模型的三个关键

  1. 试错学习
    • 模型先根据历史数据猜测,如果预测错误(例如实际跌了却预测涨),就通过“反向传播”自动调整参数,类似于学生通过错题修正知识盲点。
  2. 防止过拟合
    • 早停法:如果连续10次训练预测准确率不提升,立即停止,避免死记硬背历史数据。
    • 随机屏蔽:训练时随机关闭部分神经元,强迫模型学习多种特征组合(类似考试前藏起一半笔记,必须灵活运用知识)。

最后附上源代码,需要的自取。 备注:如果发现格式有多余的特殊字符,用普通浏览器打开复制应该没问题。

代码语言:javascript
复制
import streamlit as st
import akshare as ak
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
import plotly.graph_objects as go
# 获取股票数据
def get_stock_data(stock_code, start_date, end_date):
    start_date_str = start_date.strftime("%Y%m%d")
    end_date_str = end_date.strftime("%Y%m%d")
    stock_zh_a_hist_df = ak.stock_zh_a_hist(
        symbol=stock_code,
        period="daily",
        start_date=start_date_str,
        end_date=end_date_str,
        adjust="qfq"
    )
    return stock_zh_a_hist_df
# 数据预处理
def preprocess_data(data):
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaled_data = scaler.fit_transform(data[['收盘']].values)
    return scaled_data, scaler
# 创建训练数据
def create_dataset(dataset, time_step=1):
    dataX, dataY = [], []
    for i in range(len(dataset) - time_step - 1):
        a = dataset[i:(i + time_step), 0]
        dataX.append(a)
        dataY.append(dataset[i + time_step, 0])
    return np.array(dataX), np.array(dataY)
# 构建LSTM模型
def build_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(50, return_sequences=True, input_shape=input_shape))
    model.add(LSTM(50, return_sequences=False))
    model.add(Dense(25))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mean_squared_error')
    return model
# Streamlit应用
def app():
    st.title('股票涨跌预测系统')
    stock_code = st.text_input('请输入股票代码(例如:000001):', '000001')
    # 调整日期输入顺序并添加验证
    col1, col2 = st.columns(2)
    with col1:
        start_date = st.date_input('选择开始日期', pd.Timestamp.today().date() - pd.Timedelta(days=365))
    with col2:
        end_date = st.date_input('选择结束日期', pd.Timestamp.today().date())
    if start_date > end_date:
        st.error("错误:结束日期不能早于开始日期!")
        return
    if st.button('开始预测'):
        # 获取数据
        stock_data = get_stock_data(stock_code, start_date, end_date)
        if stock_data.empty:
            st.error("未获取到数据,请检查股票代码和日期范围!")
            return
        # 数据预处理
        scaled_data, scaler = preprocess_data(stock_data)
        time_step = 60
        # 确保有足够的数据进行预测
        if len(scaled_data) < time_step:
            st.error(f"数据量不足,至少需要{time_step}个交易日的数据!")
            return
        # 创建数据集
        X_train, y_train = create_dataset(scaled_data, time_step)
        if len(X_train) == 0:
            st.error("数据量不足创建训练集,请选择更长的日期范围!")
            return
        X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
        # 构建并训练模型
        model = build_lstm_model((X_train.shape[1], 1))
        model.fit(X_train, y_train, epochs=1, batch_size=1, verbose=2)
        # 获取最近收盘价
        latest_close = stock_data['收盘'].iloc[-1]
        # 进行预测
        test_data = scaled_data[-time_step:]
        test_data = test_data.reshape((1, time_step, 1))
        predicted_stock_price = model.predict(test_data)
        predicted_close = scaler.inverse_transform(predicted_stock_price)[0][0]
        # 计算涨跌
        change = predicted_close - latest_close
        change_percent = (change / latest_close) * 100
        # 显示结果
        st.subheader("预测结果")
        col_a, col_b = st.columns(2)
        with col_a:
            st.metric("最近收盘价", f"{latest_close:.2f}")
        with col_b:
            display_text = f"{predicted_close:.2f}"
            delta_sign = ""
            if change != 0:
                delta_sign = "↑" if change > 0 else "↓"
            st.metric(
                label="预测收盘价",
                value=f"{predicted_close:.2f}",
                delta=f"{delta_sign}{abs(change):.2f} ({abs(change_percent):.2f}%)"
            )
        # 可视化图表
        fig = go.Figure()
        # 添加历史数据
        fig.add_trace(
            go.Scatter(
                x=stock_data['日期'],
                y=stock_data['收盘'],
                name='历史收盘价',
                line=dict(color='#1f77b4', width=2),
                hovertemplate='日期: %{x}<br>收盘价: %{y:.2f}'
            )
        )
        # 添加预测点
        last_date = pd.to_datetime(stock_data['日期'].iloc[-1])
        next_date = last_date + pd.Timedelta(days=1)
        pred_color = 'green' if predicted_close >= latest_close else 'red'
        symbol_icon = 'triangle-up' if predicted_close >= latest_close else 'triangle-down'
        fig.add_trace(
            go.Scatter(
                x=[next_date],
                y=[predicted_close],
                name='预测收盘价',
                mode='markers',
                marker=dict(
                    color=pred_color,
                    size=12,
                    symbol=symbol_icon,
                    line=dict(width=2, color='white')
                ),
                hovertemplate=f'预测日期: {next_date.strftime("%Y-%m-%d")}<br>预测价格: {predicted_close:.2f}'
            )
        )
        # 更新布局
        fig.update_layout(
            title=dict(
                text=f'{stock_code} 股票价格走势预测',
                x=0.05,
                xanchor='left',
                font=dict(size=20)
            ),
            xaxis=dict(
                title='日期',
                rangeslider=dict(visible=True),
                type='date'
            ),
            yaxis=dict(
                title='收盘价 (元)',
                tickprefix='¥'
            ),
            hoverlabel=dict(
                bgcolor="white",
                font_size=12,
            ),
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=50, r=30, t=80, b=50),
            plot_bgcolor='rgba(240,240,240,0.9)',
            height=500
        )
        # 添加涨跌注释
        fig.add_annotation(
            x=next_date,
            y=predicted_close,
            text=f'{change_percent:.2f}%',
            showarrow=True,
            arrowhead=2,
            ax=0,
            ay=-40,
            font=dict(
                color=pred_color,
                size=14
            )
        )
        st.plotly_chart(fig, use_container_width=True)
if __name__ == '__main__':
    app()
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-05-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 子晓聊技术 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、LSTM的三大“智能开关”
  • 二、数据准备的三大步骤
  • 三、训练模型的三个关键
  • 最后附上源代码,需要的自取。 备注:如果发现格式有多余的特殊字符,用普通浏览器打开复制应该没问题。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档