Files

1113 lines
41 KiB
Python

"""
research.py - Model Research & Development Utilities
====================================================
Reusable boilerplate code for machine learning research and model development.
This module eliminates copy-paste by centralizing common functions for:
- Data preprocessing and time series aggregation
- Feature engineering for financial data
- Model architecture inspection and debugging
- Visualization and exploratory data analysis
- Training utilities
Import this in all your research notebooks to save time and maintain consistency.
Author: MemLabs
Course: Build a Quant Trading System
"""
# ============================================================================
# IMPORTS
# ============================================================================
# Data manipulation and analysis
import polars as pl # Fast dataframes for financial data
from typing import Dict, List, Tuple, Union # Type hints for function signatures
# Machine learning framework
import torch # PyTorch for neural networks
import torch.nn as nn # Neural network modules
import torch.optim as optim # Optimization algorithms
# Numerical computing and datetime
import numpy as np # Numerical operations
import numpy.typing as npt
from datetime import datetime, timedelta # Date and time handling
# Visualization
import altair # Interactive plotting library
import matplotlib.pyplot as plt
import random
import re
import itertools
from pathlib import Path
from tqdm import tqdm
import os
SEED = 42
# ============================================================================
# TIME SERIES AGGREGATION
# ============================================================================
OHLC_AGGS = [
# Price statistics (core OHLC data)
pl.col("price").first().alias("open"), # Opening price
pl.col("price").max().alias("high"), # Highest price
pl.col("price").min().alias("low"), # Lowest price
pl.col("price").last().alias("close"), # Closing price (most important)
]
def get_trade_files(directory: str, sym: str) -> List[Path]:
"""
Get all files in directory that start with '{sym}-trades'.
Args:
directory: Path to directory to search
sym: Symbol prefix (e.g., 'BTCUSDT')
Returns:
List of Path objects matching the pattern
Example:
>>> files = get_trade_files('./data', 'BTCUSDT')
>>> # Returns: ['BTCUSDT-trades-2024.csv', 'BTCUSDT-trades-raw.parquet', ...]
"""
dir_path = Path(directory)
pattern = f"{sym}-trades*"
return sorted(dir_path.glob(pattern))
from pathlib import Path
from typing import List, Optional
def load_ohlc_timeseries(sym: str, time_interval: str):
return load_timeseries(sym, time_interval, OHLC_AGGS)
def load_timeseries(
sym: str,
time_interval: str,
aggs: List[pl.Expr],
data_path: Optional[str] = None
) -> pl.DataFrame:
"""
Load trade CSV files one by one, aggregate to time series, and concatenate.
Args:
sym: Symbol prefix (e.g., 'BTCUSDT')
time_interval: Time interval for aggregation (e.g., '1h', '5m')
aggs: List of aggregation expressions
data_path: Optional directory path. Defaults to './cache' if not provided
Returns:
Concatenated time series DataFrame
Example:
>>> # Use default './cache' directory
>>> ts = load_ohlc_ts('BTCUSDT', '1h', ohlc_aggs)
>>> # Specify custom directory
>>> ts = load_ohlc_ts('BTCUSDT', '1h', ohlc_aggs, data_path='./my_data')
"""
# Default to './cache' if not provided
if data_path is None:
data_path = './cache'
files = get_trade_files(data_path, sym)
if not files:
raise FileNotFoundError(f"No files found for {sym} in {data_path}")
# Process each file and collect results
ts_list = []
# Add progress bar
for file in tqdm(files, desc=f"Loading {sym}", unit="file"):
# Load trades from parquet
trades = pl.read_parquet(file)
# Ensure datetime column exists and is correct type
if "datetime" not in trades.columns:
raise ValueError(f"Column 'datetime' not found in {file.name}")
trades = trades.with_columns(
pl.col("datetime").cast(pl.Datetime)
).sort("datetime")
# Aggregate to time series
ts = trades.group_by_dynamic(
"datetime",
every=time_interval,
offset="0m"
).agg(aggs)
ts_list.append(ts)
# Concatenate all time series
result = pl.concat(ts_list)
# Sort by datetime and remove duplicates if any
result = result.sort("datetime").unique(subset=["datetime"])
return result
def load_timeseries_range(
sym: str,
time_interval: str,
start_date: datetime,
end_date: datetime,
agg_cols: Union[pl.Expr,List[pl.Expr]],
data_path: Optional[str] = None
) -> pl.DataFrame:
"""
Load and aggregate trade data for a symbol between start_date and end_date
into OHLC time series using the given time interval.
Expects daily files named like:
{symbol}-trades-YYYY-MM-DD.parquet
Example filename:
BTCUSDT-trades-2025-09-22.parquet
Args:
sym: Symbol prefix (e.g., 'BTCUSDT')
time_interval: Aggregation interval (e.g., '1h', '5m')
start_date: Start datetime (inclusive)
end_date: End datetime (inclusive)
data_path: Directory containing cached trade parquet files (default: './cache')
Returns:
Polars DataFrame with aggregated OHLC time series for the given range.
"""
if data_path is None:
data_path = "./cache"
if start_date > end_date:
raise ValueError("start_date must be before or equal to end_date")
ts_list = []
total_days = (end_date - start_date).days + 1
for i in tqdm(range(total_days), desc=f"Loading {sym}", unit="day"):
current_date = start_date + timedelta(days=i)
file_name = f"{sym}-trades-{current_date.strftime('%Y-%m-%d')}.parquet"
file_path = os.path.join(data_path, file_name)
if not os.path.exists(file_path):
tqdm.write(f"[WARNING] Missing file: {file_name}")
continue
try:
trades = pl.read_parquet(file_path)
if "datetime" not in trades.columns:
raise ValueError(f"Column 'datetime' not found in {file_name}")
trades = trades.with_columns(pl.col("datetime").cast(pl.Datetime))
ts = trades.group_by_dynamic("datetime", every=time_interval, offset="0m").agg(agg_cols)
ts_list.append(ts)
except Exception as e:
tqdm.write(f"[ERROR] {file_name}: {e}")
if not ts_list:
raise ValueError(f"No trade data found for {sym} in range {start_date} to {end_date}")
result = pl.concat(ts_list).sort("datetime").unique(subset=["datetime"])
return result
def load_ohlc_timeseries_range(
sym: str,
time_interval: str,
start_date: datetime,
end_date: datetime,
data_path: Optional[str] = None
) -> pl.DataFrame:
"""
Load and aggregate trade data for a symbol between start_date and end_date
into OHLC time series using the given time interval.
Expects daily files named like:
{symbol}-trades-YYYY-MM-DD.parquet
Example filename:
BTCUSDT-trades-2025-09-22.parquet
Args:
sym: Symbol prefix (e.g., 'BTCUSDT')
time_interval: Aggregation interval (e.g., '1h', '5m')
start_date: Start datetime (inclusive)
end_date: End datetime (inclusive)
data_path: Directory containing cached trade parquet files (default: './cache')
Returns:
Polars DataFrame with aggregated OHLC time series for the given range.
"""
if data_path is None:
data_path = "./cache"
if start_date > end_date:
raise ValueError("start_date must be before or equal to end_date")
ts_list = []
total_days = (end_date - start_date).days + 1
for i in tqdm(range(total_days), desc=f"Loading {sym}", unit="day"):
current_date = start_date + timedelta(days=i)
file_name = f"{sym}-trades-{current_date.strftime('%Y-%m-%d')}.parquet"
file_path = os.path.join(data_path, file_name)
if not os.path.exists(file_path):
tqdm.write(f"[WARNING] Missing file: {file_name}")
continue
try:
trades = pl.read_parquet(file_path)
if "datetime" not in trades.columns:
raise ValueError(f"Column 'datetime' not found in {file_name}")
trades = trades.with_columns(pl.col("datetime").cast(pl.Datetime))
ts = trades.group_by_dynamic("datetime", every=time_interval, offset="0m").agg(OHLC_AGGS)
ts_list.append(ts)
except Exception as e:
tqdm.write(f"[ERROR] {file_name}: {e}")
if not ts_list:
raise ValueError(f"No trade data found for {sym} in range {start_date} to {end_date}")
result = pl.concat(ts_list).sort("datetime").unique(subset=["datetime"])
return result
def sharpe_annualization_factor(interval: str,
trading_days_per_year: int = 365,
trading_hours_per_day: float = 24) -> float:
"""
Compute annualization factor (sqrt of periods per year) given a return interval.
interval : str
Frequency string like '1d', '1h', '30m', '15s'.
trading_days_per_year : int
Number of trading days in a year (default 252).
trading_hours_per_day : float
Number of trading hours in a trading day (default 6.5).
Returns
-------
float : annualization factor
"""
match = re.match(r"(\d+)([dhms])", interval.lower())
if not match:
raise ValueError("Interval must be like '1d', '2h', '15m', '30s'")
value, unit = int(match.group(1)), match.group(2)
# periods per year
if unit == 'd':
periods = trading_days_per_year / value
elif unit == 'h':
periods = trading_days_per_year * (trading_hours_per_day / value)
elif unit == 'm':
periods = trading_days_per_year * (trading_hours_per_day * 60 / value)
elif unit == 's':
periods = trading_days_per_year * (trading_hours_per_day * 3600 / value)
else:
raise ValueError(f"Unsupported unit: {unit}")
return np.sqrt(periods)
def ohlc_timeseries(df: pl.DataFrame, time_interval: str) -> pl.DataFrame:
"""
Convert tick-level trade data into OHLC (Open, High, Low, Close) bars.
This function aggregates raw trade data into standardized price bars
with basic volume and trade statistics. If you want to extend this then call regular_timeseries
Args:
df: DataFrame containing trade data with columns:
- datetime: Timestamp of each trade
- price: Execution price
- quote_qty: Trade size in quote currency (e.g., USDT)
- is_short: Boolean indicating if trade was a short sale
time_interval: Aggregation period (e.g., '1m', '5m', '15m', '1h', '1d')
Returns:
DataFrame with OHLC bars containing:
- datetime: Bar timestamp
- open: First price in interval
- high: Highest price in interval
- low: Lowest price in interval
- close: Last price in interval (most important for ML)
- volume: Total trading volume in quote currency
- trade_count: Number of individual trades
- short_ratio: Percentage of trades that were short sales
- mean_price: Average price (volume-weighted alternative)
Example:
>>> # Create 15-minute OHLC bars
>>> bars_15m = ohlc_timeseries(trades_df, '15m')
>>>
>>> # Create hourly bars for longer-term analysis
>>> bars_1h = ohlc_timeseries(trades_df, '1h')
"""
# Define aggregation expressions for OHLC calculation
# Use the generic time series aggregation function
return timeseries(df, time_interval, OHLC_AGGS)
def lag_col_names(col: str, n: int) -> List[str]:
return [f'{col}_lag_{i}' for i in range(1, n+1)]
def auto_reg_corr_matrx(df, target, max_no_lags) -> pl.DataFrame:
return df.drop_nulls().select([target]+lag_col_names(target, max_no_lags)).corr()
def log_returns_col(name: str, step_size = 1) -> pl.Expr:
return (pl.col(name)/pl.col(name).shift(step_size)).log().alias(f'{name}_log_return')
def timeseries(
df: pl.DataFrame,
time_interval: str,
aggs: Union[List[pl.Expr],pl.Expr]
) -> pl.DataFrame:
"""
Generic function for aggregating data into regular time intervals.
This is a flexible time series aggregation framework that can handle
any custom aggregation expressions. Used as the foundation for OHLC
bars and other time-based features.
Args:
df: DataFrame with a 'datetime' column
time_interval: Aggregation period (Polars duration string)
Examples: '1m', '5m', '15m', '1h', '4h', '1d'
aggs: List of Polars expressions defining aggregations to compute
Returns:
DataFrame with time-aggregated data
Technical Details:
- Uses left-closed intervals: [start_time, end_time)
- Bars start at round times (e.g., 09:00, 09:15, 09:30)
- Missing bars (no trades) are automatically excluded
Example:
>>> # Custom aggregation for volatility analysis
>>> custom_aggs = [
... pl.col("price").std().alias("price_volatility"),
... pl.col("volume").sum().alias("total_volume"),
... ]
>>> df_volatility = regular_timeseries(df, '1h', custom_aggs)
"""
return df.group_by_dynamic(
"datetime", # Column to group by (must be datetime type)
every=time_interval, # Aggregation frequencyß
offset="0m" # No offset (bars align to round times)
).agg(aggs)
# ============================================================================
# VISUALIZATION
# ============================================================================
def plot(df: pl.DataFrame, col: str, title: str = "") -> altair.Chart:
"""
Create a smooth density plot for analyzing feature distributions.
Useful for:
- Understanding data distributions before modeling
- Detecting outliers and skewness
- Comparing feature distributions across different time periods
- Validating data preprocessing steps
Args:
df: DataFrame containing the column to plot
col: Name of the column to visualize
title: Optional chart title (defaults to None)
Returns:
Altair Chart object (displays automatically in Jupyter)
Example:
>>> # Plot distribution of returns
>>> plot(df, 'returns', title='Return Distribution')
>>>
>>> # Plot price changes
>>> plot(df, 'price_change', title='Price Change Distribution')
Note:
The density estimation uses kernel density estimation (KDE)
with basis interpolation for smooth curves.
"""
return altair.Chart(df).mark_area(
opacity=0.7, # Semi-transparent fill
interpolate='basis' # Smooth curve interpolation
).transform_density(
col, # Column to compute density for
as_=[col, 'density'] # Output column names
).encode(
x=altair.X(f'{col}:Q', title=col), # X-axis: feature values
y=altair.Y('density:Q', title='Density') # Y-axis: probability density
).properties(
width=600,
height=400,
title=title if title else f'Distribution of {col}'
)
def plot_distribution(data: pl.DataFrame, col: str, label = None, no_bins = 100):
return altair.Chart(data).mark_bar().encode(
altair.X(f'{col}:Q', bin=altair.Bin(maxbins=no_bins)),
y='count()'
).properties(
width=600,
height=400,
title=f'Distribution of {label if label else col}'
).configure_scale(zero=False).add_params(
altair.selection_interval(bind='scales')
)
def plot_static_timeseries(ts: pl.DataFrame, sym: str, col: str, interval_size: str):
plt.figure(figsize=(12, 6))
plt.plot(ts['datetime'], ts[col], label=col) # or whatever column you want
plt.title(f'{sym} {interval_size} Bars')
plt.xlabel('time')
plt.ylabel(col)
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
def plot_multiple_lines(
df: pl.DataFrame,
cols_to_plot: List[str],
sym: str,
width: int = 15,
height: int = 6,
xlabel_unit: str = "Time Step"
):
import matplotlib.pyplot as plt
"""
Plots multiple columns from a Polars DataFrame on the same axes using Matplotlib.
The x-axis uses a simple numerical index (since no datetime column is present).
Parameters:
-----------
df : polars.DataFrame
The Polars DataFrame containing the columns to plot.
cols_to_plot : list[str]
A list of column names to plot (e.g., ['log_return', 'mean']).
sym : str
A symbol or identifier for the series (used in the title).
width : int, default 15
Width of the plot in inches.
height : int, default 6
Height of the plot in inches.
xlabel_unit : str, default 'Time Step'
Label for the X-axis (the numerical index).
"""
# 1. Create the numerical index for the x-axis
x_index = np.arange(len(df))
# 2. Set the figure size (controls the width/height)
plt.figure(figsize=(width, height))
# 3. Loop through the list of columns and plot each one
for col in cols_to_plot:
if col in df.columns:
# Extract column data as a NumPy array (efficient)
y_values = df[col].to_numpy()
# Plot the line, using the column name for the label
plt.plot(x_index, y_values, label=col)
else:
print(f"Warning: Column '{col}' not found in DataFrame.")
# 4. Finalize the plot
# Dynamically generate the title based on the symbol and columns
title_cols = ', '.join(cols_to_plot)
plt.title(f'{sym} Series: {title_cols}')
plt.xlabel(xlabel_unit)
plt.ylabel('Value') # Generic Y-label since multiple series are plotted
plt.legend(loc='best')
plt.grid(True, linestyle=':', alpha=0.6)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
plt.show()
def plot_dyn_timeseries(ts: pl.DataFrame, sym: str, col: str, time_interval: str ):
return altair.Chart(ts).mark_line(tooltip=True).encode(
x="datetime",
y=col
).properties(
width=800,
height=400,
title=f"{sym} {time_interval} {col}"
).configure_scale(zero=False).add_selection(
altair.selection_interval(bind='scales', encodings=['x']), # Only zoom x-axis
altair.selection_interval(bind='scales', encodings=['y']) # Only zoom y-axis
)
def to_tensor(x, dtype=None) -> torch.Tensor:
return torch.tensor(x.to_numpy(), dtype=torch.float32 if dtype is None else dtype)
# ============================================================================
# MODEL ANALYSIS
# ============================================================================
def print_model_complexity_ratio(m1, m1_name, m2, m2_name):
m1_params = total_model_params(m1)
m2_params = total_model_params(m2)
complexity_ratio = m2_params / m1_params
print(f"Complexity Comparsion:")
print(f"\t{m2_name} has {complexity_ratio:.1f}x more parameters than {m1_name}")
print(f"\tParametric difference: {m2_params - m1_params:,} additional parameters")
def total_model_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def print_model_info(model: torch.nn.Module, model_name: str) -> None:
"""
Print detailed information about a PyTorch model's architecture and parameters.
This function helps you understand:
- Model complexity (number of parameters)
- Which parameters are trainable vs frozen
- Overall model architecture
Useful for:
- Comparing different model architectures
- Debugging training issues
- Estimating memory requirements
- Understanding model capacity
Args:
model: PyTorch model (nn.Module)
model_name: Descriptive name for the model (e.g., 'LSTM Predictor')
Returns:
None (prints to console)
Example:
>>> model = MyTradingModel(input_size=10, hidden_size=64)
>>> print_model_info(model, 'Trading LSTM v1')
Output:
Trading LSTM v1:
Architecture: MyTradingModel(...)
Total parameters: 15,234
Trainable parameters: 15,234
Note:
Total parameters includes both trainable and frozen parameters.
For transfer learning, trainable params may be less than total.
"""
# Count all parameters in the model
total_params = sum(p.numel() for p in model.parameters())
# Count only parameters that will be updated during training
trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
# Print formatted model information
print(f"\n{'='*60}")
print(f"{model_name}")
print(f"{'='*60}")
print(f"\nArchitecture:")
print(f" {model}")
print(f"\nParameter Count:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
# Warn if some parameters are frozen
if total_params != trainable_params:
frozen_params = total_params - trainable_params
print(f" Frozen parameters: {frozen_params:,}")
print(f"\n ⚠️ Note: {frozen_params:,} parameters are frozen")
print(f"{'='*60}\n")
def _prefix_cols(df, prefix):
return df.rename({col: f"{prefix}_{col}" for col in df.columns})
def _prefix_close_ts(trades, time_interval, prefix):
return _prefix_cols(ohlc_timeseries(trades, time_interval), prefix)
def compare_ts_corr(x_df, x_prefix, y_df, y_prefix, time_interval, col = 'close'):
x_col, y_col = f'{x_prefix}_{col}',f'{y_prefix}_{col}'
joined_ts = pl.concat([
_prefix_close_ts(x_df, time_interval, x_prefix),
_prefix_close_ts(y_df, time_interval, y_prefix)
], how="horizontal")
return joined_ts.select(pl.corr(x_col, y_col)).item()
def log_return_col(col: str) -> str:
return f"{col}_log_return"
def log_return(col: str, shift_size: int = 1) -> pl.Expr:
return (pl.col(col)/pl.col(col).shift(shift_size)).log().alias(log_return_col(col))
def lag_cols(col: str, forecast_horizon: str, no_lags: int) -> List[pl.Expr]:
return [pl.col(col).shift(forecast_horizon * i).alias(f'{col}_lag_{i}') for i in range(1, no_lags + 1)]
def add_lags(df: pl.DataFrame, col: str, max_no_lags: int, forecast_step: int) -> pl.DataFrame:
return df.with_columns([pl.col(col).shift(i * forecast_step).alias(f'{col}_lag_{i}') for i in range(1, max_no_lags + 1)])
def batch_train_reg(
model: nn.Module,
X_train,
X_test,
y_train,
y_test,
no_epochs: int,
criterion=None,
optimizer=None,
logging=True,
lr=None
):
if criterion is None:
criterion = nn.L1Loss()
if lr is None:
lr = 0.0002
# Default optimizer
if optimizer is None:
# Use strong_wolfe line search (more stable)
optimizer = optim.LBFGS(
model.parameters(),
lr=1,
line_search_fn='strong_wolfe',
tolerance_grad=1e-7,
tolerance_change=1e-9
)
# Logging model info
if logging:
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters())}")
print("Model architecture:")
for name, param in model.named_parameters():
print(f" {name}: {param.shape} ({param.numel()} params)")
print("\nTraining model...")
train_loss = None
log_tick_size = max(no_epochs // 10, 1) # avoid zero division
# Training loop
if isinstance(optimizer, torch.optim.LBFGS):
# LBFGS requires a closure
for epoch in range(no_epochs):
def closure():
optimizer.zero_grad()
predictions = model(X_train)
loss = criterion(predictions, y_train)
loss.backward()
return loss
optimizer.step(closure)
with torch.no_grad():
train_loss = criterion(model(X_train), y_train).item()
if logging and (epoch + 1) % log_tick_size == 0:
print(f"Epoch [{epoch+1}/{no_epochs}], Loss: {train_loss:.6f}")
else:
# SGD/Adam loop
for epoch in range(no_epochs):
optimizer.zero_grad()
predictions = model(X_train)
loss = criterion(predictions, y_train)
loss.backward()
optimizer.step()
train_loss = loss.item()
if logging and (epoch + 1) % log_tick_size == 0:
print(f"Epoch [{epoch+1}/{no_epochs}], Loss: {loss.item():.6f}")
# After training
if logging:
print("\nLearned parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}:\n{param.data.numpy()}")
# Evaluation
model.eval()
with torch.no_grad():
y_hat = model(X_test)
test_loss = criterion(y_hat, y_test)
if logging:
print(f'\nTest Loss: {test_loss.item():.6f}, Train Loss: {train_loss:.6f}')
return y_hat
def timeseries_train_test_split(df: pl.DataFrame, features, target, test_size=0.25) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
df = df.drop_nulls()
X = to_tensor(df[features])
y = to_tensor(df[target]).reshape(-1, 1)
X_train, X_test = timeseries_split(X, test_size)
y_train, y_test = timeseries_split(y, test_size)
return X_train, X_test, y_train, y_test
def timeseries_split(t, test_size=0.25):
"""
Split a tensor or array into train/test sets based on a proportion.
Parameters
----------
t : torch.Tensor or np.ndarray
Time series data.
test_size : float, default 0.25
Proportion of data to use for testing. Must be between 0 and 1.
Returns
-------
train, test : same type as t
Train and test splits.
Raises
------
ValueError
If test_size is not strictly between 0 and 1.
"""
if not (0 < test_size < 1):
raise ValueError(f"test_size must be between 0 and 1 (got {test_size})")
split_idx = int(len(t) * (1 - test_size))
return t[:split_idx], t[split_idx:]
def plot_column(df, col_name, figsize=(15, 6), title=None, xlabel='Index'):
"""
Plot a column from a Polars DataFrame using matplotlib.
Parameters:
-----------
df : polars.DataFrame
The Polars DataFrame
column_name : str
Name of the column to plot
figsize : tuple, default (15, 6)
Figure size as (width, height) in inches
title : str, optional
Plot title. If None, uses column name
xlabel : str, default 'Index'
X-axis label
ylabel : str, optional
Y-axis label. If None, uses column name
"""
if title is None:
title = col_name
chart = df[col_name].plot.line()
return chart.properties(
width=800,
height=400,
title=title
)
def plot_columns(df, col_name, figsize=(15, 6), title=None, xlabel='Index'):
"""
Plot a columns from a Polars DataFrame using matplotlib.
Parameters:
-----------
df : polars.DataFrame
The Polars DataFrame
column_name : str
Name of the column to plot
figsize : tuple, default (15, 6)
Figure size as (width, height) in inches
title : str, optional
Plot title. If None, uses column name
xlabel : str, default 'Index'
X-axis label
ylabel : str, optional
Y-axis label. If None, uses column name
"""
if title is None:
title = col_name
chart = df[col_name].plot.line()
return chart.properties(
width=800,
height=400,
title=title
)
def model_trade_results(y_true, y_pred) -> pl.DataFrame:
"""Generate trade-level results from model predictions."""
trade_results = pl.DataFrame({
'y_pred': y_pred.squeeze(),
'y_true': y_true.squeeze()
}).with_columns([
(pl.col('y_pred').sign() == pl.col('y_true').sign()).alias('is_won'),
pl.col('y_pred').sign().alias('position')
]).with_columns([
(pl.col('position') * pl.col('y_true')).alias('trade_log_return')
]).with_columns([
pl.col('trade_log_return').cum_sum().alias('equity_curve')
]).with_columns(
(pl.col('equity_curve')-pl.col('equity_curve').cum_max()).alias('drawdown_log_return'),
)
return trade_results
def add_tx_fee(trades: pl.DataFrame, tx_fee: float, name: str):
tx_fee_col = (pl.col('exit_trade_value') * tx_fee + pl.col('entry_trade_value') * tx_fee).alias(f"tx_fee_{name}")
return trades.with_columns(tx_fee_col)
def add_tx_fees(trades: pl.DataFrame, maker_fee: float, taker_fee: float):
trades = add_tx_fee(trades, maker_fee, 'maker')
trades = add_tx_fee(trades, taker_fee, 'taker')
return trades
def add_tx_fees_log(trades: pl.DataFrame, maker_fee, taker_fee):
return trades.with_columns(
(pl.col('trade_log_return') + np.log(maker_fee)).alias('trade_log_return_net_maker'),
(pl.col('trade_log_return') + np.log(taker_fee)).alias('trade_log_return_net_taker'),
).with_columns(
pl.col('trade_log_return_net_maker').cum_sum().alias('equity_curve_net_maker'),
pl.col('trade_log_return_net_taker').cum_sum().alias('equity_curve_net_taker'),
)
def eval_model_performance(y_actual, y_pred, feature_names: List[str], target_name: str, annualized_rate: float) -> Dict[str, any]:
"""Calculate performance metrics for the trading model."""
trade_results = model_trade_results(y_actual, y_pred)
accuracy = trade_results['is_won'].mean()
avg_win = trade_results.filter(pl.col('is_won'))['trade_log_return'].mean()
avg_loss = trade_results.filter(~pl.col('is_won'))['trade_log_return'].mean()
expected_value = accuracy * avg_win + (1 - accuracy) * avg_loss
drawdown = (trade_results['equity_curve'] - trade_results['equity_curve'].cum_max())
max_drawdown = drawdown.min()
sharpe = trade_results['trade_log_return'].mean() / trade_results['trade_log_return'].std() if trade_results['trade_log_return'].std() > 0 else 0
annualized_sharpe = sharpe * annualized_rate
equity_trough = trade_results['equity_curve'].min()
equity_peak = trade_results['equity_curve'].max()
total_log_return = trade_results['trade_log_return'].sum()
std = trade_results['trade_log_return'].std()
return {
'features': ','.join(list(feature_names)),
'target': target_name,
'no_trades': len(trade_results),
'win_rate': accuracy,
'avg_win': avg_win,
'avg_loss': avg_loss,
'best_trade': trade_results['trade_log_return'].max(),
'worst_trade': trade_results['trade_log_return'].min(),
'ev': expected_value,
'std': std,
'total_log_return': total_log_return,
'compound_return': np.exp(total_log_return),
'max_drawdown': max_drawdown,
'equity_trough': equity_trough,
'equity_peak': equity_peak,
'sharpe': annualized_sharpe,
}
def train_reg_model(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, annualized_rate, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
df_train, df_test = timeseries_split(df, test_size=test_size)
if no_epochs is None:
no_epochs = 6000
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
def benchmark_reg_model(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, annualized_rate, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
df_train, df_test = timeseries_split(df, test_size=test_size)
if no_epochs is None:
no_epochs = 6000
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
perf = eval_model_performance(y_test, y_hat, features, target, annualized_rate)
weights, biases = get_linear_params(model)
perf['weights'] = str(weights)
perf['biases'] = str(biases)
return perf
def learn_model_trades(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
df = df.drop_nulls()
df_train, df_test = timeseries_split(df, test_size=test_size)
if no_epochs is None:
no_epochs = 6000
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, criterion=loss, optimizer=optimizer, lr = lr, logging = log)
return model_trade_results(y_test, y_hat)
def learn_model_trade_pnl(df: pl.DataFrame, features: List[str], target: str, model: nn.Module, test_size=0.25, loss = None, optimizer = None, no_epochs = None, log = False, lr = None):
df_train, df_test = timeseries_split(df, test_size=test_size)
if no_epochs is None:
no_epochs = 6000
X_train, y_train = torch.tensor(df_train[features].to_numpy(), dtype=torch.float32), torch.tensor(df_train[target].to_numpy(), dtype=torch.float32).reshape(-1, 1)
X_test, y_test = torch.tensor(df_test[features].to_numpy(), dtype=torch.float32), torch.tensor(df_test[target].to_numpy(),dtype=torch.float32).reshape(-1, 1)
y_hat = batch_train_reg(model, X_train, X_test, y_train, y_test, no_epochs, loss, optimizer, lr = lr, logging = log)
trade_results = model_trade_results(y_test, y_hat)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
def init_weights(m):
if isinstance(m, nn.Linear):
torch.manual_seed(42) # ensures same init every time
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def get_linear_params(model: nn.Module) -> tuple[np.ndarray, float]:
"""Extract weights and bias from LinearModel as (w, b)."""
weight = model.linear.weight.detach().cpu().numpy().flatten()
bias = model.linear.bias.detach().cpu().numpy().item()
return weight, bias
def add_log_return_features(df: pl.DataFrame, col: str, forecast_horizon: int, max_no_lags = None):
if max_no_lags is None:
max_no_lags = 0
df = df.with_columns(log_return(col, forecast_horizon))
if max_no_lags > 0:
df = add_lags(df, log_return_col('close'), max_no_lags, forecast_horizon)
return df
def benchmark_linear_models(ts: pl.DataFrame, target: str, feature_pool: List[str], annualized_rate: int, max_no_features: int = 1, no_epochs = 200, loss = None, test_size=0.25) -> pl.DataFrame:
import models
ts = ts.drop_nulls()
benchmarks = []
fs = []
for i in range(1, max_no_features+1):
fs += list(itertools.combinations(feature_pool, i))
for features in fs:
m = models.LinearModel(len(features))
m.apply(init_weights)
benchmarks.append(benchmark_reg_model(ts, list(features), target, m, annualized_rate, no_epochs=no_epochs, loss=loss, test_size=test_size))
benchmark = pl.DataFrame(benchmarks)
return benchmark.sort('sharpe', descending=True)
# print out our learned params
def print_model_params(model: nn.Module):
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}:\n{param.data.numpy()}")
def add_model_predictions(test_trades: pl.DataFrame, model: nn.Module, features: Union[str, List[str]]) -> pl.DataFrame:
if type(features) != list:
features = [features]
X_test = torch.tensor(test_trades[features].to_numpy(), dtype=torch.float32)
y_hat = model(X_test)
s = pl.Series('y_hat', model(X_test).detach().cpu().numpy().squeeze())
return test_trades.with_columns(s)
def add_trade_log_returns(trades: pl.DataFrame, pre_trade_values: Union[List[float],npt.NDArray[np.float32]], tx_fee: float, initial_capital: float) -> pl.DataFrame:
# add directional signal to indicate if we're going long or short
trades = trades.with_columns(pl.col('y_hat').sign().alias('dir_signal'))
# calculate trade log return
trades = trades.with_columns((pl.col('close_log_return') * pl.col('dir_signal')).alias('trade_log_return'))
# calculate the cumulative sum of the trade log returns - this is the equity curves in log space
trades = trades.with_columns(pl.col('trade_log_return').cum_sum().alias('cum_trade_log_return'))
trades = trades.with_columns(
# add pre trade values
pre_trade_values.alias('pre_trade_value'),
# add post trade values
(pre_trade_values * pl.col('trade_log_return').exp()).alias('post_trade_value'),
# add trade qty
(pre_trade_values / pl.col('open')).alias('trade_qty'),
)
trades = trades.with_columns(
# add signed trade quantities (the main output of our strategy)
(pl.col('trade_qty') * pl.col('dir_signal')).alias('signed_trade_qty'),
# add trade gross pnl
(pl.col('post_trade_value') - pl.col('pre_trade_value')).alias('trade_gross_pnl')
# add tx fees
(pl.col('pre_trade_value') * tx_fee + pl.col('post_trade_value') * tx_fee).alias('tx_fees')
)
trades = trades.with_columns(
# calculate each trade's profit after fees (net)
(pl.col('trade_gross_pnl')-pl.col('tx_fees')).alias('trade_net_pnl')
)
trades = trades.with_columns(
# calculate equity curve for gross profit
(initial_capital + pl.col('trade_gross_pnl').cum_sum()).alias('equity_curve_gross')
# calculate equity curve for net profit
(initial_capital + pl.col('trade_net_pnl').cum_sum()).alias('equity_curve_net')
)
def add_equity_curve(trades: pl.DataFrame, initial_capital: float, col_name: str, suffix: str) -> pl.DataFrame:
return trades.with_columns(
(initial_capital + pl.col(col_name).cum_sum()).alias(f'equity_curve_{suffix}')
)
def add_compounding_trades(trades, capital, leverage, maker_fee, taker_fee):
lev_capital = capital * leverage
# calculate entry and exit trade value and size
trades = trades.with_columns(
((pl.col('cum_trade_log_return').exp()) * lev_capital).shift().fill_null(lev_capital).alias('entry_trade_value'),
((pl.col('cum_trade_log_return').exp()) * lev_capital).alias('exit_trade_value'),
).with_columns(
(pl.col('entry_trade_value') / pl.col('open') * pl.col('dir_signal')).alias('signed_trade_qty'),
(pl.col('exit_trade_value')-pl.col('entry_trade_value')).alias('trade_gross_pnl'),
)
# add transaction fee
trades = add_tx_fees(trades, maker_fee, taker_fee)
# add net trade pnl
trades = trades.with_columns(
(pl.col('trade_gross_pnl') - pl.col('tx_fee_taker')).alias('trade_net_taker_pnl'),
(pl.col('trade_gross_pnl') - pl.col('tx_fee_maker')).alias('trade_net_maker_pnl'),
)
trades = add_equity_curve(trades, capital, 'trade_gross_pnl', 'gross')
# add net equity curves (both taker and maker)
trades = add_equity_curve(trades, capital, 'trade_net_taker_pnl', 'taker')
trades = add_equity_curve(trades, capital, 'trade_net_maker_pnl', 'maker')
return trades