Skip to content

Commit a73e90a

Browse files
committed
update
1 parent afffee9 commit a73e90a

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

torch_geometric/io/sdf.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
from torch_geometric import EdgeIndex
34
from torch_geometric.data import Data
45
from torch_geometric.io import parse_txt_array
56
from torch_geometric.utils import coalesce, one_hot
@@ -19,7 +20,11 @@ def parse_sdf(src: str) -> Data:
1920
bond_block = lines[1 + num_atoms:1 + num_atoms + num_bonds]
2021
row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1
2122
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
22-
edge_index = torch.stack([row, col], dim=0)
23+
edge_index = EdgeIndex(
24+
torch.stack([row, col], dim=0),
25+
is_undirected=True,
26+
sparse_size=(num_atoms, num_atoms),
27+
)
2328
edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1
2429
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
2530
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms)

torch_geometric/io/tu.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import Tensor
66

7+
from torch_geometric import EdgeIndex
78
from torch_geometric.data import Data
89
from torch_geometric.io import fs, read_txt_array
910
from torch_geometric.utils import coalesce, cumsum, one_hot, remove_self_loops
@@ -75,7 +76,11 @@ def read_tu_data(
7576
num_nodes = int(edge_index.max()) + 1 if x is None else x.size(0)
7677
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
7778
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)
78-
79+
edge_index = EdgeIndex(
80+
edge_index,
81+
is_undirected=True,
82+
sparse_size=(num_nodes, num_nodes),
83+
)
7984
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
8085
data, slices = split(data, batch)
8186

0 commit comments

Comments
 (0)