import pandas as pd
import numpy as np
import talib as ta
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
def convert_drawgbk(ax, condition, color1, color2):
"""模拟DRAWGBK渐变背景效果"""
mask = condition.values
cmap = LinearSegmentedColormap.from_list('custom', [color1, color2], N=256)
ax.imshow(mask.reshape(1,-1), cmap=cmap, aspect='auto', extent=(0,len(mask),0.5,1.5))
def main():
data = pd.DataFrame({
'open': np.random.rand(100)*100,
'high': np.random.rand(100)*100,
'low': np.random.rand(100)*100,
'close': np.random.rand(100)*100,
'volume': np.random.randint(100000,1000000,100)
}).sort_values('date')
# 参数设置
N, M, N1 = 9, 14, 6
# 计算基础指标
TYP = (data['high'] + data['low'] + data['close']) / 3
MA_TYP = ta.EMA(TYP, M)
CC1 = (TYP - MA_TYP) / (0.015 * ta.SMA(np.abs(TYP - MA_TYP), M))
CC1 = (CC1 - CC1.mean()) / CC1.std() # 标准化处理
# 计算CCI指标
C1 = ta.SMA(CC1, 3)
C2 = ta.SMA(C1, 3)
C3 = 3*C1 - 2*C2
CB1 = ta.EMA(C3, 5)
# 绘制带状区域CCI+RSI+KDJ三重叠加共振指标
ax1 = plt.subplot(211)
plt.plot(data.index, CC1, color='#FF00FF', label='CCI')
convert_drawgbk(ax1, CC1 > 0, (0,0,0.12), (0.06,0,
0.02)) # 渐变背景
# 绘制RSI指标CCI+RSI+KDJ三重叠加共振指标
RSI1 = ta.RSI(data['close'], timeperiod=N1)
RSI = (RSI1 - 50) * 5
R1 = ta.SMA(RSI, 3)
D1 = ta.SMA(R1, 3)
J1 = 3*R1 - 2*D1
RR1 = ta.EMA(J1, 5)
# 绘制KDJ指标CCI+RSI+KDJ三重叠加共振指标
RSV = (data['close'] - data['low'].rolling(N).min()) / \
(data['high'].rolling(N).max() - data['low'].rolling(N).min()) * 100
K = ta.SMA(RSV, 3)
D = ta.SMA(K, 3)
J = 3*K - 2*D
BB1 = ta.EMA(J, 5)
# 绘制信号指标CCI+RSI+KDJ三重叠加共振指标
ax2 = plt.subplot(212, sharex=ax1)
plt.plot(data.index, RSI, color='#FF00FF', label='RSI')
plt.plot(data.index, K, color='#FFFF00', label='K')
plt.plot(data.index, D, color='#00FFFF', label='D')
plt.plot(data.index, J, color='#00FF00', label='J')
# 绘制信号CCI+RSI+KDJ三重叠加共振指标
buy_signal = (RR1 > RR1.shift(1)) & (CB1 -100) & (BB1 -100)
ax1.scatter(buy_signal[buy_signal].index,
CC1[buy_signal],
marker='^', color='#640000',
s=50, edgecolors='#FFFFFF')
# 图表美化
plt.setp(ax1.get_xticklabels(), visible=False)
ax2.legend(loc='upper left')
ax1.set_ylim(-150, 150)
ax2.set_ylim(0, 100)
plt.show()
if __name__ == "__main__":
main()