Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions pyfolio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,37 @@ def estimate_intraday(returns, positions, transactions, EOD_hour=23):
Daily net position values, resampled for intraday behavior.
"""

# Construct DataFrame of transaction amounts
txn_val = transactions.copy()
# Ensure datetime index
txn_val.index = pd.to_datetime(txn_val.index, errors="coerce")

# Remove invalid timestamps
txn_val = txn_val[txn_val.index.notnull()]

# Remove timezone if present
if getattr(txn_val.index, "tz", None) is not None:
txn_val.index = txn_val.index.tz_localize(None)

# Ensure sorted time index
txn_val = txn_val.sort_index()

# Remove duplicate columns
txn_val = txn_val.loc[:, ~txn_val.columns.duplicated()]

# Set index name safely
txn_val.index.names = ['date']

# Avoid duplicate 'date' column
if 'date' in txn_val.columns:
txn_val.drop(columns=['date'], inplace=True)

# Create date column
txn_val['date'] = txn_val.index.date

# Prevent crashes if no transactions
if txn_val.empty:
return positions

# Calculate transaction values
txn_val['value'] = txn_val.amount * txn_val.price
txn_val = txn_val.reset_index().pivot_table(
index='date', values='value',
Expand All @@ -361,8 +389,10 @@ def estimate_intraday(returns, positions, transactions, EOD_hour=23):

# Shift EOD positions to positions at start of next trading day
positions_shifted = positions.copy().shift(1).fillna(0)
starting_capital = positions.iloc[0].sum() / (1 + returns[0])
positions_shifted.cash[0] = starting_capital
starting_capital = positions.iloc[0].sum() / (1 + returns.iloc[0])
positions_shifted.loc[
positions_shifted.index[0], "cash"
] = starting_capital

# Format and add start positions to intraday position changes
txn_val.index = txn_val.index.normalize()
Expand Down Expand Up @@ -538,4 +568,4 @@ def sample_colormap(cmap_name, n_samples):
for i in np.linspace(0, 1, n_samples):
colors.append(colormap(i))

return colors
return colors