You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a dataloader, if the last batch is unequal to the batch size, when we run model.predict with return_y=True, we will get:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size XXX but got size XX for tensor number XXXX in the list.
This error persists even if we set drop_last=True when creating the dataloader, but the error will disappear if we set return=False.
I know it's intended that in inference the model will generate values for all batches regardless of batch sizes, yet it is strange that this method refuses to return the original y when a batch size is not matched.
We have 24 batches, 3 of which of size 7 and the last one of size 3. Hence causing the error.
Additional context
baseline() is only one example. Other models seem to have the same issue as they inherit the predict method for the same primitive class.
The same issue is also mentioned here: #1320 #1509
Describe the bug
For a dataloader, if the last batch is unequal to the batch size, when we run model.predict with return_y=True, we will get:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size XXX but got size XX for tensor number XXXX in the list.
This error persists even if we set drop_last=True when creating the dataloader, but the error will disappear if we set return=False.
I know it's intended that in inference the model will generate values for all batches regardless of batch sizes, yet it is strange that this method refuses to return the original y when a batch size is not matched.
To Reproduce
Expected behavior
We have 24 batches, 3 of which of size 7 and the last one of size 3. Hence causing the error.
Additional context
baseline() is only one example. Other models seem to have the same issue as they inherit the predict method for the same primitive class.
The same issue is also mentioned here:
#1320
#1509
Versions
System:
python: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0]
executable: /usr/bin/python3
machine: Linux-6.1.85+-x86_64-with-glibc2.35
Python dependencies:
pip: 24.1.2
sktime: 0.36.1
sklearn: 1.6.1
skbase: 0.12.2
numpy: 2.0.2
scipy: 1.14.1
pandas: 2.2.2
matplotlib: 3.10.0
joblib: 1.4.2
numba: 0.60.0
statsmodels: 0.14.4
pmdarima: None
statsforecast: None
tsfresh: None
tslearn: None
torch: 2.6.0+cu124
tensorflow: 2.18.0
The text was updated successfully, but these errors were encountered: