core.a 源代码

"""
Created by x2h1z on 2021/2/25.
"""
import pandas as pd
import os

from enum import Enum
from typing import Union
from datetime import datetime, timedelta
from glob import glob
from concurrent.futures import ThreadPoolExecutor


[文档]class Market(Enum): SH = 0 SZ = 1
[文档]class A: """ A Backtesting Framework 通常你只需要继承 :class:`A` 然后实现它的两个方法即可, Simple Example:: class Demo(A): def __init__(self): self.date = datetime.today() # 设置回测所用到的数据所在路径 self.set_source_path("<source/path>") # 设置买入时段和卖出时段 self.init(143000, 30 * 60, 93000, 30 * 60) def on_day_stock(self, symbol_code: str, date: str, stock_df: pd.DataFrame) -> bool: # 因子计算等操作 pass def on_day_result(self): # 对回测结果的分析. 比如计算收益率 result = self.rate_of_return_calc(self.date, ["000001"]) if __name__ == '__main__': d = Demo() # 启动回测 d.start() """ COLUMNS = ["BeginTime", "EndTime", "OpenPrice", "HighPrice", "LowPrice", "ClosePrice", "ThisVolume", "ThisTurnover", "PreClose", "UpLimit", "DropLimit", "Volume", "Turnover"] def __init__(self): self.__buy_time = ... self.__b_T = ... self.__sell_time = ... self.__s_T = ... self.__top_limit = ... self.__encoding = ... self.__source_path = "" self.__thread_workers = 50 self.__top: dict = {} self.__market: Market = ... self.__symbol_filter: list = [] def _c(self, filepath, date): symbol_code = os.path.split(filepath)[1].split('.')[0] df = pd.read_csv(filepath, encoding=self.__encoding, names=self.COLUMNS, index_col=False) # self.condition(symbol_code, date, df) if self.on_day_stock(symbol_code, date, df): self.__top[symbol_code] = df
[文档] def set_source_path(self, source_path: str): """设置资源路径 Args: source_path: 资源路径 Returns: """ if not os.path.exists(source_path): print(f"source path {source_path} not exists.") return self.__source_path = source_path
[文档] def init(self, buy_time: int, buy_time_end: int, sell_time: int, sell_time_end: int, encoding: str = "utf-8") -> None: """初始化参数配置 Args: buy_time: 买入时间 buy_time_end: 买入时长(秒) sell_time: 卖出时间 sell_time_end: 卖出时长(秒) encoding: 文件的编码格式. 作用于程序的所有读取和写入 Returns: None """ self.__buy_time = buy_time self.__b_T = buy_time_end self.__sell_time = sell_time self.__s_T = sell_time_end self.__encoding = encoding
@staticmethod def _calc_average_price(df: pd.DataFrame) -> float: """计算均价""" turnover = df['Turnover'].iloc[-1] - df['Turnover'].iloc[0] volume = df['Volume'].iloc[-1] - df['Volume'].iloc[0] average_price = turnover / volume return average_price def _rate_of_return_calc(self, date: str): """计算收益率""" # 权重 average_weight = 1 / len(self.__top) for symbol_code, df in self.__top.items(): buy_time_end = datetime.strptime(str(self.__buy_time), "%H%M%S") + timedelta(seconds=self.__b_T) buy_time_end = int(buy_time_end.strftime("%H%M%S")) b_interval_df = df[(df['EndTime'] >= self.__buy_time) & (df['EndTime'] <= buy_time_end)] buy_price = self._calc_average_price(b_interval_df) next_date = self.get_next_day(date).strftime("%Y%m%d") files = glob(os.path.join(self.__source_path, next_date, f"{symbol_code}*")) if len(files) <= 0: print(f"找不到下一个交易日, 跳过. date:{date}, next_date:{next_date}, symbol_code:{symbol_code}") continue s_df = pd.read_csv(files[0], encoding=self.__encoding, names=self.COLUMNS, index_col=False) sell_time_end = datetime.strptime(str(self.__sell_time), "%H%M%S") + timedelta(seconds=self.__s_T) sell_time_end = int(sell_time_end.strftime("%H%M%S")) s_interval_df = s_df[ (s_df['EndTime'] >= self.__sell_time) & (s_df['EndTime'] <= sell_time_end)] sell_price = self._calc_average_price(s_interval_df) pnl = (sell_price - buy_price) / buy_price s = average_weight * pnl rate = round(s * 100, 3) yield { "symbol": symbol_code, "pnl": pnl, "weight": average_weight, # 可能出现 -0.0 "rate": 0.0 if rate == 0 else rate }
[文档] @staticmethod def get_next_day(date: Union[str, datetime]) -> datetime: """获取下一个交易日 Args: date: 日期 Returns: 下一个交易日的datetime对象 """ global dt if isinstance(date, str): dt = datetime.strptime(date, '%Y%m%d') else: dt = date while True: dt = dt + timedelta(days=1) # 过滤掉周末 if dt.weekday() < 5: break return dt
@property def market(self): assert(isinstance(self.__market, Market)) return self.__market @market.setter def market(self, market: Market): """设置市场 Args: market: 市场 Returns: """ self.__market = market if market == Market.SZ: self.__symbol_filter = ['000', '002', '3'] else: self.__symbol_filter = []
[文档] def start(self, date: str = "*"): """启动回测 Args: date: 回测日期. 默认目录下所有日期 Returns: """ path = os.path.join(self.__source_path, date) if date == '*': date_paths = glob(path) else: date_paths = [path] for date_path in date_paths: # clean cache self.__top.clear() with ThreadPoolExecutor(max_workers=self.__thread_workers) as e: for filepath in glob(os.path.join(date_path, '*')): filename = os.path.split(filepath)[-1] # symbol_code = os.path.splitext(filename)[0] if len(self.__symbol_filter) > 0: if filename[0] in self.__symbol_filter or filename[:3] in self.__symbol_filter: e.submit(self._c, filepath, date) self.on_day_result()
# # wait for result # symbol_codes = [str(i) for i in self.day_result()] # self._top = {code: self._top[code] for code in symbol_codes} # self._rate_of_return_calc(date)
[文档] def rate_of_return_calc(self, date: str, symbol_codes: Union[list, set, tuple]): """收益率计算 Args: date: 日期 symbol_codes: 要计算的股票代码 Returns: """ self.__top = {code: self.__top[code] for code in symbol_codes} return self._rate_of_return_calc(date)
[文档] def on_day_result(self): """当遍历完一天的个股时会调用这个函数 ``子类必须实现`` Returns: Raises: NotImplementedError """ raise NotImplementedError
[文档] def on_day_stock(self, symbol_code: str, date: str, stock_df: pd.DataFrame) -> bool: """当遍历完个股一天行情时会调用这个函数 ``子类必须实现`` ``该方法会运行在子线程`` Args: symbol_code: 股票代码 date: 日期 stock_df: 行情数据 Returns: 是否缓存当前 stock_df. 返回 true 会将当前 stock_df 缓存 ``该缓存的数据会在计算收益率中使用, 无特殊需求, 返回True即可`` Raises: NotImplementedError """ raise NotImplementedError