Skip to content

Commit a10ba71

Browse files
authored
Better error message for CatToNumTransform (#394)
Fixes #388
1 parent 2f1a876 commit a10ba71

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

test/transforms/test_cat_to_num_transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def test_cat_to_num_transform_on_categorical_only_dataset(with_nan):
6565
out.col_names_dict[stype.numerical]) == ((dataset.num_classes - 1) *
6666
total_cols))
6767

68+
tensor_frame.feat_dict[stype.categorical] += 1
69+
with pytest.raises(RuntimeError, match="contains new category"):
70+
# Raise informative error when input tensor frame contains new category
71+
out = transform(tensor_frame)
72+
6873

6974
@pytest.mark.parametrize('task_type', [
7075
TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION,

torch_frame/transforms/cat_to_num_transform.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,12 @@ def _forward(self, tf: TensorFrame) -> TensorFrame:
123123
count = torch.tensor(self.col_stats[col_name][StatType.COUNT][1],
124124
device=tf.device)
125125
feat = tensor[:, i]
126-
v = torch.index_select(count, 0, feat).unsqueeze(1).repeat(
127-
1, num_classes - 1)
126+
max_cat = feat.max()
127+
if max_cat >= len(count):
128+
raise RuntimeError(
129+
f"{col_name} contains new category {max_cat} not seen "
130+
f"during fit stage.")
131+
v = count[feat].unsqueeze(1).repeat(1, num_classes - 1)
128132
transformed_tensor[:, i * (num_classes - 1):(i + 1) *
129133
(num_classes - 1)] = ((v + target_mean) /
130134
(self.data_size + 1))

0 commit comments

Comments
 (0)