import json
import shelve
import pickle
from pathlib import Path
from datetime import datetime, timedelta
from collections import defaultdict
from functools import lru_cache, partial
import numpy as np
from loguru import logger
import polars as pl
from constant import Interval
from dataset import to_datetime, AlphaDataset, process_drop_na, process_cs_norm, Segment
from dataset.datasets.alpha_158 import Alpha158
from model import AlphaModel
from model.models.lgb_model import LgbModel
from strategy.strategies.equity_demo_strategy import EquityDemoStrategy
class AlphaLab:
def __init__(self, lab_path: str, data_path:str) -> None:
self.lab_path: Path = Path(lab_path)
self.data_path = Path(data_path)
self.dataset_path: Path = self.lab_path.joinpath("dataset")
self.model_path: Path = self.lab_path.joinpath("model")
self.signal_path: Path = self.lab_path.joinpath("signal")
for path in [
self.lab_path,
self.dataset_path,
self.model_path,
self.signal_path
]:
if not path.exists():
path.mkdir(parents=True)
def save_dataset(self, name: str, dataset: AlphaDataset) -> None:
"""Save dataset"""
file_path: Path = self.dataset_path.joinpath(f"{name}.pkl")
with open(file_path, mode="wb") as f:
pickle.dump(dataset, f)
def load_dataset(self, name: str) -> AlphaDataset | None:
"""Load dataset"""
file_path: Path = self.dataset_path.joinpath(f"{name}.pkl")
if not file_path.exists():
logger.error(f"Dataset file {name} does not exist")
return None
with open(file_path, mode="rb") as f:
dataset: AlphaDataset = pickle.load(f)
return dataset
def save_model(self, name: str, model: AlphaModel) -> None:
"""Save model"""
file_path: Path = self.model_path.joinpath(f"{name}.pkl")
with open(file_path, mode="wb") as f:
pickle.dump(model, f)
def load_model(self, name: str) -> AlphaModel | None:
"""Load model"""
file_path: Path = self.model_path.joinpath(f"{name}.pkl")
if not file_path.exists():
logger.error(f"Model file {name} does not exist")
return None
with open(file_path, mode="rb") as f:
model: AlphaModel = pickle.load(f)
return model
def remove_model(self, name: str) -> bool:
"""Remove model"""
file_path: Path = self.model_path.joinpath(f"{name}.pkl")
if not file_path.exists():
logger.error(f"Model file {name} does not exist")
return False
file_path.unlink()
return True
def list_all_models(self) -> list[str]:
"""List all models"""
return [file.stem for file in self.model_path.glob("*.pkl")]
def save_signal(self, name: str, signal: pl.DataFrame) -> None:
"""Save signal"""
file_path: Path = self.signal_path.joinpath(f"{name}.parquet")
signal.write_parquet(file_path)
def load_signal(self, name: str) -> pl.DataFrame | None:
"""Load signal"""
file_path: Path = self.signal_path.joinpath(f"{name}.parquet")
if not file_path.exists():
logger.error(f"Signal file {name} does not exist")
return None
return pl.read_parquet(file_path)
def remove_signal(self, name: str) -> bool:
"""Remove signal"""
file_path: Path = self.signal_path.joinpath(f"{name}.parquet")
if not file_path.exists():
logger.error(f"Signal file {name} does not exist")
return False
file_path.unlink()
return True
def list_all_signals(self) -> list[str]:
"""List all signals"""
return [file.stem for file in
self.model_path.glob("*.parquet")]
def load_bar_df(self,symbols: list[str],
start: datetime | str='20100101',
end: datetime | str=datetime.now().strftime('%Y%m%d'),
extended_days: int=20):
start = to_datetime(start) - timedelta(days=extended_days)
end = to_datetime(end) + timedelta(days=extended_days // 10)
dfs: list = []
for s in symbols:
file_path: Path = self.data_path.joinpath(f"{s}.csv")
if not file_path.exists():
logger.error(f"File {file_path} does not exist")
continue
df: pl.DataFrame = pl.read_csv(file_path, schema_overrides={'date': pl.Utf8})
df = df.with_columns(
pl.col("date").str.strptime(pl.Date, "%Y%m%d").alias("date")
)
df = df.filter((pl.col("date") >= start) & (pl.col("date") <= end))
df = df.with_columns(
pl.col("open"),
pl.col("high"),
pl.col("low"),
pl.col("close"),
pl.col("volume"),
)
if df.is_empty():
continue
dfs.append(df)
result_df: pl.DataFrame = pl.concat(dfs)
return result_df
def calc_exprs(self, df, names, fields):
from dataset.utility import calculate_by_expression
if len(fields) == len(names) and len(names) > 0:
results = []
for name,field in zip(names, fields):
if field == '':continue
expr = calculate_by_expression(df, field)
results.append(expr['data'].alias(name))
df = df.with_columns(results)
return df
if __name__ == '__main__':
from strategy import BacktestingEngine
symbols = ['510300.SH','159915.SZ']
lab = AlphaLab(lab_path='./run', data_path=
'D:\work\.aitrader_data\quotes_etf')
df = lab.load_bar_df(symbols=symbols, start='20100101', end='20250506',extended_days=10)
print(df)
dataset: AlphaDataset = Alpha158(
df,
train_period=("2010-01-01", "2014-12-31"),
valid_period=("2015-01-01", "2016-12-31"),
test_period=("2017-01-01", "2020-8-31"),
)
dataset.add_processor("learn", partial(process_drop_na, names=["label"]))
dataset.add_processor("learn", partial(process_cs_norm, names=["label"], method="zscore"))
name = 'etf轮动'
dataset.prepare_data(filters=None, max_workers=3)
lab.save_dataset(name=name,dataset=dataset)
dataset: AlphaDataset = lab.load_dataset(name)
model: AlphaModel = LgbModel(seed=42)
model.fit(dataset)
lab.save_model(name,model)
model = lab.load_model(name)
pre: np.ndarray = model.predict(dataset, Segment.TEST)
df_t: pl.DataFrame = dataset.fetch_infer(Segment.TEST)
df_t = df_t.with_columns(pl.Series(pre).alias("signal"))
signal: pl.DataFrame = df_t["date", "symbol", "signal"]
dataset.show_signal_performance(signal)
lab.save_signal(name,signal)
signal = lab.load_signal(name)
import bt
from bt.algos import *
from bt_algos_extend import SelectTopK
s = bt.Strategy('s1', [bt.algos.RunDaily(),
SelectTopK(signal=signal),
bt.algos.WeighEqually(),
bt.algos.Rebalance()])
engine = BacktestingEngine(lab)
engine.set_parameters(
vt_symbols=symbols,
interval=Interval.DAILY,
start=datetime(2017, 1, 1),
end=datetime(2020, 8, 1),
capital=100000000
)