From 0dac4fe4cefc71dc3821f9f3aadd0a7299f82907 Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Wed, 26 Mar 2025 06:58:47 +0000 Subject: [PATCH 1/3] update --- torch_geometric/io/sdf.py | 7 ++++++- torch_geometric/io/tu.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torch_geometric/io/sdf.py b/torch_geometric/io/sdf.py index 8451ac28a768..c2a46441af4a 100644 --- a/torch_geometric/io/sdf.py +++ b/torch_geometric/io/sdf.py @@ -1,5 +1,6 @@ import torch +from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.io import parse_txt_array from torch_geometric.utils import coalesce, one_hot @@ -19,7 +20,11 @@ def parse_sdf(src: str) -> Data: bond_block = lines[1 + num_atoms:1 + num_atoms + num_bonds] row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) - edge_index = torch.stack([row, col], dim=0) + edge_index = EdgeIndex( + torch.stack([row, col], dim=0), + is_undirected=True, + sparse_size=(num_atoms, num_atoms), + ) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms) diff --git a/torch_geometric/io/tu.py b/torch_geometric/io/tu.py index d997f43d3f4d..f5977d5398e2 100644 --- a/torch_geometric/io/tu.py +++ b/torch_geometric/io/tu.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +from torch_geometric import EdgeIndex from torch_geometric.data import Data from torch_geometric.io import fs, read_txt_array from torch_geometric.utils import coalesce, cumsum, one_hot, remove_self_loops @@ -75,7 +76,11 @@ def read_tu_data( num_nodes = int(edge_index.max()) + 1 if x is None else x.size(0) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) - + edge_index = EdgeIndex( + edge_index, + is_undirected=True, + sparse_size=(num_nodes, num_nodes), + ) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data, slices = split(data, batch) From 99763bb5cc93573221c868d8fec65cc8aca7314b Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Fri, 28 Mar 2025 06:57:05 +0000 Subject: [PATCH 2/3] fix --- torch_geometric/io/sdf.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_geometric/io/sdf.py b/torch_geometric/io/sdf.py index c2a46441af4a..9ac353ee3dce 100644 --- a/torch_geometric/io/sdf.py +++ b/torch_geometric/io/sdf.py @@ -27,7 +27,11 @@ def parse_sdf(src: str) -> Data: ) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) - edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms) + edge_index, edge_attr = coalesce( + edge_index, # type: ignore + edge_attr, + num_atoms, + ) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) From 1d772c48f72360c1e58110e14b99f517068ed2f3 Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Fri, 28 Mar 2025 07:15:35 +0000 Subject: [PATCH 3/3] fix --- torch_geometric/io/sdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/io/sdf.py b/torch_geometric/io/sdf.py index 9ac353ee3dce..28a9c6c27c9e 100644 --- a/torch_geometric/io/sdf.py +++ b/torch_geometric/io/sdf.py @@ -27,8 +27,8 @@ def parse_sdf(src: str) -> Data: ) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) - edge_index, edge_attr = coalesce( - edge_index, # type: ignore + edge_index, edge_attr = coalesce( # type: ignore + edge_index, edge_attr, num_atoms, )