Files
beast-trader-strategies/strategy.py

424 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Structure Flow Swing Strategy v3.1
==================================
波段交易策略 — 基于4H震荡区间保守参数 v2
v3.1 改动基于v3.0诊断结果):
1. 双边测试 AND→OR在10根K线内测试过支撑 OR 阻力即可(不需两者都测过)
2. 区间稳定性 15%→25%:放宽波动容忍度
3. 入场范围 2%→3%:增加候选信号密度
4. 冷却期 3根→1根减少过渡过滤
保留纯震荡定位、ATR×1.5止损、区间70%止盈、D1趋势过滤
预期年交易量从9笔 → 50-80笔约1-2单/周)
版本历史:
v3.0 (2026-06-10): 初版,基于冯总波段交易新思路
v3.1 (2026-06-10): 降低条件门槛,提升交易频率
"""
from datetime import datetime
import numpy as np
import pandas as pd
from pandas import DataFrame
from freqtrade.strategy import IStrategy, IntParameter, informative
from freqtrade.persistence import Trade
class StructureFlowSwingV31(IStrategy):
"""
Structure Flow Swing Strategy v3.1
4H震荡区间波段交易 — 放宽震荡判定
"""
can_short = True
stoploss = -0.20
use_custom_stoploss = True
minimal_roi = {"0": 100}
max_open_trades = 1
timeframe = "4h"
# =====================
# 可优化参数(放宽后默认值)
# =====================
swing_lookback = IntParameter(4, 8, default=5, space="buy")
zone_stability_threshold = IntParameter(15, 40, default=25, space="buy") # v3.1: 15→25↑
entry_zone_pct = IntParameter(1, 5, default=3, space="buy") # v3.1: 2→3↑
atr_stop_mult = IntParameter(10, 25, default=15, space="buy")
take_profit_pct = IntParameter(50, 80, default=70, space="sell")
# 固定参数
zone_touch_lookback = 10
breakout_bars = 2
# =====================
# 工具Swing Point 检测
# =====================
@staticmethod
def _detect_swing_points(
high: pd.Series,
low: pd.Series,
window: int = 5,
) -> tuple[pd.Series, pd.Series]:
n = len(high)
sh = pd.Series(np.nan, index=high.index, dtype=float)
sl = pd.Series(np.nan, index=low.index, dtype=float)
for i in range(window, n - window):
if high.iloc[i] > high.iloc[i - window:i].max() and high.iloc[i] > high.iloc[i + 1:i + window + 1].max():
sh.iloc[i] = high.iloc[i]
if low.iloc[i] < low.iloc[i - window:i].min() and low.iloc[i] < low.iloc[i + 1:i + window + 1].min():
sl.iloc[i] = low.iloc[i]
return sh, sl
# =====================
# 工具:区间震荡检测
# =====================
def _detect_range(
self,
sh: pd.Series,
sl: pd.Series,
high: pd.Series,
low: pd.Series,
close: pd.Series,
) -> DataFrame:
n = len(high)
is_ranging = np.full(n, False)
support_arr = np.full(n, np.nan)
resistance_arr = np.full(n, np.nan)
zone_width_arr = np.full(n, np.nan)
sh_prices = []
sl_prices = []
for i in range(n):
if pd.notna(sh.iloc[i]):
sh_prices.append(sh.iloc[i])
if len(sh_prices) > 5:
sh_prices.pop(0)
if pd.notna(sl.iloc[i]):
sl_prices.append(sl.iloc[i])
if len(sl_prices) > 5:
sl_prices.pop(0)
if len(sh_prices) < 3 or len(sl_prices) < 3:
continue
current_sh = sh_prices[-1]
current_sl = sl_prices[-1]
if current_sh <= current_sl:
continue
zone_width = (current_sh - current_sl) / current_sl
support_arr[i] = current_sl
resistance_arr[i] = current_sh
zone_width_arr[i] = zone_width
# 条件1区间宽度稳定性
widths = []
for j in range(min(len(sh_prices), len(sl_prices)) - 1, -1, -1):
w = (sh_prices[j] - sl_prices[j]) / sl_prices[j]
widths.append(w)
if len(widths) >= 3:
break
if len(widths) >= 3:
mean_width = np.mean(widths)
if mean_width > 0:
max_dev = max(abs(w - mean_width) / mean_width for w in widths)
stability_threshold = self.zone_stability_threshold.value / 100.0
is_stable = max_dev <= stability_threshold
else:
is_stable = False
else:
is_stable = False
if not is_stable:
continue
# 条件2价格测试过边界 — v3.1: AND→OR
# 只需要测试过支撑或阻力之一,不需要两者都测过
start_idx = max(0, i - self.zone_touch_lookback)
support_zone_upper = current_sl * 1.01
touched_support = any(
low.iloc[j] <= support_zone_upper
for j in range(start_idx, i + 1)
)
resistance_zone_lower = current_sh * 0.99
touched_resistance = any(
high.iloc[j] >= resistance_zone_lower
for j in range(start_idx, i + 1)
)
# v3.1: AND → OR
if not (touched_support or touched_resistance):
continue
# 条件3无突破
consecutive_outside = 0
for j in range(i, max(0, i - self.breakout_bars) - 1, -1):
if close.iloc[j] > current_sh or close.iloc[j] < current_sl:
consecutive_outside += 1
else:
break
if consecutive_outside >= self.breakout_bars:
continue
is_ranging[i] = True
return DataFrame({
"is_ranging": is_ranging,
"support": support_arr,
"resistance": resistance_arr,
"zone_width": zone_width_arr,
}, index=high.index)
# =====================
# 工具ATR计算
# =====================
@staticmethod
def _calc_atr(high: pd.Series, low: pd.Series, close: pd.Series, period: int = 14) -> pd.Series:
tr = pd.DataFrame({
"hl": high - low,
"hc": (high - close.shift(1)).abs(),
"lc": (low - close.shift(1)).abs(),
}).max(axis=1)
return tr.rolling(period).mean()
# ================================================================
# D1 信息时间框架 — 宏观趋势参考
# ================================================================
@informative("1d")
def populate_indicators_1d(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
sh, sl = self._detect_swing_points(
dataframe["high"], dataframe["low"], window=5
)
sh_vals = sh.dropna()
sl_vals = sl.dropna()
is_uptrend = pd.Series(False, index=dataframe.index)
is_downtrend = pd.Series(False, index=dataframe.index)
if len(sh_vals) >= 2 and len(sl_vals) >= 2:
if sh_vals.iloc[-1] > sh_vals.iloc[-2] and sl_vals.iloc[-1] > sl_vals.iloc[-2]:
is_uptrend[:] = True
elif sh_vals.iloc[-1] < sh_vals.iloc[-2] and sl_vals.iloc[-1] < sl_vals.iloc[-2]:
is_downtrend[:] = True
dataframe["d1_uptrend"] = is_uptrend
dataframe["d1_downtrend"] = is_downtrend
return dataframe
# ================================================================
# 主时间框架 — 4H 指标
# ================================================================
def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
sh, sl = self._detect_swing_points(
dataframe["high"], dataframe["low"],
self.swing_lookback.value,
)
range_info = self._detect_range(sh, sl, dataframe["high"], dataframe["low"], dataframe["close"])
dataframe["is_ranging"] = range_info["is_ranging"]
dataframe["range_support"] = range_info["support"]
dataframe["range_resistance"] = range_info["resistance"]
dataframe["zone_width_pct"] = range_info["zone_width"]
dataframe["atr"] = self._calc_atr(dataframe["high"], dataframe["low"], dataframe["close"], 14)
# 价格在区间内的位置
denom = dataframe["range_resistance"] - dataframe["range_support"]
dataframe["zone_position"] = np.where(
denom > 0,
(dataframe["close"] - dataframe["range_support"]) / denom,
np.nan,
)
# 距离边界百分比
dataframe["dist_to_support"] = np.where(
dataframe["range_support"] > 0,
(dataframe["close"] - dataframe["range_support"]) / dataframe["close"],
np.nan,
)
dataframe["dist_to_resistance"] = np.where(
dataframe["range_resistance"] > 0,
(dataframe["range_resistance"] - dataframe["close"]) / dataframe["close"],
np.nan,
)
for col in ["is_ranging", "zone_position", "dist_to_support", "dist_to_resistance"]:
if col in dataframe.columns:
dataframe[col] = dataframe[col].fillna(False if col == "is_ranging" else 999)
return dataframe
# ================================================================
# 入场信号 — v3.1: 冷却期 3→1
# ================================================================
def populate_entry_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
entry_zone = self.entry_zone_pct.value / 100.0
d1_downtrend_col = "d1_downtrend_1d"
d1_uptrend_col = "d1_uptrend_1d"
for col in ["is_ranging", d1_uptrend_col, d1_downtrend_col]:
if col in dataframe.columns:
dataframe[col] = dataframe[col].fillna(False)
else:
dataframe[col] = False
# ── 做多:震荡市中,价格靠近支撑位 ──
long_conds = (
dataframe["is_ranging"]
& (dataframe["dist_to_support"] <= entry_zone)
& (dataframe["dist_to_support"] > 0)
& (~dataframe[d1_downtrend_col])
)
cooldown = 1 # v3.1: 3→1
long_recent = long_conds.rolling(cooldown, min_periods=1).max().shift(1) == 0
dataframe.loc[long_conds & long_recent, "enter_long"] = 1
# ── 做空:震荡市中,价格靠近阻力位 ──
short_conds = (
dataframe["is_ranging"]
& (dataframe["dist_to_resistance"] <= entry_zone)
& (dataframe["dist_to_resistance"] > 0)
& (~dataframe[d1_uptrend_col])
)
short_recent = short_conds.rolling(cooldown, min_periods=1).max().shift(1) == 0
dataframe.loc[short_conds & short_recent, "enter_short"] = 1
return dataframe
# ================================================================
# 出场信号
# ================================================================
def populate_exit_trend(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
return dataframe
# ================================================================
# 自定义止损:支撑/阻力外侧ATR*1.5 缓冲
# ================================================================
def custom_stoploss(
self,
pair: str,
trade: Trade,
current_time: datetime,
current_rate: float,
current_profit: float,
after_fill: bool,
**kwargs,
) -> float:
dataframe, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
if dataframe is None or len(dataframe) == 0:
return -0.02 if not trade.is_short else 0.02
last = dataframe.iloc[-1]
atr_mult = self.atr_stop_mult.value / 10.0
if not trade.is_short:
support = last.get("range_support", np.nan)
atr = last.get("atr", np.nan)
if pd.isna(support) or support <= 0:
return -0.02
if pd.notna(atr) and atr > 0:
sl_price = support - atr * atr_mult
else:
sl_price = support * 0.985
sl_ratio = (sl_price / current_rate) - 1.0
return max(sl_ratio, -0.20)
else:
resistance = last.get("range_resistance", np.nan)
atr = last.get("atr", np.nan)
if pd.isna(resistance) or resistance <= 0:
return 0.02
if pd.notna(atr) and atr > 0:
sl_price = resistance + atr * atr_mult
else:
sl_price = resistance * 1.015
sl_ratio = 1.0 - (sl_price / current_rate)
return min(sl_ratio, 0.20)
# ================================================================
# 自定义止盈区间70%
# ================================================================
def custom_exit(
self,
pair: str,
trade: Trade,
current_time: datetime,
current_rate: float,
current_profit: float,
**kwargs,
) -> str | None:
tp_pct = self.take_profit_pct.value / 100.0
dataframe, _ = self.dp.get_analyzed_dataframe(pair, self.timeframe)
if dataframe is None or len(dataframe) == 0:
return None
last = dataframe.iloc[-1]
if not trade.is_short:
support = last.get("range_support", np.nan)
resistance = last.get("range_resistance", np.nan)
if pd.notna(support) and pd.notna(resistance) and resistance > support:
zone_height = (resistance - support) / support
tp_target = zone_height * tp_pct
if current_profit >= tp_target:
return "take_profit"
else:
support = last.get("range_support", np.nan)
resistance = last.get("range_resistance", np.nan)
if pd.notna(support) and pd.notna(resistance) and resistance > support:
zone_height = (resistance - support) / resistance
tp_target = zone_height * tp_pct
if current_profit >= tp_target:
return "take_profit"
return None
# ================================================================
# Plot config
# ================================================================
@staticmethod
def plot_config() -> dict:
return {
"main_plot": {
"range_support": {"color": "green", "type": "line"},
"range_resistance": {"color": "red", "type": "line"},
},
"subplots": {
"range": {
"is_ranging": {"color": "blue", "type": "line"},
"zone_width_pct": {"color": "purple", "type": "line"},
},
"position": {
"dist_to_support": {"color": "green", "type": "line"},
"dist_to_resistance": {"color": "red", "type": "line"},
},
},
}