diff --git a/torch_geometric/io/sdf.py b/torch_geometric/io/sdf.py index 8451ac28a768..28a9c6c27c9e 100644 --- a/torch_geometric/io/sdf.py +++ b/torch_geometric/io/sdf.py @@ -1,5 +1,6 @@ import torch +from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.io import parse_txt_array from torch_geometric.utils import coalesce, one_hot @@ -19,10 +20,18 @@ def parse_sdf(src: str) -> Data: bond_block = lines[1 + num_atoms:1 + num_atoms + num_bonds] row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) - edge_index = torch.stack([row, col], dim=0) + edge_index = EdgeIndex( + torch.stack([row, col], dim=0), + is_undirected=True, + sparse_size=(num_atoms, num_atoms), + ) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) - edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms) + edge_index, edge_attr = coalesce( # type: ignore + edge_index, + edge_attr, + num_atoms, + ) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) diff --git a/torch_geometric/io/tu.py b/torch_geometric/io/tu.py index d997f43d3f4d..f5977d5398e2 100644 --- a/torch_geometric/io/tu.py +++ b/torch_geometric/io/tu.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.io import fs, read_txt_array from torch_geometric.utils import coalesce, cumsum, one_hot, remove_self_loops @@ -75,7 +76,11 @@ def read_tu_data( num_nodes = int(edge_index.max()) + 1 if x is None else x.size(0) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) - + edge_index = EdgeIndex( + edge_index, + is_undirected=True, + sparse_size=(num_nodes, num_nodes), + ) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data, slices = split(data, batch)