1113 lines
41 KiB
Python
1113 lines
41 KiB
Python
"""
|
|
research.py - Model Research & Development Utilities
|
|
====================================================
|
|
Reusable boilerplate code for machine learning research and model development.
|
|
This module eliminates copy-paste by centralizing common functions for:
|
|
|
|
- Data preprocessing and time series aggregation
|
|
- Feature engineering for financial data
|
|
- Model architecture inspection and debugging
|
|
- Visualization and exploratory data analysis
|
|
- Training utilities
|
|
|
|
Import this in all your research notebooks to save time and maintain consistency.
|
|
|
|
Author: MemLabs
|
|
Course: Build a Quant Trading System
|
|
"""
|
|
|
|
# ============================================================================
|
|
# IMPORTS
|
|
# ============================================================================
|
|
|
|
# Data manipulation and analysis
|
|
import polars as pl # Fast dataframes for financial data
|
|
from typing import Dict, List, Tuple, Union # Type hints for function signatures
|
|
|
|
# Machine learning framework
|
|
import torch # PyTorch for neural networks
|
|
import torch.nn as nn # Neural network modules
|
|
import torch.optim as optim # Optimization algorithms
|
|
|
|
# Numerical computing and datetime
|
|
import numpy as np # Numerical operations
|
|
import numpy.typing as npt
|
|
from datetime import datetime, timedelta # Date and time handling
|
|
|
|
# Visualization
|
|
import altair # Interactive plotting library
|
|
import matplotlib.pyplot as plt
|
|
|
|
import random
|
|
import re
|
|
import itertools
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
SEED = 42
|
|
|
|
# ============================================================================
|
|
# TIME SERIES AGGREGATION
|
|
# ============================================================================
|
|
OHLC_AGGS = [
|
|
# Price statistics (core OHLC data)
|
|
pl.col("price").first().alias("open"), # Opening price
|
|
pl.col("price").max().alias("high"), # Highest price
|
|
pl.col("price").min().alias("low"), # Lowest price
|
|
pl.col("price").last().alias("close"), # Closing price (most important)
|
|
]
|
|
|
|
|
|
def get_trade_files(directory: str, sym: str) -> List[Path]:
|
|
"""
|
|
Get all files in directory that start with '{sym}-trades'.
|
|
|
|
Args:
|
|
directory: Path to directory to search
|
|
sym: Symbol prefix (e.g., 'BTCUSDT')
|
|
|
|
Returns:
|
|
List of Path objects matching the pattern
|
|
|
|
Example:
|
|
>>> files = get_trade_files('./data', 'BTCUSDT')
|
|
>>> # Returns: ['BTCUSDT-trades-2024.csv', 'BTCUSDT-trades-raw.parquet', ...]
|
|
"""
|
|
dir_path = Path(directory)
|
|
pattern = f"{sym}-trades*"
|
|
return sorted(dir_path.glob(pattern))
|
|
|
|
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
def load_ohlc_timeseries(sym: str, time_interval: str):
|
|
return load_timeseries(sym, time_interval, OHLC_AGGS)
|
|
|
|
def load_timeseries(
|
|
sym: str,
|
|
time_interval: str,
|
|
aggs: List[pl.Expr],
|
|
data_path: Optional[str] = None
|
|
) -> pl.DataFrame:
|
|
"""
|
|
Load trade CSV files one by one, aggregate to time series, and concatenate.
|
|
|
|
Args:
|
|
sym: Symbol prefix (e.g., 'BTCUSDT')
|
|
time_interval: Time interval for aggregation (e.g., '1h', '5m')
|
|
aggs: List of aggregation expressions
|
|
data_path: Optional directory path. Defaults to './cache' if not provided
|
|
|
|
Returns:
|
|
Concatenated time series DataFrame
|
|
|
|
Example:
|
|
>>> # Use default './cache' directory
|
|
>>> ts = load_ohlc_ts('BTCUSDT', '1h', ohlc_aggs)
|
|
|
|
>>> # Specify custom directory
|
|
>>> ts = load_ohlc_ts('BTCUSDT', '1h', ohlc_aggs, data_path='./my_data')
|
|
"""
|
|
# Default to './cache' if not provided
|
|
if data_path is None:
|
|
data_path = './cache'
|
|
|
|
files = get_trade_files(data_path, sym)
|
|
|
|
if not files:
|
|
raise FileNotFoundError(f"No files found for {sym} in {data_path}")
|
|
|
|
# Process each file and collect results
|
|
ts_list = []
|
|
|
|
# Add progress bar
|
|
for file in tqdm(files, desc=f"Loading {sym}", unit="file"):
|
|
|
|
# Load trades from parquet
|
|
trades = pl.read_parquet(file)
|
|
|
|
# Ensure datetime column exists and is correct type
|
|
if "datetime" not in trades.columns:
|
|
raise ValueError(f"Column 'datetime' not found in {file.name}")
|
|
|
|
trades = trades.with_columns(
|
|
pl.col("datetime").cast(pl.Datetime)
|
|
).sort("datetime")
|
|
|
|
# Aggregate to time series
|
|
ts = trades.group_by_dynamic(
|
|
"datetime",
|
|
every=time_interval,
|
|
offset="0m"
|
|
).agg(aggs)
|
|
|
|
ts_list.append(ts)
|
|
|
|
# Concatenate all time series
|
|
result = pl.concat(ts_list)
|
|
|
|
# Sort by datetime and remove duplicates if any
|
|
result = result.sort("datetime").unique(subset=["datetime"])
|
|
|
|
return result
|
|
|
|
|
|
def load_timeseries_range(
|
|
sym: str,
|
|
time_interval: str,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
agg_cols: Union[pl.Expr,List[pl.Expr]],
|
|
data_path: Optional[str] = None
|
|
) -> pl.DataFrame:
|
|
"""
|
|
Load and aggregate trade data for a symbol between start_date and end_date
|
|
into OHLC time series using the given time interval.
|
|
|
|
Expects daily files named like:
|
|
{symbol}-trades-YYYY-MM-DD.parquet
|
|
|
|
Example filename:
|
|
BTCUSDT-trades-2025-09-22.parquet
|
|
|
|
Args:
|
|
sym: Symbol prefix (e.g., 'BTCUSDT')
|
|
time_interval: Aggregation interval (e.g., '1h', '5m')
|
|
start_date: Start datetime (inclusive)
|
|
end_date: End datetime (inclusive)
|
|
data_path: Directory containing cached trade parquet files (default: './cache')
|
|
|
|
Returns:
|
|
Polars DataFrame with aggregated OHLC time series for the given range.
|
|
"""
|
|
if data_path is None:
|
|
data_path = "./cache"
|
|
|
|
if start_date > end_date:
|
|
raise ValueError("start_date must be before or equal to end_date")
|
|
|
|
ts_list = []
|
|
total_days = (end_date - start_date).days + 1
|
|
|
|
for i in tqdm(range(total_days), desc=f"Loading {sym}", unit="day"):
|
|
current_date = start_date + timedelta(days=i)
|
|
file_name = f"{sym}-trades-{current_date.strftime('%Y-%m-%d')}.parquet"
|
|
file_path = os.path.join(data_path, file_name)
|
|
|
|
if not os.path.exists(file_path):
|
|
tqdm.write(f"[WARNING] Missing file: {file_name}")
|
|
continue
|
|
|
|
try:
|
|
trades = pl.read_parquet(file_path)
|
|
|
|
if "datetime" not in trades.columns:
|
|
raise ValueError(f"Column 'datetime' not found in {file_name}")
|
|
|
|
trades = trades.with_columns(pl.col("datetime").cast(pl.Datetime))
|
|
|
|
ts = trades.group_by_dynamic("datetime", every=time_interval, offset="0m").agg(agg_cols)
|
|
ts_list.append(ts)
|
|
|
|
except Exception as e:
|
|
tqdm.write(f"[ERROR] {file_name}: {e}")
|
|
|
|
if not ts_list:
|
|
raise ValueError(f"No trade data found for {sym} in range {start_date} to {end_date}")
|
|
|
|
result = pl.concat(ts_list).sort("datetime").unique(subset=["datetime"])
|
|
return result
|
|
|
|
def load_ohlc_timeseries_range(
|
|
sym: str,
|
|
time_interval: str,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
data_path: Optional[str] = None
|
|
) -> pl.DataFrame:
|
|
"""
|
|
Load and aggregate trade data for a symbol between start_date and end_date
|
|
into OHLC time series using the given time interval.
|
|
|
|
Expects daily files named like:
|
|
{symbol}-trades-YYYY-MM-DD.parquet
|
|
|
|
Example filename:
|
|
BTCUSDT-trades-2025-09-22.parquet
|
|
|
|
Args:
|
|
sym: Symbol prefix (e.g., 'BTCUSDT')
|
|
time_interval: Aggregation interval (e.g., '1h', '5m')
|
|
start_date: Start datetime (inclusive)
|
|
end_date: End datetime (inclusive)
|
|
data_path: Directory containing cached trade parquet files (default: './cache')
|
|
|
|
Returns:
|
|
Polars DataFrame with aggregated OHLC time series for the given range.
|
|
"""
|
|
if data_path is None:
|
|
data_path = "./cache"
|
|
|
|
if start_date > end_date:
|
|
raise ValueError("start_date must be before or equal to end_date")
|
|
|
|
ts_list = []
|
|
total_days = (end_date - start_date).days + 1
|
|
|
|
for i in tqdm(range(total_days), desc=f"Loading {sym}", unit="day"):
|
|
current_date = start_date + timedelta(days=i)
|
|
file_name = f"{sym}-trades-{current_date.strftime('%Y-%m-%d')}.parquet"
|
|
file_path = os.path.join(data_path, file_name)
|
|
|
|
if not os.path.exists(file_path):
|
|
tqdm.write(f"[WARNING] Missing file: {file_name}")
|
|
continue
|
|
|
|
try:
|
|
trades = pl.read_parquet(file_path)
|
|
|
|
if "datetime" not in trades.columns:
|
|
raise ValueError(f"Column 'datetime' not found in {file_name}")
|
|
|
|
trades = trades.with_columns(pl.col("datetime").cast(pl.Datetime))
|
|
|
|
ts = trades.group_by_dynamic("datetime", every=time_interval, offset="0m").agg(OHLC_AGGS)
|
|
ts_list.append(ts)
|
|
|
|
except Exception as e:
|
|
tqdm.write(f"[ERROR] {file_name}: {e}")
|
|
|
|
if not ts_list:
|
|
raise ValueError(f"No trade data found for {sym} in range {start_date} to {end_date}")
|
|
|
|
result = pl.concat(ts_list).sort("datetime").unique(subset=["datetime"])
|
|
return result
|
|
|
|
|
|
def sharpe_annualization_factor(interval: str,
|
|
trading_days_per_year: int = 365,
|
|
trading_hours_per_day: float = 24) -> float:
|
|
"""
|
|
Compute annualization factor (sqrt of periods per year) given a return interval.
|
|
|
|
interval : str
|
|
Frequency string like '1d', '1h', '30m', '15s'.
|
|
trading_days_per_year : int
|
|
Number of trading days in a year (default 252).
|
|
trading_hours_per_day : float
|
|
Number of trading hours in a trading day (default 6.5).
|
|
|
|
Returns
|
|
-------
|
|
float : annualization factor
|
|
"""
|
|
match = re.match(r"(\d+)([dhms])", interval.lower())
|
|
if not match:
|
|
raise ValueError("Interval must be like '1d', '2h', '15m', '30s'")
|
|
|
|
value, unit = int(match.group(1)), match.group(2)
|
|
|
|
# periods per year
|
|
if unit == 'd':
|
|
periods = trading_days_per_year / value
|
|
elif unit == 'h':
|
|
periods = trading_days_per_year * (trading_hours_per_day / value)
|
|
elif unit == 'm':
|
|
periods = trading_days_per_year * (trading_hours_per_day * 60 / value)
|
|
elif unit == 's':
|
|
periods = trading_days_per_year * (trading_hours_per_day * 3600 / value)
|
|
else:
|
|
raise ValueError(f"Unsupported unit: {unit}")
|
|
|
|
return np.sqrt(periods)
|
|
|
|
|
|
def ohlc_timeseries(df: pl.DataFrame, time_interval: str) -> pl.DataFrame:
|
|
"""
|
|
Convert tick-level trade data into OHLC (Open, High, Low, Close) bars.
|
|
|
|
This function aggregates raw trade data into standardized price bars
|
|
with basic volume and trade statistics. If you want to extend this then call regular_timeseries
|
|
|
|
Args:
|
|
df: DataFrame containing trade data with columns:
|
|
- datetime: Timestamp of each trade
|
|
- price: Execution price
|
|
- quote_qty: Trade size in quote currency (e.g., USDT)
|
|
- is_short: Boolean indicating if trade was a short sale
|
|
time_interval: Aggregation period (e.g., '1m', '5m', '15m', '1h', '1d')
|
|
|
|
Returns:
|
|
DataFrame with OHLC bars containing:
|
|
- datetime: Bar timestamp
|
|
- open: First price in interval
|
|
- high: Highest price in interval
|
|
- low: Lowest price in interval
|
|
- close: Last price in interval (most important for ML)
|
|
- volume: Total trading volume in quote currency
|
|
- trade_count: Number of individual trades
|
|
- short_ratio: Percentage of trades that were short sales
|
|
- mean_price: Average price (volume-weighted alternative)
|
|
|
|
Example:
|
|
>>> # Create 15-minute OHLC bars
|
|
>>> bars_15m = ohlc_timeseries(trades_df, '15m')
|
|
>>>
|
|
>>> # Create hourly bars for longer-term analysis
|
|
>>> bars_1h = ohlc_timeseries(trades_df, '1h')
|
|
"""
|
|
# Define aggregation expressions for OHLC calculation
|
|
|
|
|
|
# Use the generic time series aggregation function
|
|
return timeseries(df, time_interval, OHLC_AGGS)
|
|
|
|
def lag_col_names(col: str, n: int) -> List[str]:
|
|
return [f'{col}_lag_{i}' for i in range(1, n+1)]
|
|
|
|
def auto_reg_corr_matrx(df, target, max_no_lags) -> pl.DataFrame:
|
|
return df.drop_nulls().select([target]+lag_col_names(target, max_no_lags)).corr()
|
|
|
|
def log_returns_col(name: str, step_size = 1) -> pl.Expr:
|
|
return (pl.col(name)/pl.col(name).shift(step_size)).log().alias(f'{name}_log_return')
|
|
|
|
def timeseries(
|
|
df: pl.DataFrame,
|
|
time_interval: str,
|
|
aggs: Union[List[pl.Expr],pl.Expr]
|
|
) -> pl.DataFrame:
|
|
"""
|
|
Generic function for aggregating data into regular time intervals.
|
|
|
|
This is a flexible time series aggregation framework that can handle
|
|
any custom aggregation expressions. Used as the foundation for OHLC
|
|
bars and other time-based features.
|
|
|
|
Args:
|
|
df: DataFrame with a 'datetime' column
|
|
time_interval: Aggregation period (Polars duration string)
|
|
Examples: '1m', '5m', '15m', '1h', '4h', '1d'
|
|
aggs: List of Polars expressions defining aggregations to compute
|
|
|
|
Returns:
|
|
DataFrame with time-aggregated data
|
|
|
|
Technical Details:
|
|
- Uses left-closed intervals: [start_time, end_time)
|
|
- Bars start at round times (e.g., 09:00, 09:15, 09:30)
|
|
- Missing bars (no trades) are automatically excluded
|
|
|
|
Example:
|
|
>>> # Custom aggregation for volatility analysis
|
|
>>> custom_aggs = [
|
|
... pl.col("price").std().alias("price_volatility"),
|
|
... pl.col("volume").sum().alias("total_volume"),
|
|
... ]
|
|
>>> df_volatility = regular_timeseries(df, '1h', custom_aggs)
|
|
"""
|
|
return df.group_by_dynamic(
|
|
"datetime", # Column to group by (must be datetime type)
|
|
every=time_interval, # Aggregation frequencyß
|
|
offset="0m" # No offset (bars align to round times)
|
|
).agg(aggs)
|
|
|
|
|
|
# ============================================================================
|
|
# VISUALIZATION
|
|
# ============================================================================
|
|
|
|
def plot(df: pl.DataFrame, col: str, title: str = "") -> altair.Chart:
|
|
"""
|
|
Create a smooth density plot for analyzing feature distributions.
|
|
|
|
Useful for:
|
|
- Understanding data distributions before modeling
|
|
- Detecting outliers and skewness
|
|
- Comparing feature distributions across different time periods
|
|
- Validating data preprocessing steps
|
|
|
|
Args:
|
|
df: DataFrame containing the column to plot
|
|
col: Name of the column to visualize
|
|
title: Optional chart title (defaults to None)
|
|
|
|
Returns:
|
|
Altair Chart object (displays automatically in Jupyter)
|
|
|
|
Example:
|
|
>>> # Plot distribution of returns
|
|
>>> plot(df, 'returns', title='Return Distribution')
|
|
>>>
|
|
>>> # Plot price changes
|
|
>>> plot(df, 'price_change', title='Price Change Distribution')
|
|
|
|
Note:
|
|
The density estimation uses kernel density estimation (KDE)
|
|
with basis interpolation for smooth curves.
|
|
"""
|
|
return altair.Chart(df).mark_area(
|
|
opacity=0.7, # Semi-transparent fill
|
|
interpolate='basis' # Smooth curve interpolation
|
|
).transform_density(
|
|
col, # Column to compute density for
|
|
as_=[col, 'density'] # Output column names
|
|
).encode(
|
|
x=altair.X(f'{col}:Q', title=col), # X-axis: feature values
|
|
y=altair.Y('density:Q', title='Density') # Y-axis: probability density
|
|
).properties(
|
|
width=600,
|
|
height=400,
|
|
title=title if title else f'Distribution of {col}'
|
|
)
|
|
|
|
def plot_distribution(data: pl.DataFrame, col: str, label = None, no_bins = 100):
|
|
return altair.Chart(data).mark_bar().encode(
|
|
altair.X(f'{col}:Q', bin=altair.Bin(maxbins=no_bins)),
|
|
y='count()'
|
|
).properties(
|
|
width=600,
|
|
height=400,
|
|
title=f'Distribution of {label if label else col}'
|
|
).configure_scale(zero=False).add_params(
|
|
altair.selection_interval(bind='scales')
|
|
)
|
|
|
|
def plot_static_timeseries(ts: pl.DataFrame, sym: str, col: str, interval_size: str):
|
|
plt.figure(figsize=(12, 6))
|
|
plt.plot(ts['datetime'], ts[col], label=col) # or whatever column you want
|
|
plt.title(f'{sym} {interval_size} Bars')
|
|
plt.xlabel('time')
|
|
plt.ylabel(col)
|
|
plt.legend()
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def plot_multiple_lines(
|
|
df: pl.DataFrame,
|
|
cols_to_plot: List[str],
|
|
sym: str,
|
|
width: int = 15,
|
|
height: int = 6,
|
|
xlabel_unit: str = "Time Step"
|
|
):
|
|
import matplotlib.pyplot as plt
|
|
"""
|
|
Plots multiple columns from a Polars DataFrame on the same axes using Matplotlib.
|
|
The x-axis uses a simple numerical index (since no datetime column is present).
|
|
|
|
Parameters:
|
|
-----------
|
|
df : polars.DataFrame
|
|
The Polars DataFrame containing the columns to plot.
|
|
cols_to_plot : list[str]
|
|
A list of column names to plot (e.g., ['log_return', 'mean']).
|
|
sym : str
|
|
A symbol or identifier for the series (used in the title).
|
|
width : int, default 15
|
|
Width of the plot in inches.
|
|
height : int, default 6
|
|
Height of the plot in inches.
|
|
xlabel_unit : str, default 'Time Step'
|
|
Label for the X-axis (the numerical index).
|
|
"""
|
|
|
|
# 1. Create the numerical index for the x-axis
|
|
x_index = np.arange(len(df))
|
|
|
|
# 2. Set the figure size (controls the width/height)
|
|
plt.figure(figsize=(width, height))
|
|
|
|
# 3. Loop through the list of columns and plot each one
|
|
for col in cols_to_plot:
|
|
if col in df.columns:
|
|
# Extract column data as a NumPy array (efficient)
|
|
y_values = df[col].to_numpy()
|
|
|
|
# Plot the line, using the column name for the label
|
|
plt.plot(x_index, y_values, label=col)
|
|
else:
|
|
print(f"Warning: Column '{col}' not found in DataFrame.")
|
|
|
|
# 4. Finalize the plot
|
|
|
|
# Dynamically generate the title based on the symbol and columns
|
|
title_cols = ', '.join(cols_to_plot)
|
|
plt.title(f'{sym} Series: {title_cols}')
|
|
|
|
plt.xlabel(xlabel_unit)
|
|
plt.ylabel('Value') # Generic Y-label since multiple series are plotted
|
|
plt.legend(loc='best')
|
|
plt.grid(True, linestyle=':', alpha=0.6)
|
|
|
|
# Adjust layout to prevent labels from being cut off
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
def plot_dyn_timeseries(ts: pl.DataFrame, sym: str, col: str, time_interval: str ):
|
|
return altair.Chart(ts).mark_line(tooltip=True).encode(
|
|
x="datetime",
|
|
y=col
|
|
).properties(
|
|
width=800,
|
|
height=400,
|
|
title=f"{sym} {time_interval} {col}"
|
|
).configure_scale(zero=False).add_selection(
|
|
altair.selection_interval(bind='scales', encodings=['x']), # Only zoom x-axis
|
|
altair.selection_interval(bind='scales', encodings=['y']) # Only zoom y-axis
|
|
)
|
|
|
|
def to_tensor(x, dtype=None) -> torch.Tensor:
|
|
return torch.tensor(x.to_numpy(), dtype=torch.float32 if dtype is None else dtype)
|
|
|
|
# ============================================================================
|
|
# MODEL ANALYSIS
|
|
# ============================================================================
|
|
|
|
def print_model_complexity_ratio(m1, m1_name, m2, m2_name):
|
|
m1_params = total_model_params(m1)
|
|
m2_params = total_model_params(m2)
|
|
complexity_ratio = m2_params / m1_params
|
|
|
|
print(f"Complexity Comparsion:")
|
|
print(f"\t{m2_name} has {complexity_ratio:.1f}x more parameters than {m1_name}")
|
|
print(f"\tParametric difference: {m2_params - m1_params:,} additional parameters")
|
|
|
|
def total_model_params(model: nn.Module) -> int:
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
def print_model_info(model: torch.nn.Module, model_name: str) -> None:
|
|
"""
|
|
Print detailed information about a PyTorch model's architecture and parameters.
|
|
|
|
This function helps you understand:
|
|
- Model complexity (number of parameters)
|
|
- Which parameters are trainable vs frozen
|
|
- Overall model architecture
|
|
|
|
Useful for:
|
|
- Comparing different model architectures
|
|
- Debugging training issues
|
|
- Estimating memory requirements
|
|
- Understanding model capacity
|
|
|
|
Args:
|
|
model: PyTorch model (nn.Module)
|
|
model_name: Descriptive name for the model (e.g., 'LSTM Predictor')
|
|
|
|
Returns:
|
|
None (prints to console)
|
|
|
|
Example:
|
|
>>> model = MyTradingModel(input_size=10, hidden_size=64)
|
|
>>> print_model_info(model, 'Trading LSTM v1')
|
|
|
|
Output:
|
|
Trading LSTM v1:
|
|
Architecture: MyTradingModel(...)
|
|
Total parameters: 15,234
|
|
Trainable parameters: 15,234
|
|
|
|
Note:
|
|
Total parameters includes both trainable and frozen parameters.
|
|
For transfer learning, trainable params may be less than total.
|
|
"""
|
|
# Count all parameters in the model
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
# Count only parameters that will be updated during training
|
|
trainable_params = sum(
|
|
p.numel() for p in model.parameters() if p.requires_grad
|
|
)
|
|
|
|
# Print formatted model information
|
|
print(f"\n{'='*60}")
|
|
print(f"{model_name}")
|
|
print(f"{'='*60}")
|
|
print(f"\nArchitecture:")
|
|
print(f" {model}")
|
|
print(f"\nParameter Count:")
|
|
print(f" Total parameters: {total_params:,}")
|
|
print(f" Trainable parameters: {trainable_params:,}")
|
|
|
|
# Warn if some parameters are frozen
|
|
if total_params != trainable_params:
|
|
frozen_params = total_params - trainable_params
|
|
print(f" Frozen parameters: {frozen_params:,}")
|
|
print(f"\n ⚠️ Note: {frozen_params:,} parameters are frozen")
|
|
|
|
print(f"{'='*60}\n")
|
|
|
|
def _prefix_cols(df, prefix):
|
|
return df.rename({col: f"{prefix}_{col}" for col in df.columns})
|
|
|
|
def _prefix_close_ts(trades, time_interval, prefix):
|
|
return _prefix_cols(ohlc_timeseries(trades, time_interval), prefix)
|
|
|
|
def compare_ts_corr(x_df, x_prefix, y_df, y_prefix, time_interval, col = 'close'):
|
|
x_col, y_col = f'{x_prefix}_{col}',f'{y_prefix}_{col}'
|
|
joined_ts = pl.concat([
|
|
_prefix_close_ts(x_df, time_interval, x_prefix),
|
|
_prefix_close_ts(y_df, time_interval, y_prefix)
|
|
], how="horizontal")
|
|
return joined_ts.select(pl.corr(x_col, y_col)).item()
|
|
|
|
def log_return_col(col: str) -> str:
|
|
return f"{col}_log_return"
|
|
|
|
def log_return(col: str, shift_size: int = 1) -> pl.Expr:
|
|
return (pl.col(col)/pl.col(col).shift(shift_size)).log().alias(log_return_col(col))
|
|
|
|
def lag_cols(col: str, forecast_horizon: str, no_lags: int) -> List[pl.Expr]:
|
|
return [pl.col(col).shift(forecast_horizon * i).alias(f'{col}_lag_{i}') for i in range(1, no_lags + 1)]
|
|
|
|
def add_lags(df: pl.DataFrame, col: str, max_no_lags: int, forecast_step: int) -> pl.DataFrame:
|
|
return df.with_columns([pl.col(col).shift(i * forecast_step).alias(f'{col}_lag_{i}') for i in range(1, max_no_lags + 1)])
|
|
|
|
def batch_train_reg(
|
|
model: nn.Module,
|
|
X_train,
|
|
X_test,
|
|
y_train,
|
|
y_test,
|
|
no_epochs: int,
|
|
criterion=None,
|
|
optimizer=None,
|
|
logging=True,
|
|
lr=None
|
|
):
|
|
if criterion is None:
|
|
criterion = nn.L1Loss()
|
|
|
|
if lr is None:
|
|
lr = 0.0002
|
|
|
|
# Default optimizer
|
|
if optimizer is None:
|
|
# Use strong_wolfe line search (more stable)
|
|
optimizer = optim.LBFGS(
|
|
model.parameters(),
|
|
lr=1,
|
|
line_search_fn='strong_wolfe',
|
|
tolerance_grad=1e-7,
|
|
tolerance_change=1e-9
|
|
)
|
|
|
|
# Logging model info
|
|
if logging:
|
|
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters())}")
|
|
print("Model architecture:")
|
|
for name, param in model.named_parameters():
|
|
print(f" {name}: {param.shape} ({param.numel()} params)")
|
|
print("\nTraining model...")
|
|
|
|
train_loss = None
|
|
log_tick_size = max(no_epochs // 10, 1) # avoid zero division
|
|
|
|
# Training loop
|
|
if isinstance(optimizer, torch.optim.LBFGS):
|
|
# LBFGS requires a closure
|
|
for epoch in range(no_epochs):
|
|
def closure():
|
|
optimizer.zero_grad()
|
|
predictions = model(X_train)
|
|
loss = criterion(predictions, y_train)
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer.step(closure)
|
|
|
|
with torch.no_grad():
|
|
train_loss = criterion(model(X_train), y_train).item()
|
|
|
|
if logging and (epoch + 1) % log_tick_size == 0:
|
|
print(f"Epoch [{epoch+1}/{no_epochs}], Loss: {train_loss:.6f}")
|
|
|
|
else:
|
|
# SGD/Adam loop
|
|
for epoch in range(no_epochs):
|
|
optimizer.zero_grad()
|
|
predictions = model(X_train)
|
|
loss = criterion(predictions, y_train)
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss = loss.item()
|
|
|
|
if logging and (epoch + 1) % log_tick_size == 0:
|
|
print(f"Epoch [{epoch+1}/{no_epochs}], Loss: {loss.item():.6f}")
|
|
|
|
# After training
|
|
if logging:
|
|
print("\nLearned parameters:")
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
print(f"{name}:\n{param.data.numpy()}")
|
|
|
|
# Evaluation
|
|
model.eval()
|
|
with torch.no_grad():
|
|
y_hat = model(X_test)
|
|
test_loss = criterion(y_hat, y_test)
|
|
if logging:
|
|
print(f'\nTest Loss: {test_loss.item():.6f}, Train Loss: {train_loss:.6f}')
|
|
|
|
return y_hat
|
|
|
|
|
|
def timeseries_train_test_split(df: pl.DataFrame, features, target, test_size=0.25) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
df = df.drop_nulls()
|
|
X = to_tensor(df[features])
|
|
y = to_tensor(df[target]).reshape(-1, 1)
|
|
X_train, X_test = timeseries_split(X, test_size)
|
|
y_train, y_test = timeseries_split(y, test_size)
|
|
return X_train, X_test, y_train, y_test
|
|
|
|
def timeseries_split(t, test_size=0.25):
|
|
"""
|
|
Split a tensor or array into train/test sets based on a proportion.
|
|
|
|
Parameters
|
|
----------
|
|
t : torch.Tensor or np.ndarray
|
|
Time series data.
|
|
test_size : float, default 0.25
|
|
Proportion of data to use for testing. Must be between 0 and 1.
|
|
|
|
Returns
|
|
-------
|
|
train, test : same type as t
|
|
Train and test splits.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If test_size is not strictly between 0 and 1.
|
|
"""
|
|
if not (0 < test_size < 1):
|
|
raise ValueError(f"test_size must be between 0 and 1 (got {test_size})")
|
|
|
|
split_idx = int(len(t) * (1 - test_size))
|
|
return t[:split_idx], t[split_idx:]
|
|
|
|
|
|
def plot_column(df, col_name, figsize=(15, 6), title=None, xlabel='Index'):
|
|
"""
|
|
Plot a column from a Polars DataFrame using matplotlib.
|
|
|
|
Parameters:
|
|
-----------
|
|
df : polars.DataFrame
|
|
The Polars DataFrame
|
|
column_name : str
|
|
Name of the column to plot
|
|
figsize : tuple, default (15, 6)
|
|
Figure size as (width, height) in inches
|
|
title : str, optional
|
|
Plot title. If None, uses column name
|
|
xlabel : str, default 'Index'
|
|
X-axis label
|
|
ylabel : str, optional
|
|
Y-axis label. If None, uses column name
|
|
"""
|
|
|
|
if title is None:
|
|
title = col_name
|
|
|
|
chart = df[col_name].plot.line()
|
|
return chart.properties(
|
|
width=800,
|
|
height=400,
|
|
title=title
|
|
)
|
|
|
|
|
|
def plot_columns(df, col_name, figsize=(15, 6), title=None, xlabel='Index'):
|
|
"""
|
|
Plot a columns from a Polars DataFrame using matplotlib.
|
|
|
|
Parameters:
|
|
-----------
|
|
df : polars.DataFrame
|
|
The Polars DataFrame
|
|
column_name : str
|
|
Name of the column to plot
|
|
figsize : tuple, default (15, 6)
|
|
Figure size as (width, height) in inches
|
|
title : str, optional
|
|
Plot title. If None, uses column name
|
|
xlabel : str, default 'Index'
|
|
X-axis label
|
|
ylabel : str, optional
|
|
Y-axis label. If None, uses column name
|
|
"""
|
|
if title is None:
|
|
title = col_name
|
|
|
|
chart = df[col_name].plot.line()
|
|
return chart.properties(
|
|
width=800,
|
|
height=400,
|
|
title=title
|
|
)
|
|
|
|
|
|
def model_trade_results(y_true, y_pred) -> pl.DataFrame:
|
|
"""Generate trade-level results from model predictions."""
|
|
|
|
trade_results = pl.DataFrame({
|
|
'y_pred': y_pred.squeeze(),
|
|
'y_true': y_true.squeeze()
|
|
}).with_columns([
|
|
(pl.col('y_pred').sign() == pl.col('y_true').sign()).alias('is_won'),
|
|
pl.col('y_pred').sign().alias('position')
|
|
]).with_columns([
|
|
(pl.col('position') * pl.col('y_true')).alias('trade_log_return')
|
|
]).with_columns([
|
|
pl.col('trade_log_return').cum_sum().alias('equity_curve')
|
|
]).with_columns(
|
|
(pl.col('equity_curve')-pl.col('equity_curve').cum_max()).alias('drawdown_log_return'),
|
|
)
|
|
return trade_results
|
|
|
|
|
|
def add_tx_fee(trades: pl.DataFrame, tx_fee: float, name: str):
|
|
tx_fee_col = (pl.col('exit_trade_value') * tx_fee + pl.col('entry_trade_value') * tx_fee).alias(f"tx_fee_{name}")
|
|
return trades.with_columns(tx_fee_col)
|
|
|
|
|
|
def add_tx_fees(trades: pl.DataFrame, maker_fee: float, taker_fee: float):
|
|
trades = add_tx_fee(trades, maker_fee, 'maker')
|
|
trades = add_tx_fee(trades, taker_fee, 'taker')
|
|
return trades
|
|
|
|
def add_tx_fees_log(trades: pl.DataFrame, maker_fee, taker_fee):
|
|
return trades.with_columns(
|
|
(pl.col('trade_log_return') + np.log(maker_fee)).alias('trade_log_return_net_maker'),
|
|
(pl.col('trade_log_return') + np.log(taker_fee)).alias('trade_log_return_net_taker'),
|
|
).with_columns(
|
|
pl.col('trade_log_return_net_maker').cum_sum().alias('equity_curve_net_maker'),
|
|
pl.col('trade_log_return_net_taker').cum_sum().alias('equity_curve_net_taker'),
|
|
)
|
|
|
|
def eval_model_performance(y_actual, y_pred, feature_names: List[str], target_name: str, annualized_rate: float) -> Dict[str, any]:
|
|
"""Calculate performance metrics for the trading model."""
|
|
trade_results = model_trade_results(y_actual, y_pred)
|
|
|
|
accuracy = trade_results['is_won'].mean()
|
|
avg_win = trade_results.filter(pl.col('is_won'))['trade_log_return'].mean()
|
|
avg_loss = trade_results.filter(~pl.col('is_won'))['trade_log_return'].mean()
|
|
expected_value = accuracy * avg_win + (1 - accuracy) * avg_loss
|
|
drawdown = (trade_results['equity_curve'] - trade_results['equity_curve'].cum_max())
|
|
max_drawdown = drawdown.min()
|
|
sharpe = trade_results['trade_log_return'].mean() / trade_results['trade_log_return'].std() if trade_results['trade_log_return'].std() > 0 else 0
|
|
annualized_sharpe = sharpe * annualized_rate
|
|
equity_trough = trade_results['equity_curve'].min()
|
|
equity_peak = trade_results['equity_curve'].max()
|
|
total_log_return = trade_results['trade_log_return'].sum()
|
|
std = trade_results['trade_log_return'].std()
|
|
return {
|
|
'features': ','.join(list(feature_names)),
|
|
'target': target_name,
|
|
'no_trades': len(trade_results),
|
|
'win_rate': accuracy,
|
|
'avg_win': avg_win,
|
|
'avg_loss': avg_loss,
|
|
'best_trade': trade_results['trade_log_return'].max(),
|
|
'worst_trade': trade_results['trade_log_return'].min(),
|
|
'ev': expected_value,
|
|
'std': std,
|
|
'total_log_return': total_log_return,
|
|
'compound_return': np.exp(total_log_return),
|
|
'max_drawdown': max_drawdown,
|
|
'equity_trough': equity_trough,
|
|
'equity_peak': equity_peak,
|
|
'sharpe': annualized_sharpe,
|
|
}
|
|
|
|
def train_reg_model(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, annualized_rate, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
|
|
df_train, df_test = timeseries_split(df, test_size=test_size)
|
|
if no_epochs is None:
|
|
no_epochs = 6000
|
|
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
|
|
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
|
|
|
|
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
|
|
|
|
|
|
|
|
def benchmark_reg_model(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, annualized_rate, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
|
|
df_train, df_test = timeseries_split(df, test_size=test_size)
|
|
if no_epochs is None:
|
|
no_epochs = 6000
|
|
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
|
|
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
|
|
|
|
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
|
|
|
|
perf = eval_model_performance(y_test, y_hat, features, target, annualized_rate)
|
|
|
|
weights, biases = get_linear_params(model)
|
|
perf['weights'] = str(weights)
|
|
perf['biases'] = str(biases)
|
|
|
|
return perf
|
|
|
|
|
|
def learn_model_trades(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
|
|
df = df.drop_nulls()
|
|
df_train, df_test = timeseries_split(df, test_size=test_size)
|
|
if no_epochs is None:
|
|
no_epochs = 6000
|
|
|
|
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
|
|
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
|
|
|
|
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, criterion=loss, optimizer=optimizer, lr = lr, logging = log)
|
|
|
|
return model_trade_results(y_test, y_hat)
|
|
|
|
def learn_model_trade_pnl(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
|
|
df_train, df_test = timeseries_split(df, test_size=test_size)
|
|
if no_epochs is None:
|
|
no_epochs = 6000
|
|
|
|
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
|
|
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
|
|
|
|
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
|
|
|
|
trade_results = model_trade_results(y_test, y_hat)
|
|
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.use_deterministic_algorithms(True)
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Linear):
|
|
torch.manual_seed(42) # ensures same init every time
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def get_linear_params(model: nn.Module) -> tuple[np.ndarray, float]:
|
|
"""Extract weights and bias from LinearModel as (w, b)."""
|
|
weight = model.linear.weight.detach().cpu().numpy().flatten()
|
|
bias = model.linear.bias.detach().cpu().numpy().item()
|
|
return weight, bias
|
|
|
|
def add_log_return_features(df: pl.DataFrame, col: str, forecast_horizon: int, max_no_lags = None):
|
|
if max_no_lags is None:
|
|
max_no_lags = 0
|
|
df = df.with_columns(log_return(col, forecast_horizon))
|
|
if max_no_lags > 0:
|
|
df = add_lags(df, log_return_col('close'), max_no_lags, forecast_horizon)
|
|
return df
|
|
|
|
def benchmark_linear_models(ts: pl.DataFrame, target: str, feature_pool: List[str], annualized_rate: int, max_no_features: int = 1, no_epochs = 200, loss = None, test_size=0.25) -> pl.DataFrame:
|
|
import models
|
|
|
|
ts = ts.drop_nulls()
|
|
|
|
benchmarks = []
|
|
fs = []
|
|
for i in range(1, max_no_features+1):
|
|
fs += list(itertools.combinations(feature_pool, i))
|
|
|
|
for features in fs:
|
|
m = models.LinearModel(len(features))
|
|
m.apply(init_weights)
|
|
benchmarks.append(benchmark_reg_model(ts, list(features), target, m, annualized_rate, no_epochs=no_epochs, loss=loss, test_size=test_size))
|
|
|
|
benchmark = pl.DataFrame(benchmarks)
|
|
return benchmark.sort('sharpe', descending=True)
|
|
|
|
|
|
# print out our learned params
|
|
def print_model_params(model: nn.Module):
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
print(f"{name}:\n{param.data.numpy()}")
|
|
|
|
|
|
def add_model_predictions(test_trades: pl.DataFrame, model: nn.Module, features: Union[str, List[str]]) -> pl.DataFrame:
|
|
if type(features) != list:
|
|
features = [features]
|
|
X_test = torch.tensor(test_trades[features].to_numpy(), dtype=torch.float32)
|
|
y_hat = model(X_test)
|
|
s = pl.Series('y_hat', model(X_test).detach().cpu().numpy().squeeze())
|
|
return test_trades.with_columns(s)
|
|
|
|
|
|
def add_trade_log_returns(trades: pl.DataFrame, pre_trade_values: Union[List[float],npt.NDArray[np.float32]], tx_fee: float, initial_capital: float) -> pl.DataFrame:
|
|
# add directional signal to indicate if we're going long or short
|
|
trades = trades.with_columns(pl.col('y_hat').sign().alias('dir_signal'))
|
|
# calculate trade log return
|
|
trades = trades.with_columns((pl.col('close_log_return') * pl.col('dir_signal')).alias('trade_log_return'))
|
|
# calculate the cumulative sum of the trade log returns - this is the equity curves in log space
|
|
trades = trades.with_columns(pl.col('trade_log_return').cum_sum().alias('cum_trade_log_return'))
|
|
trades = trades.with_columns(
|
|
# add pre trade values
|
|
pre_trade_values.alias('pre_trade_value'),
|
|
# add post trade values
|
|
(pre_trade_values * pl.col('trade_log_return').exp()).alias('post_trade_value'),
|
|
# add trade qty
|
|
(pre_trade_values / pl.col('open')).alias('trade_qty'),
|
|
)
|
|
|
|
trades = trades.with_columns(
|
|
# add signed trade quantities (the main output of our strategy)
|
|
(pl.col('trade_qty') * pl.col('dir_signal')).alias('signed_trade_qty'),
|
|
# add trade gross pnl
|
|
(pl.col('post_trade_value') - pl.col('pre_trade_value')).alias('trade_gross_pnl')
|
|
# add tx fees
|
|
(pl.col('pre_trade_value') * tx_fee + pl.col('post_trade_value') * tx_fee).alias('tx_fees')
|
|
)
|
|
trades = trades.with_columns(
|
|
# calculate each trade's profit after fees (net)
|
|
(pl.col('trade_gross_pnl')-pl.col('tx_fees')).alias('trade_net_pnl')
|
|
)
|
|
trades = trades.with_columns(
|
|
# calculate equity curve for gross profit
|
|
(initial_capital + pl.col('trade_gross_pnl').cum_sum()).alias('equity_curve_gross')
|
|
# calculate equity curve for net profit
|
|
(initial_capital + pl.col('trade_net_pnl').cum_sum()).alias('equity_curve_net')
|
|
)
|
|
|
|
def add_equity_curve(trades: pl.DataFrame, initial_capital: float, col_name: str, suffix: str) -> pl.DataFrame:
|
|
return trades.with_columns(
|
|
(initial_capital + pl.col(col_name).cum_sum()).alias(f'equity_curve_{suffix}')
|
|
)
|
|
|
|
def add_compounding_trades(trades, capital, leverage, maker_fee, taker_fee):
|
|
lev_capital = capital * leverage
|
|
# calculate entry and exit trade value and size
|
|
trades = trades.with_columns(
|
|
((pl.col('cum_trade_log_return').exp()) * lev_capital).shift().fill_null(lev_capital).alias('entry_trade_value'),
|
|
((pl.col('cum_trade_log_return').exp()) * lev_capital).alias('exit_trade_value'),
|
|
).with_columns(
|
|
(pl.col('entry_trade_value') / pl.col('open') * pl.col('dir_signal')).alias('signed_trade_qty'),
|
|
(pl.col('exit_trade_value')-pl.col('entry_trade_value')).alias('trade_gross_pnl'),
|
|
)
|
|
# add transaction fee
|
|
trades = add_tx_fees(trades, maker_fee, taker_fee)
|
|
# add net trade pnl
|
|
trades = trades.with_columns(
|
|
(pl.col('trade_gross_pnl') - pl.col('tx_fee_taker')).alias('trade_net_taker_pnl'),
|
|
(pl.col('trade_gross_pnl') - pl.col('tx_fee_maker')).alias('trade_net_maker_pnl'),
|
|
)
|
|
trades = add_equity_curve(trades, capital, 'trade_gross_pnl', 'gross')
|
|
# add net equity curves (both taker and maker)
|
|
trades = add_equity_curve(trades, capital, 'trade_net_taker_pnl', 'taker')
|
|
trades = add_equity_curve(trades, capital, 'trade_net_maker_pnl', 'maker')
|
|
return trades
|