util.utils 源代码
import os
import re
from .holiday import get_holiday
from datetime import datetime, timedelta
from typing import Union, Optional
[文档]def traverse_path(path: str) -> list:
"""遍历路径下的所有文件, ``包含子目录``
Args:
path: 需要遍历的路径
Returns: 路径下的所有文件
"""
r = []
if os.path.isdir(path):
for i in os.listdir(path):
r += traverse_path(os.path.join(path, i))
else:
r.append(path)
return r
[文档]def is_symbol_code(code: str) -> bool:
"""检查 code 是不是合法有效的 A 股股票代码
Args:
code: 代码
Returns:
"""
r = re.findall(r"(\d{6})", code)
return len(code) <= 9 and len(r) > 0
[文档]def is_sz_symbol_code(code: str) -> bool:
"""检查股票代码是不是属于深交所
Args:
code: 待检查的股票代码
Returns:
"""
if '.' in code:
a, b = code.split('.')
return a.upper() == 'SZ' or b.upper() == 'SZ'
else:
return is_symbol_code(code) and code[:3] in ['000', '002', '300']
[文档]def is_sh_symbol_code(code: str) -> bool:
"""检查股票代码是不是属于上交所
Args:
code: 待检查的股票代码
Returns:
"""
if '.' in code:
a, b = code.split('.')
return a.upper() == 'SH' or b.upper() == 'SH'
else:
return is_symbol_code(code) and code[:3] in ['600', '603', '601']
[文档]def get_all_stock(root_path: str, date: Optional[datetime] = None) -> list:
"""获取数据路径下的所有股票代码
Args:
root_path: 数据路径
date: 日期. 可选,默认全部
Returns: 去重后的所有股票代码
Examples:
# Get all symbol code form the source path
all_stock = get_all_stock("/source/path")
# Or limit date
all_stock = get_all_stock("/source/path", datetime.today())
"""
if date:
date_s = date.strftime('%Y%m%d')
path = os.path.join(root_path, date_s[:4], date_s[:6], date_s)
else:
path = root_path
if not os.path.exists(path):
raise FileNotFoundError(f"{path} not exists.")
stocks = []
for i in traverse_path(path):
s = os.path.splitext(os.path.split(i)[-1])[0]
# 检查一下是不是 A 股股票代码
if is_sz_symbol_code(s) or is_sh_symbol_code(s):
stocks.append(s)
return stocks
[文档]def get_trading_day(start: Union[datetime, str], end: Optional[Union[str, datetime]] = None) -> list:
"""获取指定时间范围内的所有交易日,排除节假日和周末
Args:
start: 开始时间
str 或 datetime 类型
``需要注意的是 str 必须要是 YYYYMMDD 格式``
end: 结束时间, 可选, 不包含
str 或 datetime 类型
``需要注意的是 str 必须要是 YYYYMMDD 格式``
Returns: 所有交易日
"""
start_dt = start
if isinstance(start, str):
start_dt = datetime.strptime(start, "%Y%m%d")
if end is not None:
end_dt = end
if isinstance(end, str):
end_dt = datetime.strptime(end, "%Y%m%d")
else:
end_dt = datetime.today()
trading_day = []
holidays = get_holiday(start_dt, end_dt)
while start_dt < end_dt:
# 排除周末和节假日
if start_dt.weekday() < 5 and start_dt.strftime("%Y%m%d") not in holidays:
trading_day.append(start_dt)
start_dt += timedelta(days=1)
return trading_day