diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..afe465877 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2513,7 +2513,7 @@ def to_dataloader( Parameters ---------- - train : bool, optional, default=Trze + train : bool, optional, default=True whether dataloader is used for training (True) or prediction (False). Will shuffle and drop last batch if True. Defaults to True. batch_size : int, optional, default=64 diff --git a/pytorch_forecasting/utils/_utils.py b/pytorch_forecasting/utils/_utils.py index af93006cf..364c74d68 100644 --- a/pytorch_forecasting/utils/_utils.py +++ b/pytorch_forecasting/utils/_utils.py @@ -272,7 +272,33 @@ def concat_sequences( if isinstance(sequences[0], rnn.PackedSequence): return rnn.pack_sequence(sequences, enforce_sorted=False) elif isinstance(sequences[0], torch.Tensor): - return torch.cat(sequences, dim=1) + if sequences[0].ndim > 1: + first_lens = [xi.shape[1] for xi in sequences] + max_first_len = max(first_lens) + if max_first_len > min(first_lens): + sequences = [ + ( + xi + if xi.shape[1] == max_first_len + else torch.cat( + [ + xi, + torch.full( + ( + xi.shape[0], + max_first_len - xi.shape[1], + *xi.shape[2:], + ), + float("nan"), + device=xi.device, + ), + ], + dim=1, + ) + ) + for xi in sequences + ] + return torch.cat(sequences, dim=0) elif isinstance(sequences[0], (tuple, list)): return tuple( concat_sequences([sequences[ii][i] for ii in range(len(sequences))]) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 24c249bd5..5cbf51909 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -203,10 +203,19 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) return_index=True, return_x=True, return_y=True, - fast_dev_run=True, + fast_dev_run=2, trainer_kwargs=trainer_kwargs, ) pred_len = len(predictions.index) + if isinstance(predictions.output, torch.Tensor): + assert ( + predictions.output.shape == predictions.y[0].shape + ), "shape of predictions should match shape of targets" + else: + for i in range(len(predictions.output)): + assert ( + predictions.output[i].shape == predictions.y[0][i].shape + ), "shape of predictions should match shape of targets" # check that output is of correct shape def check(x):