from pathlib import Path
import pandas as pd
import polars as pl
from loguru import logger
from datetime import datetime
from alpha.dataset.utility import DataProxy
from config import DATA_DIR
def compare_symbol_dates(df, symbol1: str, symbol2: str):
df1 = df.filter(pl.col("symbol") == symbol1)
df2 = df.filter(pl.col("symbol") == symbol2)
dates1 = df1.select(pl.col("date").unique()).to_series()
dates2 = df2.select(pl.col("date").unique()).to_series()
diff1 = dates1.filter(~dates1.is_in(dates2))
diff2 = dates2.filter(~dates2.is_in(dates1))
if not diff1.is_empty():
print(f"=== {symbol1} 独有的日期数据 ===")
print(df1.filter(pl.col("date").is_in(diff1)))
else:
print(f"{symbol1} 和 {symbol2} 的日期完全一致")
if not diff2.is_empty():
print(f"\n=== {symbol2} 独有的日期数据 ===")
print(df2.filter(pl.col("date").is_in(diff2)))
else:
print(f"{symbol2} 和 {symbol1} 的日期完全一致")
class PolarDataloader:
def __init__(self, symbols:list[str],
start:str = '20100101',
end: str = datetime.now().strftime('%Y%m%d'),
fields=[],
names=[],
path:Path=DATA_DIR.joinpath('quotes')):
self.path = path
self.symbols = symbols
self.start = start
self.end = end
self.df = self._load_df_from_csvs()
if len
(fields) == len(names) and len(names) > 0:
results = []
for name,field in zip(names, fields):
if field == '':continue
expr = self.calculate_by_expression(self.df, field)
results.append(expr['data'].alias(name))
self.df = self.df.with_columns(results)
self.df = self.df.drop_nans()
def _load_df_from_csvs(self, symbols=None):
if not symbols:
symbols = self.symbols
if not self.end:
self.end = datetime.now().strftime("%Y%m%d")
dfs: list = []
for s in symbols:
file_path: Path = self.path.joinpath(f"{s}.csv")
if not file_path.exists():
logger.error(f"File {file_path} does not exist")
continue
df: pl.DataFrame = pl.read_csv(file_path, schema_overrides={'date': pl.Utf8})
df = df.sort("date")
df = df.filter((pl.col("date") >= self.start) & (pl.col("date") <= self.end))
df = df.with_columns(
pl.col("open"),
pl.col("high"),
pl.col("low"),
pl.col("close"),
pl.col("volume"),
)
if df.is_empty():
continue
dfs.append(df)
start_dates = [df.get_column("date").min() for df in dfs]
latest_start_date = max(start_dates)
print("统一截取的起始日期:", latest_start_date)
dfs_cut = [
df.filter(pl.col("date") >= latest_start_date)
for df in dfs
]
result_df: pl.DataFrame = pl.concat(dfs_cut)
return result_df
def calculate_by_expression(self, df: pl.DataFrame, expression: str) -> pl.DataFrame:
from alpha.dataset.ts_function import (
ts_delay,
ts_min, ts_max,
ts_argmax, ts_argmin,
ts_rank, ts_sum,
ts_mean, ts_std,
ts_slope, ts_quantile,
ts_rsquare, ts_resi,
ts_corr,
ts_less, ts_greater,
ts_log, ts_abs
)
from alpha.dataset.ts_function import ts_delay as shift
from alpha.dataset.expr_extends import trend_score, roc, RSRS
from alpha.dataset.ts_function import ts_mean as ma
from alpha.dataset.ts_function import ts_slope as slope
d: dict
= locals()
for column in df.columns:
if column in {"date", "symbol"}:
continue
column_df = df[["date", "symbol", column]]
d[column] = DataProxy(column_df)
other: DataProxy = eval(expression, {}, d)
return other.df
def get_col_df(self, col='close'):
df_col = self.df.pivot(values=col, index='date', on='symbol',aggregate_function = "first").sort('date')
df_col = df_col.to_pandas()
df_col.set_index('date', inplace=True)
df_col = df_col.ffill()
df_col.index = pd.to_datetime(df_col.index)
return df_col
def get_col_df_by_symbols(self,symbols: list[str], col='close'):
df_all = self._load_df_from_csvs(symbols=symbols)
if col not in df_all.columns:
logger.error('{}列不存在')
return None
df_close = df_all.pivot(values=col, index='date', on='symbol').sort('date')
df_col = df_close.to_pandas()
df_col.set_index('date', inplace=True)
df_col = df_col.ffill()
df_col.index = pd.to_datetime(df_col.index)
return df_col
if __name__ == '__main__':
loader = PolarDataloader(symbols=['510300.SH','159915.SZ'],start='20200101', names=['roc_5'], fields=['close/shift(close,5)-1'])
df = loader.get_col_df('roc_5')
print(df)
df = loader.get_col_df('high')
print(df)
df = loader.get_col_df_by_symbols(symbols=['159915.SZ'],col='high')
print(df)