From fed3fc4506602472a581145e5992d3281fe28b3b Mon Sep 17 00:00:00 2001 From: RUPESH-KUMAR01 <118011558+RUPESH-KUMAR01@users.noreply.github.com> Date: Sun, 23 Feb 2025 23:31:01 +0530 Subject: [PATCH 1/4] [BUG] pytorch-forecasting#1752 Fixing --- pytorch_forecasting/data/timeseries.py | 2 +- pytorch_forecasting/models/base/_base_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..afe465877 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2513,7 +2513,7 @@ def to_dataloader( Parameters ---------- - train : bool, optional, default=Trze + train : bool, optional, default=True whether dataloader is used for training (True) or prediction (False). Will shuffle and drop last batch if True. Defaults to True. batch_size : int, optional, default=64 diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index 5e6c68391..d3b9a7898 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -387,7 +387,7 @@ def on_predict_epoch_end( if self.return_decoder_lengths: output["decoder_lengths"] = torch.cat(self._decode_lengths, dim=0) if self.return_y: - y = concat_sequences([yi[0] for yi in self._y]) + y = _torch_cat_na([yi[0] for yi in self._y]) if self._y[-1][1] is None: weight = None else: From 9124c5dab6e4fbfb7d97b6d548070a03c320a32a Mon Sep 17 00:00:00 2001 From: RUPESH-KUMAR01 <118011558+RUPESH-KUMAR01@users.noreply.github.com> Date: Wed, 26 Feb 2025 16:20:40 +0530 Subject: [PATCH 2/4] Changing back to old function with modifications --- .../models/base/_base_model.py | 2 +- pytorch_forecasting/utils/_utils.py | 28 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index d3b9a7898..5e6c68391 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -387,7 +387,7 @@ def on_predict_epoch_end( if self.return_decoder_lengths: output["decoder_lengths"] = torch.cat(self._decode_lengths, dim=0) if self.return_y: - y = _torch_cat_na([yi[0] for yi in self._y]) + y = concat_sequences([yi[0] for yi in self._y]) if self._y[-1][1] is None: weight = None else: diff --git a/pytorch_forecasting/utils/_utils.py b/pytorch_forecasting/utils/_utils.py index af93006cf..364c74d68 100644 --- a/pytorch_forecasting/utils/_utils.py +++ b/pytorch_forecasting/utils/_utils.py @@ -272,7 +272,33 @@ def concat_sequences( if isinstance(sequences[0], rnn.PackedSequence): return rnn.pack_sequence(sequences, enforce_sorted=False) elif isinstance(sequences[0], torch.Tensor): - return torch.cat(sequences, dim=1) + if sequences[0].ndim > 1: + first_lens = [xi.shape[1] for xi in sequences] + max_first_len = max(first_lens) + if max_first_len > min(first_lens): + sequences = [ + ( + xi + if xi.shape[1] == max_first_len + else torch.cat( + [ + xi, + torch.full( + ( + xi.shape[0], + max_first_len - xi.shape[1], + *xi.shape[2:], + ), + float("nan"), + device=xi.device, + ), + ], + dim=1, + ) + ) + for xi in sequences + ] + return torch.cat(sequences, dim=0) elif isinstance(sequences[0], (tuple, list)): return tuple( concat_sequences([sequences[ii][i] for ii in range(len(sequences))]) From a9f1188495df384e4e3cc1ef41b84e1d7577cf0a Mon Sep 17 00:00:00 2001 From: RUPESH-KUMAR01 <118011558+RUPESH-KUMAR01@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:32:34 +0530 Subject: [PATCH 3/4] Modifying the tests --- tests/test_models/test_temporal_fusion_transformer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 24c249bd5..5115fedd6 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -203,11 +203,15 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) return_index=True, return_x=True, return_y=True, - fast_dev_run=True, + fast_dev_run=2, trainer_kwargs=trainer_kwargs, ) pred_len = len(predictions.index) - + if isinstance(predictions.output, torch.Tensor): + assert predictions.output.shape == predictions.y[0].shape, "shape of predictions should match shape of targets" + else: + for i in range(len(predictions.output)): + assert predictions.output[i].shape == predictions.y[0][i].shape, "shape of predictions should match shape of targets" # check that output is of correct shape def check(x): if isinstance(x, (tuple, list)): From f606d7ac9ef157a1dbbcdb58380bdbeabe49f220 Mon Sep 17 00:00:00 2001 From: RUPESH-KUMAR01 <118011558+RUPESH-KUMAR01@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:53:55 +0530 Subject: [PATCH 4/4] Minor changes --- tests/test_models/test_temporal_fusion_transformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 5115fedd6..5cbf51909 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -208,10 +208,15 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) ) pred_len = len(predictions.index) if isinstance(predictions.output, torch.Tensor): - assert predictions.output.shape == predictions.y[0].shape, "shape of predictions should match shape of targets" + assert ( + predictions.output.shape == predictions.y[0].shape + ), "shape of predictions should match shape of targets" else: for i in range(len(predictions.output)): - assert predictions.output[i].shape == predictions.y[0][i].shape, "shape of predictions should match shape of targets" + assert ( + predictions.output[i].shape == predictions.y[0][i].shape + ), "shape of predictions should match shape of targets" + # check that output is of correct shape def check(x): if isinstance(x, (tuple, list)):