Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] concat_sequences Stacks Batches as Time Steps for Single-Step Predictions #1808

Open
jarrodconnolly opened this issue Apr 4, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@jarrodconnolly
Copy link

Describe the bug
We’re predicting 1 step ahead (max_prediction_length=1) with TemporalFusionTransformer.predict(return_y=True) on 128 rows (2 batches of 64).

We expect y to be (128, 1) 1 actual per row. Instead, it’s (64, 2) batches stacked as time steps.

To Reproduce

import pandas as pd
import torch
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# 136 rows, 2 groups—68 each
data = pd.DataFrame({
    "TimeIdx": list(range(68)) + list(range(68)),
    "Symbol": ["A"] * 68 + ["B"] * 68,
    "Close": [1.0] * 136
})
data["Timestamp"] = pd.to_datetime(data["TimeIdx"], unit="h")

# Dataset
ts = TimeSeriesDataSet(
    data,
    time_idx="TimeIdx",
    target="Close",
    group_ids=["Symbol"],
    max_encoder_length=4,
    max_prediction_length=1,
    static_categoricals=["Symbol"]
)
dl = ts.to_dataloader(batch_size=64, train=False)

# Model
tft = TemporalFusionTransformer.from_dataset(ts, hidden_size=8, output_size=1)
preds = tft.predict(dl, return_y=True, mode="raw")
print("Stock y Shape:", preds.y[0].shape)  # (64, 2) - bug

# Manual concat
y_batches = [batch[1][0] for batch in dl]  # List of (64, 1)
y_fixed = torch.cat(y_batches, dim=0)
print("Fixed y Shape:", y_fixed.shape)     # (128, 1) - expected
assert y_fixed.shape == (128, 1), "Manual dim=0 works—stock concat_sequences fails"

Expected behavior

Output:  
- Stock y Shape: `(64, 2)`—wrong, treats 2 batches as 2 steps.  
- Fixed y Shape: `(128, 1)`—manual `torch.cat(..., dim=0)`—expected.  

Additional context
PredictCallback.on_predict_epoch_end concat_sequences in utils/_utils.py uses torch.cat(..., dim=1).
For 2 (64, 1) batches, dim=1 makes (64, 2) stacks horizontally (time).
Should use dim=0 for single-step, stacking vertically to (128, 1) (rows).
Multi-step (max_prediction_length > 1) needs dim=1, but not here.

Related: #1752, #1509, #1320 similar dim=1 issues.

Versions

from pytorch_forecasting import show_versions; show_versions()

System:
python: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ]
executable: /home/username/code/rich/.venv/bin/python
machine: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35

Python dependencies:
pip: None
pytorch-forecasting: 1.3.0
torch: 2.6.0
lightning: 2.5.1
numpy: 2.2.4
scipy: 1.15.2
pandas: 2.2.3
cpflows: None
matplotlib: None
optuna: None
optuna-integration: None
pytorch_optimizer: None
scikit-learn: 1.6.1
scikit-base: None
statsmodels: None

@jarrodconnolly jarrodconnolly added the bug Something isn't working label Apr 4, 2025
@github-project-automation github-project-automation bot moved this to Needs triage & validation in Bugfixing - pytorch-forecasting Apr 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests

1 participant