
去年,我帮一个朋友重构他的量化代码。
打开项目的那一刻,我惊呆了:
1
2
3
4
5
6
7
8
9
strategy_v1.py
strategy_v2.py
strategy_v3_final.py
strategy_v3_final_真的final.py
backtest.py
backtest_new.py
data_fetch.py
utils.py
temp.py
每个策略都是从头开始写,数据获取、信号计算、回测逻辑全部耦合在一起。
我问他:「如果你想改一个数据源,要改多少文件?」
他挠挠头:「可能...十几个?」
这就是没有框架的代价。
今天我们来解决这个痛点:从零构建一个可扩展的Rust量化策略框架。
好的框架需要满足三个条件:
我们的框架结构:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
┌─────────────────────────────────────────────┐
│ 量化策略框架 │
├─────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐│
│ │ DataLoader│ → │ Signal │ → │Execution ││
│ └──────────┘ └──────────┘ └──────────┘│
│ ↑ ↑ ↑ │
│ │ │ │ │
│ ┌────────────────────────────────────┐ │
│ │ Trait 定义 │ │
│ └────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────┘
三层架构:
Trait是Rust的灵魂,它定义了「接口契约」。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use polars::prelude::*;
use async_trait::async_trait;
/// 数据源接口
#[async_trait]
pub trait DataLoader: Send + Sync {
/// 加载股票列表
async fn load_stock_list(&self) -> Result<Vec<String>, Box<dyn std::error::Error>>;
/// 加载日线数据
async fn load_daily_data(&self, code: &str) -> Result<DataFrame, Box<dyn std::error::Error>>;
/// 加载实时数据
async fn load_realtime_data(&self, codes: &[String]) -> Result<DataFrame, Box<dyn std::error::Error>>;
}
/// 策略接口
pub trait Strategy: Send + Sync {
/// 策略名称
fn name(&self) -> &str;
/// 初始化策略
fn init(&mut self, config: &StrategyConfig) -> Result<(), Box<dyn std::error::Error>>;
/// 生成信号
fn generate_signals(&self, df: &DataFrame) -> Result<SignalResult, Box<dyn std::error::Error>>;
}
/// 交易执行接口
#[async_trait]
pub trait Broker: Send + Sync {
/// 获取账户信息
async fn get_account(&self) -> Result<Account, Box<dyn std::error::Error>>;
/// 下单
async fn place_order(&self, order: &Order) -> Result<OrderResult, Box<dyn std::error::Error>>;
/// 撤单
async fn cancel_order(&self, order_id: &str) -> Result<(), Box<dyn std::error::Error>>;
}
这些Trait定义了框架的「骨架」。任何具体实现都必须遵循这个契约。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use std::path::PathBuf;
pub struct CsvDataLoader {
data_dir: PathBuf,
}
impl CsvDataLoader {
pub fn new(data_dir: &str) -> Self {
Self {
data_dir: PathBuf::from(data_dir),
}
}
}
#[async_trait]
impl DataLoader for CsvDataLoader {
async fn load_stock_list(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let path = self.data_dir.join("stock_list.csv");
let df = CsvReader::from_path(path)?.finish()?;
let codes = df.column("ts_code")?
.str()?
.into_iter()
.flatten()
.map(|s| s.to_string())
.collect();
Ok(codes)
}
async fn load_daily_data(&self, code: &str) -> Result<DataFrame, Box<dyn std::error::Error>> {
let path = self.data_dir.join("daily").join(format!("{}.csv", code));
let df = CsvReader::from_path(path)?.finish()?;
Ok(df)
}
async fn load_realtime_data(&self, codes: &[String]) -> Result<DataFrame, Box<dyn std::error::Error>> {
// 实时数据通常需要API调用,这里返回空DataFrame
unimplemented!("实时数据需要API实现")
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
use reqwest::Client;
pub struct TushareDataLoader {
client: Client,
token: String,
}
impl TushareDataLoader {
pub fn new(token: &str) -> Self {
Self {
client: Client::new(),
token: token.to_string(),
}
}
}
#[async_trait]
impl DataLoader for TushareDataLoader {
async fn load_daily_data(&self, code: &str) -> Result<DataFrame, Box<dyn std::error::Error>> {
let url = "http://api.tushare.pro";
let body = serde_json::json!({
"api_name": "daily",
"token": self.token,
"ts_code": code,
"fields": "ts_code,trade_date,open,high,low,close,vol,amount"
});
let resp = self.client.post(url)
.json(&body)
.send()
.await?
.json::<serde_json::Value>()
.await?;
// 解析并转换为DataFrame
let df = parse_tushare_response(&resp)?;
Ok(df)
}
// ... 其他方法实现
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
pub struct MeanReversionStrategy {
name: String,
period: usize,
std_dev: f64,
}
impl MeanReversionStrategy {
pub fn new(period: usize, std_dev: f64) -> Self {
Self {
name: "MeanReversion".to_string(),
period,
std_dev,
}
}
}
impl Strategy for MeanReversionStrategy {
fn name(&self) -> &str {
&self.name
}
fn init(&mut self, config: &StrategyConfig) -> Result<(), Box<dyn std::error::Error>> {
self.period = config.get("period").unwrap_or(20);
self.std_dev = config.get("std_dev").unwrap_or(2.0);
Ok(())
}
fn generate_signals(&self, df: &DataFrame) -> Result<SignalResult, Box<dyn std::error::Error>> {
let result = df.lazy()
.with_columns([
col("close")
.rolling_mean(self.period as i64)
.alias("ma"),
col("close")
.rolling_std(self.period as i64)
.alias("std"),
])
.with_columns([
((col("close") - col("ma")) / col("std")).alias("zscore"),
])
.with_columns([
col("zscore").lt(-self.std_dev).alias("buy_signal"),
col("zscore").gt(self.std_dev).alias("sell_signal"),
])
.collect()?;
let signals = result.filter(&col("buy_signal").or(col("sell_signal")))?;
Ok(SignalResult {
signals,
strategy_name: self.name.clone(),
})
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
pub struct MomentumStrategy {
name: String,
rsi_period: usize,
rsi_oversold: f64,
rsi_overbought: f64,
}
impl MomentumStrategy {
pub fn new() -> Self {
Self {
name: "Momentum".to_string(),
rsi_period: 14,
rsi_oversold: 30.0,
rsi_overbought: 70.0,
}
}
}
impl Strategy for MomentumStrategy {
fn name(&self) -> &str {
&self.name
}
fn init(&mut self, config: &StrategyConfig) -> Result<(), Box<dyn std::error::Error>> {
self.rsi_period = config.get("rsi_period").unwrap_or(14);
self.rsi_oversold = config.get("rsi_oversold").unwrap_or(30.0);
self.rsi_overbought = config.get("rsi_overbought").unwrap_or(70.0);
Ok(())
}
fn generate_signals(&self, df: &DataFrame) -> Result<SignalResult, Box<dyn std::error::Error>> {
// RSI计算逻辑...
// 返回信号结果
todo!()
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
pub struct SimulatedBroker {
cash: f64,
positions: HashMap<String, Position>,
commission_rate: f64,
}
impl SimulatedBroker {
pub fn new(initial_cash: f64, commission_rate: f64) -> Self {
Self {
cash: initial_cash,
positions: HashMap::new(),
commission_rate,
}
}
}
#[async_trait]
impl Broker for SimulatedBroker {
async fn get_account(&self) -> Result<Account, Box<dyn std::error::Error>> {
let total_value = self.cash + self.positions.values()
.map(|p| p.quantity * p.current_price)
.sum::<f64>();
Ok(Account {
cash: self.cash,
total_value,
positions: self.positions.clone(),
})
}
async fn place_order(&mut self, order: &Order) -> Result<OrderResult, Box<dyn std::error::Error>> {
match order.side {
OrderSide::Buy => {
let cost = order.price * order.quantity;
let commission = (cost * self.commission_rate).max(5.0);
if self.cash < cost + commission {
return Err("资金不足".into());
}
self.cash -= cost + commission;
let position = self.positions.entry(order.code.clone()).or_insert(Position {
code: order.code.clone(),
quantity: 0.0,
avg_cost: 0.0,
current_price: order.price,
});
let total_quantity = position.quantity + order.quantity;
position.avg_cost = (position.avg_cost * position.quantity + cost) / total_quantity;
position.quantity = total_quantity;
Ok(OrderResult {
order_id: uuid::Uuid::new_v4().to_string(),
status: OrderStatus::Filled,
filled_quantity: order.quantity,
filled_price: order.price,
})
}
OrderSide::Sell => {
// 卖出逻辑...
todo!()
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
pub struct QuantEngine {
data_loader: Box<dyn DataLoader>,
strategies: Vec<Box<dyn Strategy>>,
broker: Box<dyn Broker>,
config: EngineConfig,
}
impl QuantEngine {
pub fn new(
data_loader: Box<dyn DataLoader>,
broker: Box<dyn Broker>,
config: EngineConfig,
) -> Self {
Self {
data_loader,
strategies: Vec::new(),
broker,
config,
}
}
pub fn add_strategy(&mut self, strategy: Box<dyn Strategy>) {
self.strategies.push(strategy);
}
pub async fn run_backtest(&mut self) -> Result<BacktestResult, Box<dyn std::error::Error>> {
// 加载股票列表
let stock_list = self.data_loader.load_stock_list().await?;
let mut all_signals = Vec::new();
// 遍历每只股票
for code in &stock_list {
// 加载数据
let df = self.data_loader.load_daily_data(code).await?;
// 遍历每个策略
for strategy in &self.strategies {
let result = strategy.generate_signals(&df)?;
all_signals.push(result);
}
}
// 执行回测
let backtest_result = self.execute_backtest(all_signals).await?;
Ok(backtest_result)
}
async fn execute_backtest(&mut self, signals: Vec<SignalResult>) -> Result<BacktestResult, Box<dyn std::error::Error>> {
// 回测执行逻辑
todo!()
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 创建数据加载器
let data_loader = Box::new(CsvDataLoader::new("./data"));
// 创建模拟券商
let broker = Box::new(SimulatedBroker::new(1_000_000.0, 0.0001));
// 创建引擎
let config = EngineConfig::default();
let mut engine = QuantEngine::new(data_loader, broker, config);
// 添加策略
engine.add_strategy(Box::new(MeanReversionStrategy::new(20, 2.0)));
engine.add_strategy(Box::new(MomentumStrategy::new()));
// 运行回测
let result = engine.run_backtest().await?;
// 输出结果
println!("年化收益: {:.2}%", result.annual_return * 100.0);
println!("夏普比率: {:.2}", result.sharpe_ratio);
println!("最大回撤: {:.2}%", result.max_drawdown * 100.0);
Ok(())
}
使用YAML配置文件:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# config.yaml
engine:
initial_cash: 1000000
commission_rate: 0.0001
data:
source: csv
path: ./data
strategies:
- name: MeanReversion
enabled: true
params:
period: 20
std_dev: 2.0
- name: Momentum
enabled: true
params:
rsi_period: 14
rsi_oversold: 30
rsi_overbought: 70
加载配置:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct Config {
engine: EngineConfigSection,
data: DataConfigSection,
strategies: Vec<StrategyConfigSection>,
}
fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
let content = std::fs::read_to_string(path)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
有了这个框架,新增一个策略只需要:
不需要改框架代码,不需要重复写数据加载逻辑。
对比之前朋友的情况:
操作 | 无框架 | 有框架 |
|---|---|---|
新增策略 | 复制粘贴,改十几处 | 实现Trait,配置参数 |
改数据源 | 改10+文件 | 实现DataLoader |
调整参数 | 改代码,重新编译 | 改配置文件,重启 |
回测运行 | 每个策略单独运行 | 一次运行多策略 |
有些人觉得写框架浪费时间,不如直接写策略。
但我的经验是:框架是投资,不是成本。
一开始多花一点时间搭建框架,后续每个策略都能复用,开发效率提升10倍以上。
更重要的是,框架让策略代码变得可测试、可维护、可扩展——这是长期存活的基础。
下一篇,我们将聊聊策略参数优化系统:
敬请期待。
(全文完)