diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..c8252014d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,709 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + + self.time_series_dataset = time_series_dataset + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length + self.min_prediction_idx = min_prediction_idx + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + self.target_normalizer = target_normalizer + self.categorical_encoders = categorical_encoders + self.scalers = scalers + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults and derived attributes + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self._target_normalizer = RobustScaler() + else: + self._target_normalizer = target_normalizer + + self.time_series_metadata = time_series_dataset.get_metadata() + self._min_prediction_length = min_prediction_length or max_prediction_length + self._min_encoder_length = min_encoder_length or max_encoder_length + self._categorical_encoders = _coerce_to_dict(categorical_encoders) + self._scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.min_prediction_length = self._min_prediction_length + self.min_encoder_length = self._min_encoder_length + self.categorical_encoders = self._categorical_encoders + self.scalers = self._scalers + self.target_normalizer = self._target_normalizer + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self._min_encoder_length, + "min_prediction_length": self._min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + sample = self.time_series_dataset[series_idx] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.dataset = dataset + self.data_module = data_module + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.data_module._preprocess_data(series_idx) + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + encoder_cat = data["features"]["categorical"][encoder_indices] + encoder_cont = data["features"]["continuous"][encoder_indices] + + features = data["features"] + metadata = self.data_module.time_series_metadata + + known_cat_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + + known_cont_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + + cat_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.categorical_indices) + } + cont_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.continuous_indices) + } + + mapped_known_cat_indices = [ + cat_map[idx] for idx in known_cat_indices if idx in cat_map + ] + mapped_known_cont_indices = [ + cont_map[idx] for idx in known_cont_indices if idx in cont_map + ] + + decoder_cat = ( + features["categorical"][decoder_indices][:, mapped_known_cat_indices] + if mapped_known_cat_indices + else torch.zeros((pred_length, 0)) + ) + + decoder_cont = ( + features["continuous"][decoder_indices][:, mapped_known_cont_indices] + if mapped_known_cont_indices + else torch.zeros((pred_length, 0)) + ) + + x = { + "encoder_cat": encoder_cat, + "encoder_cont": encoder_cont, + "decoder_cat": decoder_cat, + "decoder_cont": decoder_cont, + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `time_series_dataset`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + series_idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares ``fit`` datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, + ) + + elif stage == "test": + if not hasattr(self, "test_dataset"): + self.test_windows = self._create_windows(self._test_indices) + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.test_windows, + self.add_relative_time_idx, + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_windows = self._create_windows(predict_indices) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.time_series_dataset, + self, + self.predict_windows, + self.add_relative_time_idx, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..788c08201 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,15 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, + _find_end_indices, + check_for_nonfinite, +) +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries + +__all__ = [ + "_find_end_indices", + "check_for_nonfinite", + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 99% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index 942a49721..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -32,6 +32,7 @@ ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,23 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..d5ecbcabb --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,323 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +from warnings import warn + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.utils._coerce import _coerce_to_list + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = target + self.group = group + self.weight = weight + self.num = num + self.cat = cat + self.known = known + self.unknown = unknown + self.static = static + + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults, coercion, and derived attributes + self._target = _coerce_to_list(target) + self._group = _coerce_to_list(group) + self._num = _coerce_to_list(num) + self._cat = _coerce_to_list(cat) + self._known = _coerce_to_list(known) + self._unknown = _coerce_to_list(unknown) + self._static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self._group + [self.weight] + self._target + ] + if self._group: + self._groups = self.data.groupby(self._group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.group = self._group + self.target = self._target + self.num = self._num + self.cat = self._cat + self.known = self._known + self.unknown = self._unknown + self.static = self._static + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self._target, + "x": self.feature_cols, + "st": self._static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self._target + self.feature_cols + self._static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self._cat else "F" + + self.metadata["col_known"][col] = "K" if col in self._known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + time = self.time + feature_cols = self.feature_cols + _target = self._target + _known = self._known + _static = self._static + _group = self._group + _groups = self._groups + _group_ids = self._group_ids + weight = self.weight + data_future = self.data_future + + group_id = _group_ids[index] + + if _group: + mask = _groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[time].max() + + data_vals = data[time].values + data_tgt_vals = data[_target].values + data_feat_vals = data[feature_cols].values + + result = { + "t": data_vals, + "y": torch.tensor(data_tgt_vals), + "x": torch.tensor(data_feat_vals), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[_static].iloc[0].values if _static else []), + "cutoff_time": cutoff_time, + } + + if data_future is not None: + if _group: + future_mask = self.data_future.groupby(_group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + data_fut_vals = future_data[time].values + + combined_times = np.concatenate([data_vals, data_fut_vals]) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(_target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data_vals): + idx = current_time_indices[t] + x_merged[idx] = data_feat_vals[i] + y_merged[idx] = data_tgt_vals[i] + + for i, t in enumerate(data_fut_vals): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(_known): + if col in feature_cols: + feature_idx = feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data_vals): + idx = current_time_indices[t] + weights_merged[idx] = data[weight].values[i] + + for i, t in enumerate(data_fut_vals): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata diff --git a/pytorch_forecasting/utils/_coerce.py b/pytorch_forecasting/utils/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/utils/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/tests/test_data/test_d1.py b/tests/test_data/test_d1.py new file mode 100644 index 000000000..b32c13213 --- /dev/null +++ b/tests/test_data/test_d1.py @@ -0,0 +1,379 @@ +import numpy as np +import pandas as pd +import pytest +import torch + +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_data(): + """Create time series data for testing.""" + dates = pd.date_range(start="2023-01-01", periods=10, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "target_value": np.sin(np.arange(10)) + 10, + "feature1": np.random.randn(10), + "feature2": np.random.randn(10), + "feature3": np.random.randn(10), + "group_id": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "weight": np.abs(np.random.randn(10)) + 0.1, + "static_feat": [10, 10, 10, 10, 10, 20, 20, 20, 20, 20], + } + ) + return data + + +@pytest.fixture +def future_data(): + """Create future time series data.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 10, 20, 20], + } + ) + return data + + +def test_init_basic(sample_data): + """Test basic initialization of TimeSeries class. + + Ensures that the class stores time, target, and correctly detects feature columns + when no group, known/unknown features, or static/weight features are specified.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + assert ts.time == "timestamp" + assert ts.target == ["target_value"] + assert len(ts.feature_cols) == 6 # All columns except timestamp, target_value + assert len(ts) == 1 # Single group by default + + +def test_init_with_groups(sample_data): + """Test initialization with group parameter. + + Verifies that data is grouped correctly and each group is handled as a + separate time series. + """ + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert ts.group == ["group_id"] + assert len(ts) == 2 # Two groups (1 and 2) + assert set(ts._group_ids) == {1, 2} + + +def test_init_with_features_categorization(sample_data): + """Test feature categorization. + + Ensures that numeric, categorical, and static features are categorized and + stored correctly in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], + static=["static_feat"], + ) + + assert ts.num == ["feature1", "feature2", "feature3"] + assert ts.cat == [] + assert ts.static == ["static_feat"] + assert ts.metadata["col_type"]["feature1"] == "F" + assert ts.metadata["col_type"]["feature2"] == "F" + + +def test_init_with_known_unknown(sample_data): + """Test known and unknown features classification. + + Checks if the known and unknown feature categorization is correctly set + and stored in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + assert ts.known == ["feature1"] + assert ts.unknown == ["feature2", "feature3"] + assert ts.metadata["col_known"]["feature1"] == "K" + assert ts.metadata["col_known"]["feature2"] == "U" + + +def test_init_with_weight(sample_data): + """Test initialization with weight parameter. + + Verifies that the weight column is stored correctly and excluded + from the feature columns.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + assert ts.weight == "weight" + assert "weight" not in ts.feature_cols + + +def test_getitem_basic(sample_data): + """Test __getitem__ with basic configuration. + + Checks the output structure of a single time series without grouping, + ensuring x, y are tensors of correct shapes.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + result = ts[0] + assert torch.is_tensor(result["y"]) + assert torch.is_tensor(result["x"]) + assert "t" in result + assert "cutoff_time" in result + assert len(result["y"]) == 10 # 10 data points + assert result["y"].shape == (10, 1) # One target variable + assert result["x"].shape[1] == 6 # Six feature columns + + +def test_getitem_with_groups(sample_data): + """Test __getitem__ with groups parameter. + + Verifies the per-group access using index and checks that each group + has the correct number of time steps.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + # group (1) + result_g1 = ts[0] + assert len(result_g1["t"]) == 5 # 5 data points in group 1 + + # group (2) + result_g2 = ts[1] + assert len(result_g2["t"]) == 5 # 5 data points in group 2 + + +def test_getitem_with_static(sample_data): + """Test __getitem__ with static features. + + Ensures static features are included in the output and correctly + mapped per group.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + group=["group_id"], + static=["static_feat"], + ) + + result_g1 = ts[0] + result_g2 = ts[1] + + assert torch.is_tensor(result_g1["st"]) + assert result_g1["st"].item() == 10 # Static feature for group 1 + assert result_g2["st"].item() == 20 # Static feature for group 2 + + +def test_getitem_with_weight(sample_data): + """Test __getitem__ with weight parameter. + + Validates that weights are correctly returned in the output and have the + expected length and type.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + result = ts[0] + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == 10 + + +def test_with_future_data(sample_data, future_data): + """Test with future data provided. + + Verifies that future time steps are appended to the end of each group, + especially for known features.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result_g1 = ts[0] # Group 1 + + assert len(result_g1["t"]) == 8 # 5 original + 3 future for group 1 + + feature1_idx = ts.feature_cols.index("feature1") + assert not torch.isnan( + result_g1["x"][-1, feature1_idx] + ) # feature1 is not NaN in last row + + +def test_future_data_with_weights(sample_data, future_data): + """Test handling of weights with future data. + + Ensures that weights from future data are combined properly and match the + time indices.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + weight="weight", + ) + + result = ts[0] # Group 1 + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == len(result["t"]) + + +def test_future_data_missing_columns(sample_data): + """Test handling when future data is missing some columns. + + Verifies the handling of missing feature columns in future data by + checking NaN padding.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + incomplete_future = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + # Missing feature2, feature3 + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=incomplete_future, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result = ts[0] + # Check that missing features are NaN in future timepoints + future_indices = np.where(result["t"] >= np.datetime64("2023-01-11"))[0] + feature2_idx = ts.feature_cols.index("feature2") + feature3_idx = ts.feature_cols.index("feature3") + assert torch.isnan(result["x"][future_indices[0], feature2_idx]) + assert torch.isnan(result["x"][future_indices[0], feature3_idx]) + + +def test_different_future_groups(sample_data): + """Test with future data that has different groups than original data. + + Ensures that groups present only in future data are ignored if not + in the original dataset.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + future_with_new_group = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 3, 3, 3], # Group 3 is new + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 30, 30, 30], + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=future_with_new_group, + time="timestamp", + target="target_value", + group=["group_id"], + ) + + # Original data has groups 1 and 2, but not 3 + assert len(ts) == 2 + assert 3 not in ts._group_ids + + +def test_multiple_targets(sample_data): + """Test handling of multiple target variables. + + Verifies that multiple target columns are handled and returned + as the correct shape in the output.""" + sample_data["target_value2"] = np.cos(np.arange(10)) + 5 + + ts = TimeSeries( + data=sample_data, time="timestamp", target=["target_value", "target_value2"] + ) + + result = ts[0] + assert result["y"].shape == (10, 2) # Two target variables + + +def test_empty_groups(): + """Test handling of empty groups. + + Confirms that the class handles datasets with a single group and + no empty group errors occur.""" + data = pd.DataFrame( + { + "timestamp": pd.date_range(start="2023-01-01", periods=5, freq="D"), + "target_value": np.random.randn(5), + "group_id": [1, 1, 1, 1, 1], # Only one group + } + ) + + ts = TimeSeries( + data=data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert len(ts) == 1 # Only one group + + +def test_metadata_structure(sample_data): + """Test the structure of metadata. + + Ensures the metadata dictionary includes the expected keys and + correct mappings of feature roles.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], # No categorical features + static=["static_feat"], + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + metadata = ts.get_metadata() + + assert "cols" in metadata + assert "col_type" in metadata + assert "col_known" in metadata + + assert metadata["cols"]["y"] == ["target_value"] + assert set(metadata["cols"]["x"]) == { + "feature1", + "feature2", + "feature3", + "group_id", + "weight", + "static_feat", + } + assert metadata["cols"]["st"] == ["static_feat"] + + assert metadata["col_type"]["feature1"] == "F" + assert metadata["col_type"]["feature2"] == "F" + + assert metadata["col_known"]["feature1"] == "K" + assert metadata["col_known"]["feature2"] == "U" diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..4051b852c --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,464 @@ +import numpy as np +import pandas as pd +import pytest + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Create a sample time series dataset with only numerical values.""" + num_groups = 10 + seq_length = 100 + + groups = [] + times = [] + values = [] + categorical_feature = [] + continuous_feature1 = [] + continuous_feature2 = [] + known_future = [] + + for g in range(num_groups): + for t in range(seq_length): + groups.append(g) + times.append(pd.Timestamp("2020-01-01") + pd.Timedelta(days=t)) + + value = 10 + 0.1 * t + 5 * np.sin(t / 10) + g * 2 + np.random.normal(0, 1) + values.append(value) + + categorical_feature.append(np.random.choice([0, 1, 2])) + + continuous_feature1.append(np.random.normal(g, 1)) + continuous_feature2.append(value * 0.5 + np.random.normal(0, 0.5)) + + known_future.append(t % 7) + + df = pd.DataFrame( + { + "group": groups, + "time": times, + "target": values, + "cat_feat": categorical_feature, + "cont_feat1": continuous_feature1, + "cont_feat2": continuous_feature2, + "known_future": known_future, + } + ) + + time_series = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["cont_feat1", "cont_feat2", "known_future"], + cat=["cat_feat"], + known=["known_future"], + ) + + return time_series + + +@pytest.fixture +def data_module(sample_timeseries_data): + """Create a data module instance.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.7, 0.15, 0.15), + ) + return dm + + +def test_init(sample_timeseries_data): + """Test the initialization of the data module. + + Verifies hyperparameter assignment and basic time_series_metadata creation.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=8, + ) + + assert dm.max_encoder_length == 24 + assert dm.max_prediction_length == 12 + assert dm.min_encoder_length == 24 + assert dm.min_prediction_length == 12 + assert dm.batch_size == 8 + assert dm.train_val_test_split == (0.7, 0.15, 0.15) + + assert isinstance(dm.time_series_metadata, dict) + assert "cols" in dm.time_series_metadata + + +def test_prepare_metadata(data_module): + """Test the metadata preparation method. + + Ensures that internal metadata keys are created correctly.""" + metadata = data_module._prepare_metadata() + + assert "encoder_cat" in metadata + assert "encoder_cont" in metadata + assert "decoder_cat" in metadata + assert "decoder_cont" in metadata + assert "target" in metadata + assert "max_encoder_length" in metadata + assert "max_prediction_length" in metadata + + assert metadata["max_encoder_length"] == 24 + assert metadata["max_prediction_length"] == 12 + + +def test_metadata_property(data_module): + """Test the metadata property. + + Confirms caching behavior and correct feature counts.""" + metadata = data_module.metadata + + # Should return the same object when called multiple times (caching) + assert data_module.metadata is metadata + + assert metadata["encoder_cat"] == 1 # cat_feat + assert metadata["encoder_cont"] == 3 # cont_feat1, cont_feat2, known_future + assert metadata["decoder_cat"] == 0 # No categorical features marked as known + assert metadata["decoder_cont"] == 1 # Only known_future marked as known + + +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 + + +def test_create_windows(data_module): + """Test the window creation logic. + + Validates window structure and length settings.""" + data_module.setup() + + windows = data_module._create_windows(data_module._train_indices) + + assert len(windows) > 0 + + for window in windows: + assert len(window) == 4 + assert window[2] == data_module.max_encoder_length + assert window[3] == data_module.max_prediction_length + + +def test_dataloader_creation(data_module): + """Test that dataloaders are created correctly. + + Checks batch sizes and dataloader instantiation across all stages.""" + data_module.setup() + + train_loader = data_module.train_dataloader() + assert train_loader.batch_size == data_module.batch_size + assert train_loader.num_workers == data_module.num_workers + + val_loader = data_module.val_dataloader() + assert val_loader.batch_size == data_module.batch_size + + data_module.setup(stage="test") + test_loader = data_module.test_dataloader() + assert test_loader.batch_size == data_module.batch_size + + data_module.setup(stage="predict") + predict_loader = data_module.predict_dataloader() + assert predict_loader.batch_size == data_module.batch_size + + +def test_processed_dataset(data_module): + """Test the internal ProcessedEncoderDecoderDataset class. + + Verifies sample structure and tensor dimensions for encoder/decoder inputs.""" + data_module.setup() + + assert len(data_module.train_dataset) == len(data_module.train_windows) + assert len(data_module.val_dataset) == len(data_module.val_windows) + + x, y = data_module.train_dataset[0] + + required_keys = [ + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + "encoder_mask", + "decoder_mask", + ] + + for key in required_keys: + assert key in x + + assert x["encoder_cat"].shape[0] == data_module.max_encoder_length + assert x["decoder_cat"].shape[0] == data_module.max_prediction_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x["decoder_cat"].shape[1] == known_cat_count + assert x["decoder_cont"].shape[1] == known_cont_count + + assert y.shape[0] == data_module.max_prediction_length + + +def test_collate_fn(data_module): + """Test the collate function that combines batch samples. + + Ensures proper stacking of dictionary keys and batch outputs.""" + data_module.setup() + + batch_size = 3 + batch = [data_module.train_dataset[i] for i in range(batch_size)] + + x_batch, y_batch = data_module.collate_fn(batch) + + for key in x_batch: + assert x_batch[key].shape[0] == batch_size + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_full_dataloader_iteration(data_module): + """Test a full iteration through the train dataloader. + + Confirms batch retrieval and tensor dimensions match configuration.""" + data_module.setup() + train_loader = data_module.train_dataloader() + + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[0] == data_module.batch_size + assert x_batch["encoder_cat"].shape[1] == data_module.max_encoder_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[0] == data_module.batch_size + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[0] == data_module.batch_size + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == data_module.batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_variable_encoder_lengths(sample_timeseries_data): + """Test with variable encoder lengths. + + Ensures random length behavior is respected and functional.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + min_encoder_length=12, + max_prediction_length=12, + batch_size=4, + randomize_length=True, + ) + + dm.setup() + assert dm.min_encoder_length == 12 + assert dm.max_encoder_length == 24 + + +def test_preprocess_data(data_module, sample_timeseries_data): + """Test the _preprocess_data method. + + Checks preprocessing output structure and alignment with raw data.""" + if not hasattr(data_module, "_split_indices"): + data_module.setup() + + series_idx = data_module._train_indices[0] + + processed = data_module._preprocess_data(series_idx) + + assert "features" in processed + assert "categorical" in processed["features"] + assert "continuous" in processed["features"] + assert "target" in processed + assert "time_mask" in processed + + original_sample = sample_timeseries_data[series_idx.item()] + expected_length = len(original_sample["y"]) + + assert processed["features"]["categorical"].shape[0] == expected_length + assert processed["features"]["continuous"].shape[0] == expected_length + assert processed["target"].shape[0] == expected_length + + +def test_with_static_features(): + """Test with static features included. + + Validates static feature support in both metadata and sample input.""" + df = pd.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "time": pd.date_range("2020-01-01", periods=6), + "target": [1, 2, 3, 4, 5, 6], + "static_cat": [0, 0, 0, 1, 1, 1], + "static_num": [10, 10, 10, 20, 20, 20], + "feature1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + } + ) + + ts = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["feature1", "static_num"], + static=["static_cat", "static_num"], + cat=["static_cat"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=2, + max_prediction_length=1, + batch_size=2, + ) + + dm.setup() + + metadata = dm.metadata + assert metadata["static_categorical_features"] == 1 + assert metadata["static_continuous_features"] == 1 + + x, y = dm.train_dataset[0] + assert "static_categorical_features" in x + assert "static_continuous_features" in x + + +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val + + +def test_multivariate_target(): + """Test with multivariate target (multiple target columns). + + Verifies correct handling of multivariate targets in data pipeline.""" + df = pd.DataFrame( + { + "group": np.repeat([0, 1], 50), + "time": np.tile(pd.date_range("2020-01-01", periods=50), 2), + "target1": np.random.normal(0, 1, 100), + "target2": np.random.normal(5, 2, 100), + "feature1": np.random.normal(0, 1, 100), + "feature2": np.random.normal(0, 1, 100), + } + ) + + ts = TimeSeries( + data=df, + time="time", + target=["target1", "target2"], + group=["group"], + num=["feature1", "feature2"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=10, + max_prediction_length=5, + batch_size=4, + ) + + dm.setup() + + x, y = dm.train_dataset[0] + assert y.shape[-1] == 2