from dataclasses import dataclass, asdict, fields
from typing import List, Dict, Any
import bt
import numpy as np
import pandas as pd
class SelectTopK(bt.AlgoStack):
def __init__(self, signal, K=1, dropN=0, sort_descending=True, all_or_none=False, filter_selected=True):
super(SelectTopK, self).__init__(bt.algos.SetStat(signal),
bt.algos.SelectN(int(K) + int(dropN), sort_descending, all_or_none,
filter_selected))
self.dropN = dropN
def __call__(self, target):
super(SelectTopK, self).__call__(target)
if self.dropN > 0:
sel = target.temp["selected"]
if self.dropN >= len(sel):
target.temp['selected'] = []
else:
target.temp["selected"] = target.temp["selected"][self.dropN:]
return True
return True
from matplotlib import rcParams
from dataclasses import dataclass, field
from polar_loader import PolarDataloader
rcParams['font.family'] = 'SimHei'
@dataclass
class MultiStrategies:
name: str = '多策略组合'
id_or_symbols: List[str] = field(default_factory=list)
start_date: str = '20100101'
end_date: str = None
benchmark: str = '510300.SH'
weight: str = 'WeighEqually'
select: str = 'SelectAll'
weight_fixed: Dict[str, int] = field(default_factory=dict)
period: str = 'RunMonthly'
@dataclass
class Task:
name: str = '策略'
symbols: List[str] = field(default_factory=list)
min_date: bool = False
start_date: str = '20100101'
end_date: str = None
benchmark: str = '510300.SH'
select: str = 'SelectAll'
select_buy: List[str] = field(default_factory=list)
buy_at_least_count: int = 0
select_sell: List[str] = field(default_factory=list)
sell_at_least_count: int = 1
order_by_signal: str = ''
order_by_topK: int = 1
order_by_dropN: int = 0
order_by_DESC: bool = True
weight: str = 'WeighEqually'
weight_fixed: Dict[str, int] = field(default_factory=dict)
period: str = 'RunDaily'
period_days: int = None
@dataclass
class StrategyConfig:
name: str = '策略'
desc: str = '策略描述'
config_json: Dict[str, int] = field(default_factory=dict)
author: str = ''
import importlib
class Engine:
def __init__(self, path='quotes'):
self.path = path
def _parse_rules(self, task: Task):
def _rules(rules, at_least):
if not rules or len(rules) == 0:
return None
all = None
for r in rules:
if r == '':
continue
df_r = self.loader.get_col_df(r)
if df_r is not None:
df_r = df_r.replace({True: 1, False: 0})
df_r = df_r.astype('Int64')
if all is None:
all = df_r
else:
all += df_r
return all >= at_least
buy_at_least_count = task.buy_at_least_count
if buy_at_least_count <= 0:
buy_at_least_count = len
(task.select_buy)
all_buy = _rules(task.select_buy, at_least=buy_at_least_count)
all_sell = _rules(task.select_sell, task.sell_at_least_count)
if all_sell is not None:
all_sell = all_sell.fillna(True)
if all_buy is not None:
all_buy = all_buy.fillna(False)
return all_buy, all_sell
def _get_algos(self, task: Task):
bt_algos = importlib.import_module('bt.algos')
if task.period == 'RunEveryNPeriods':
algo_period = bt.algos.RunEveryNPeriods(n=task.period_days, run_on_last_date=True)
else:
algo_period = getattr(bt_algos, task.period)(run_on_last_date=True)
algo_select_where = None
signal_buy, signal_sell = self._parse_rules(task)
if signal_buy is not None or signal_sell is not None:
df_close = self.loader.get_col_df('close')
if signal_buy is None:
select_signal = np.ones(df_close.shape)
select_signal = pd.DataFrame(select_signal, columns=df_close.columns, index=df_close.index)
else:
select_signal = np.where(signal_buy, 1, np.nan)
if signal_sell is not None:
select_signal = np.where(signal_sell, 0, select_signal)
select_signal = pd.DataFrame(select_signal, index=df_close.index, columns=df_close.columns)
select_signal.ffill(inplace=True)
select_signal.fillna(0, inplace=True)
algo_select_where = bt.algos.SelectWhere(signal=select_signal)
algo_order_by = None
if task.order_by_signal:
signal_order_by = self.loader.get_col_df(col=task.order_by_signal)
algo_order_by = SelectTopK(signal=signal_order_by, K=task.order_by_topK, dropN=task.order_by_dropN,
sort_descending=task.order_by_DESC)
algos = []
algos.append(algo_period)
if algo_select_where:
algos.append(algo_select_where)
else:
algos.append(bt.algos.SelectAll())
if algo_order_by:
algos.append(algo_order_by)
if task.weight == 'WeighERC':
algos.insert(0, bt.algos.RunAfterDays(days=256))
algo_weight = getattr(bt_algos, task.weight)()
elif task.weight == 'WeighSpecified':
algo_weight = bt.algos.WeighSpecified(**task.weight_fixed)
else:
if task.weight == 'WeighInVol':
task.weight = 'WeighInvVol'
algo_weight = getattr(bt_algos, task.weight)()
algos.append(algo_weight)
algos.append(bt.algos.Rebalance())
return algos
def run_tasks(self, tasks: list[Task]):
bkts = []
benchmarks = []
for task in tasks:
df = CSVDataloader.get_df(task.symbols, True, task.start_date, task.end_date)
if len(task.symbols):
fields = list(set(task.select_buy + task.select_sell + [task.order_by_signal]))
names = fields
if len(fields):
df = CSVDataloader.calc_expr(df, fields, names=names)
s = bt.Strategy(task.name, self._get_algos(task, df))
df_close = CSVDataloader.get_col_df(df, 'close')
bkt = bt.Backtest(s, df_close, name=task.name)
bkts.append(bkt)
benchmarks.append(task.benchmark)
for bench in list(set(benchmarks)):
data = CSVDataloader.get([bench])
s = bt.Strategy(bench, [bt.algos.RunOnce(),
bt.algos.SelectAll(),
bt.algos.WeighEqually(),
bt.algos.Rebalance()])
stra = bt.Backtest(s, data, name="基准:" + bench)
bkts.append(stra)
res = bt.run(*bkts)
self.res = res
return res
def _get_bkt(self, task):
if type(task) is str:
return task
s = bt.Strategy(task.name, self._get_algos(task))
df_close = self.loader.get_col_df('close')
bkt = bt.Backtest(s, df_close, name='策略', integer_positions=True )
return bkt
def run(self, task: Task, path=None):
fields = list(set(task.select_buy + task.select_sell + [task.order_by_signal]))
names = fields
self.loader = PolarDataloader(task.symbols,task.start_date,task.end_date,fields, names,min_date=task.min_date,path=path)
bkt = self._get_bkt(task)
bkts = [bkt]
for bench in [task.benchmark]:
data =self.loader.get_col_df_by_symbols([bench])
s = bt.Strategy(bench, [bt.algos.RunOnce(),
bt.algos.SelectAll(),
bt.algos.WeighEqually(),
bt.algos.Rebalance()])
stra = bt.Backtest(s, data, name='benchmark', progress_bar=True)
bkts.append(stra)
res = bt.run(*bkts)
self.res = res
return res
def _get_task_by_id(self, id: str):
def astock_rolling():
t = Task()
t.name = '大小盘轮动'
t.start_date = '20200101'
t.weight = 'WeighEqually'
t.symbols = [
'159915.SZ'
]
t.benchmark = '512890.SH'
t.select_buy = ['roc(close,20)>0.08']
t.select_sell = ['roc(close,20)<0']
t.order_by_signal = 'roc(close,20)'
return t
return astock_rolling()
def run_multi_tasks(self, strategy: MultiStrategies):
tasks = []
for t in strategy.id_or_symbols:
if len(t) 10:
tasks.append(t)
else:
tasks.append(self._get_task_by_id(t))
instruments = []
for t in tasks:
if type(t) is Task:
instruments.extend(t.symbols)
else:
instruments.append(t)
instruments = list(set(instruments))
data = CSVDataloader.get_df(instruments, set_index=True, start_date=strategy.start_date)
data.dropna(inplace=True)
df_close = CSVDataloader.get_col_df(data)
print(df_close)
children = []
for t in tasks:
if type(t) is str:
children.append(t)
else:
children.append(self._get_bkt(t).strategy)
bt_algos = importlib.import_module('bt.algos')
combined_strategy = bt.Strategy(
strategy.name,
algos=[
getattr(bt_algos, strategy.period)(),
getattr
(bt_algos, strategy.select)(),
getattr(bt_algos, strategy.weight)(),
bt.algos.Rebalance()
],
children=children
)
combined_test = bt.Backtest(
combined_strategy,
df_close,
integer_positions=False,
progress_bar=False
)
res = bt.run(combined_test)
return res
def get_equities(self):
quotes = (self.res.prices.pct_change() + 1).cumprod().dropna()
quotes['date'] = quotes.index
quotes.index = pd.to_datetime(quotes.index).map(lambda x: x.value)
quotes = quotes[['策略', 'benchmark']]
dict = quotes.to_dict(orient='series')
results = {}
for k, s in dict.items():
result = list(zip(s.index, s.values))
results[k] = result
import requests, json
def dict_to_task(data: Dict[str, Any]) -> Task:
"""将字典安全转换为 Task 实例"""
valid_fields = {f.name for f in fields(Task)}
filtered_data = {}
for key, value in data.items():
if key not in valid_fields:
continue
field_type = Task.__annotations__.get(key)
filtered_data[key] = value
return Task(**filtered_data)
if __name__ == '__main__':
t = Task()
t.name = '全球大类资产-修正斜率轮动'
etfs = [
'510300.SH',
'159915.SZ',
'518880.SH',
'513100.SH',
'159985.SZ',
'511880.SH',
]
t.symbols = etfs
t.select_sell = ["roc(close,21)>0.17"]
t.order_by_signal = "roc(close,22)"
e = Engine()
res = e.run(t)
print(res.stats)
import matplotlib.pyplot as plt
res.plot()
plt.show()