from pathlib import Pathimport pandas as pdimport polars as plfrom loguru import loggerfrom datetime import datetime
from alpha.dataset.utility import DataProxyfrom config import DATA_DIR
def compare_symbol_dates(df, symbol1: str, symbol2: str): df1 = df.filter(pl.col("symbol") == symbol1) df2 = df.filter(pl.col("symbol") == symbol2)
dates1 = df1.select(pl.col("date").unique()).to_series() dates2 = df2.select(pl.col("date").unique()).to_series()
diff1 = dates1.filter(~dates1.is_in(dates2)) diff2 = dates2.filter(~dates2.is_in(dates1))
if not diff1.is_empty(): print(f"=== {symbol1} 独有的日期数据 ===") print(df1.filter(pl.col("date").is_in(diff1))) else: print(f"{symbol1} 和 {symbol2} 的日期完全一致")
if not diff2.is_empty(): print(f"\n=== {symbol2} 独有的日期数据 ===") print(df2.filter(pl.col("date").is_in(diff2))) else: print(f"{symbol2} 和 {symbol1} 的日期完全一致")
class PolarDataloader: def __init__(self, symbols:list[str], start:str = '20100101', end: str = datetime.now().strftime('%Y%m%d'), fields=[], names=[], path:Path=DATA_DIR.joinpath('quotes')): self.path = path self.symbols = symbols self.start = start self.end = end self.df = self._load_df_from_csvs() if len
(fields) == len(names) and len(names) > 0: results = [] for name,field in zip(names, fields): if field == '':continue expr = self.calculate_by_expression(self.df, field) results.append(expr['data'].alias(name)) self.df = self.df.with_columns(results) self.df = self.df.drop_nans()
def _load_df_from_csvs(self, symbols=None): if not symbols: symbols = self.symbols
if not self.end: self.end = datetime.now().strftime("%Y%m%d") dfs: list = [] for s in symbols: file_path: Path = self.path.joinpath(f"{s}.csv") if not file_path.exists(): logger.error(f"File {file_path} does not exist") continue
df: pl.DataFrame = pl.read_csv(file_path, schema_overrides={'date': pl.Utf8}) df = df.sort("date")
df = df.filter((pl.col("date") >= self.start) & (pl.col("date") <= self.end))
df = df.with_columns( pl.col("open"), pl.col("high"), pl.col("low"), pl.col("close"), pl.col("volume"), )
if df.is_empty(): continue dfs.append(df)
start_dates = [df.get_column("date").min() for df in dfs]
latest_start_date = max(start_dates) print("统一截取的起始日期:", latest_start_date)
dfs_cut = [ df.filter(pl.col("date") >= latest_start_date) for df in dfs ]
result_df: pl.DataFrame = pl.concat(dfs_cut) return result_df
def calculate_by_expression(self, df: pl.DataFrame, expression: str) -> pl.DataFrame: from alpha.dataset.ts_function import ( ts_delay, ts_min, ts_max, ts_argmax, ts_argmin, ts_rank, ts_sum, ts_mean, ts_std, ts_slope, ts_quantile, ts_rsquare, ts_resi, ts_corr, ts_less, ts_greater, ts_log, ts_abs ) from alpha.dataset.ts_function import ts_delay as shift from alpha.dataset.expr_extends import trend_score, roc, RSRS from alpha.dataset.ts_function import ts_mean as ma from alpha.dataset.ts_function import ts_slope as slope
d: dict
= locals()
for column in df.columns: if column in {"date", "symbol"}: continue
column_df = df[["date", "symbol", column]] d[column] = DataProxy(column_df)
other: DataProxy = eval(expression, {}, d)
return other.df
def get_col_df(self, col='close'):
df_col = self.df.pivot(values=col, index='date', on='symbol',aggregate_function = "first").sort('date')
df_col = df_col.to_pandas() df_col.set_index('date', inplace=True) df_col = df_col.ffill()
df_col.index = pd.to_datetime(df_col.index) return df_col
def get_col_df_by_symbols(self,symbols: list[str], col='close'): df_all = self._load_df_from_csvs(symbols=symbols) if col not in df_all.columns: logger.error('{}列不存在') return None df_close = df_all.pivot(values=col, index='date', on='symbol').sort('date') df_col = df_close.to_pandas() df_col.set_index('date', inplace=True) df_col = df_col.ffill() df_col.index = pd.to_datetime(df_col.index) return df_col
if __name__ == '__main__': loader = PolarDataloader(symbols=['510300.SH','159915.SZ'],start='20200101', names=['roc_5'], fields=['close/shift(close,5)-1']) df = loader.get_col_df('roc_5') print(df) df = loader.get_col_df('high') print(df) df = loader.get_col_df_by_symbols(symbols=['159915.SZ'],col='high') print(df)