Skip to content

Commit ca95b8f

Browse files
committed
fix: patch skorch.utils.to_tensor()
1 parent df8ecc4 commit ca95b8f

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torch_frame/utils/skorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
import skorch.utils
2+
3+
# TODO: make it more safe
4+
old_to_tensor = skorch.utils.to_tensor
5+
6+
def to_tensor(X, device, accept_sparse=False):
7+
if isinstance(X, TensorFrame):
8+
return X
9+
return old_to_tensor(X, device, accept_sparse)
10+
11+
skorch.utils.to_tensor = to_tensor
12+
import importlib
13+
importlib.reload(skorch.net)
14+
115
from typing import Any
216

317
import pandas as pd

0 commit comments

Comments
 (0)