Skip to content

Commit 7f7c222

Browse files
authored
Graph Convolutional Network (#1163)
* Implement GCN model architecture * Comments added to GCN example * test added * match random seed with PR screenshot * Update index.rst
1 parent 8c16e96 commit 7f7c222

File tree

6 files changed

+322
-1
lines changed

6 files changed

+322
-1
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ snli/results
1212
word_language_model/model.pt
1313
fast_neural_style/saved_models
1414
fast_neural_style/saved_models.zip
15+
gcn/cora/
1516
docs/build
1617
docs/venv
1718

Diff for: docs/source/index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,11 @@ experiment with PyTorch.
169169

170170
`GO TO EXAMPLE <https://github.com/pytorch/examples/tree/main/mnist_forward_forward>`__ :opticon:`link-external`
171171

172+
---
173+
174+
Graph Convolutional Network
175+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
176+
177+
This example implements the `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`__ paper on the CORA database.
178+
179+
`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/gcn>`__ :opticon:`link-external`

Diff for: gcn/README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Graph Convolutional Network
2+
3+
This repository contains an implementation of Graph Convolutional Networks (GCN) based on the paper "Semi-Supervised Classification with Graph Convolutional Networks" by Thomas N. Kipf and Max Welling.
4+
5+
## Overview
6+
This project implements the GCN model proposed in the paper for semi-supervised node classification on graph-structured data. GCN leverages graph convolutions to aggregate information from neighboring nodes and learn node representations for downstream tasks. The implementation provides a flexible and efficient GCN model for graph-based machine learning tasks.
7+
8+
# Requirements
9+
- Python 3.7 or higher
10+
- PyTorch 2.0 or higher
11+
- Requests 2.31 or higher
12+
- NumPy 1.24 or higher
13+
14+
15+
# Installation
16+
```bash
17+
pip install -r requirements.txt
18+
python main.py
19+
```
20+
21+
# Dataset
22+
The implementation includes support for the Cora dataset, a standard benchmark dataset for graph-based machine learning tasks. The Cora dataset consists of scientific publications, where nodes represent papers and edges represent citation relationships. Each paper is associated with a binary label indicating one of seven classes. The dataset is downloaded, preprocessed and ready to use.
23+
24+
## Model Architecture
25+
The GCN model architecture follows the details provided in the paper. It consists of multiple graph convolutional layers with ReLU activation, followed by a final softmax layer for classification. The implementation supports customizable hyperparameters such as the number of hidden units, the number of layers, and dropout rate.
26+
27+
## Usage
28+
To train and evaluate the GCN model on the Cora dataset, use the following command:
29+
```bash
30+
python train.py --epochs 200 --lr 0.01 --l2 5e-4 --dropout-p 0.5 --hidden-dim 16 --val-every 20 --include-bias False --no-cuda False
31+
```
32+
33+
# Results
34+
The model achieves a classification accuracy of 82.5% on the test set of the Cora dataset after 200 epochs of training. This result is comparable to the performance reported in the original paper. However, the results can vary due to the randomness of the train/val/test split.
35+
36+
References
37+
Thomas N. Kipf and Max Welling. "Semi-Supervised Classification with Graph Convolutional Networks." Link to the paper
38+
39+
Original paper repository: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)

Diff for: gcn/main.py

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import os
2+
import time
3+
import requests
4+
import tarfile
5+
import numpy as np
6+
import argparse
7+
8+
import torch
9+
from torch import nn
10+
import torch.nn.functional as F
11+
from torch.optim import Adam
12+
13+
14+
class GraphConv(nn.Module):
15+
"""
16+
Graph Convolutional Layer described in "Semi-Supervised Classification with Graph Convolutional Networks".
17+
18+
Given an input feature representation for each node in a graph, the Graph Convolutional Layer aims to aggregate
19+
information from the node's neighborhood to update its own representation. This is achieved by applying a graph
20+
convolutional operation that combines the features of a node with the features of its neighboring nodes.
21+
22+
Mathematically, the Graph Convolutional Layer can be described as follows:
23+
24+
H' = f(D^(-1/2) * A * D^(-1/2) * H * W)
25+
26+
where:
27+
H: Input feature matrix with shape (N, F_in), where N is the number of nodes and F_in is the number of
28+
input features per node.
29+
A: Adjacency matrix of the graph with shape (N, N), representing the relationships between nodes.
30+
W: Learnable weight matrix with shape (F_in, F_out), where F_out is the number of output features per node.
31+
"""
32+
def __init__(self, input_dim, output_dim, use_bias=False):
33+
super(GraphConv, self).__init__()
34+
35+
# Initialize the weight matrix W (in this case called `kernel`)
36+
self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim))
37+
nn.init.xavier_normal_(self.kernel) # Initialize the weights using Xavier initialization
38+
39+
# Initialize the bias (if use_bias is True)
40+
self.bias = None
41+
if use_bias:
42+
self.bias = nn.Parameter(torch.Tensor(output_dim))
43+
nn.init.zeros_(self.bias) # Initialize the bias to zeros
44+
45+
def forward(self, input_tensor, adj_mat):
46+
"""
47+
Performs a graph convolution operation.
48+
49+
Args:
50+
input_tensor (torch.Tensor): Input tensor representing node features.
51+
adj_mat (torch.Tensor): Adjacency matrix representing graph structure.
52+
53+
Returns:
54+
torch.Tensor: Output tensor after the graph convolution operation.
55+
"""
56+
57+
support = torch.mm(input_tensor, self.kernel) # Matrix multiplication between input and weight matrix
58+
output = torch.spmm(adj_mat, support) # Sparse matrix multiplication between adjacency matrix and support
59+
# Add the bias (if bias is not None)
60+
if self.bias is not None:
61+
output = output + self.bias
62+
63+
return output
64+
65+
66+
class GCN(nn.Module):
67+
"""
68+
Graph Convolutional Network (GCN) as described in the paper `"Semi-Supervised Classification with Graph
69+
Convolutional Networks" <https://arxiv.org/pdf/1609.02907.pdf>`.
70+
71+
The Graph Convolutional Network is a deep learning architecture designed for semi-supervised node
72+
classification tasks on graph-structured data. It leverages the graph structure to learn node representations
73+
by propagating information through the graph using graph convolutional layers.
74+
75+
The original implementation consists of two stacked graph convolutional layers. The ReLU activation function is
76+
applied to the hidden representations, and the Softmax activation function is applied to the output representations.
77+
"""
78+
def __init__(self, input_dim, hidden_dim, output_dim, use_bias=True, dropout_p=0.1):
79+
super(GCN, self).__init__()
80+
81+
# Define the Graph Convolution layers
82+
self.gc1 = GraphConv(input_dim, hidden_dim, use_bias=use_bias)
83+
self.gc2 = GraphConv(hidden_dim, output_dim, use_bias=use_bias)
84+
85+
# Define the dropout layer
86+
self.dropout = nn.Dropout(dropout_p)
87+
88+
def forward(self, input_tensor, adj_mat):
89+
"""
90+
Performs forward pass of the Graph Convolutional Network (GCN).
91+
92+
Args:
93+
input_tensor (torch.Tensor): Input node feature matrix with shape (N, input_dim), where N is the number of nodes
94+
and input_dim is the number of input features per node.
95+
adj_mat (torch.Tensor): Adjacency matrix of the graph with shape (N, N), representing the relationships between
96+
nodes.
97+
98+
Returns:
99+
torch.Tensor: Output tensor with shape (N, output_dim), representing the predicted class probabilities for each node.
100+
"""
101+
102+
# Perform the first graph convolutional layer
103+
x = self.gc1(input_tensor, adj_mat)
104+
x = F.relu(x) # Apply ReLU activation function
105+
x = self.dropout(x) # Apply dropout regularization
106+
107+
# Perform the second graph convolutional layer
108+
x = self.gc2(x, adj_mat)
109+
110+
# Apply log-softmax activation function for classification
111+
return F.log_softmax(x, dim=1)
112+
113+
114+
def load_cora(path='./cora', device='cpu'):
115+
"""
116+
The graph convolutional operation rquires normalize the adjacency matrix: D^(-1/2) * A * D^(-1/2). This step
117+
scales the adjacency matrix such that the features of neighboring nodes are weighted appropriately during
118+
aggregation. The steps involved in the renormalization trick are as follows:
119+
- Compute the degree matrix.
120+
- Compute the inverse square root of the degree matrix.
121+
- Multiply the inverse square root of the degree matrix with the adjacency matrix.
122+
"""
123+
124+
# Set the paths to the data files
125+
content_path = os.path.join(path, 'cora.content')
126+
cites_path = os.path.join(path, 'cora.cites')
127+
128+
# Load data from files
129+
content_tensor = np.genfromtxt(content_path, dtype=np.dtype(str))
130+
cites_tensor = np.genfromtxt(cites_path, dtype=np.int32)
131+
132+
# Process features
133+
features = torch.FloatTensor(content_tensor[:, 1:-1].astype(np.int32)) # Extract feature values
134+
scale_vector = torch.sum(features, dim=1) # Compute sum of features for each node
135+
scale_vector = 1 / scale_vector # Compute reciprocal of the sums
136+
scale_vector[scale_vector == float('inf')] = 0 # Handle division by zero cases
137+
scale_vector = torch.diag(scale_vector).to_sparse() # Convert the scale vector to a sparse diagonal matrix
138+
features = scale_vector @ features # Scale the features using the scale vector
139+
140+
# Process labels
141+
classes, labels = np.unique(content_tensor[:, -1], return_inverse=True) # Extract unique classes and map labels to indices
142+
labels = torch.LongTensor(labels) # Convert labels to a tensor
143+
144+
# Process adjacency matrix
145+
idx = content_tensor[:, 0].astype(np.int32) # Extract node indices
146+
idx_map = {id: pos for pos, id in enumerate(idx)} # Create a dictionary to map indices to positions
147+
148+
# Map node indices to positions in the adjacency matrix
149+
edges = np.array(
150+
list(map(lambda edge: [idx_map[edge[0]], idx_map[edge[1]]],
151+
cites_tensor)), dtype=np.int32)
152+
153+
V = len(idx) # Number of nodes
154+
E = edges.shape[0] # Number of edges
155+
adj_mat = torch.sparse_coo_tensor(edges.T, torch.ones(E), (V, V), dtype=torch.int64) # Create the initial adjacency matrix as a sparse tensor
156+
adj_mat = torch.eye(V) + adj_mat # Add self-loops to the adjacency matrix
157+
158+
degree_mat = torch.sum(adj_mat, dim=1) # Compute the sum of each row in the adjacency matrix (degree matrix)
159+
degree_mat = torch.sqrt(1 / degree_mat) # Compute the reciprocal square root of the degrees
160+
degree_mat[degree_mat == float('inf')] = 0 # Handle division by zero cases
161+
degree_mat = torch.diag(degree_mat).to_sparse() # Convert the degree matrix to a sparse diagonal matrix
162+
163+
adj_mat = degree_mat @ adj_mat @ degree_mat # Apply the renormalization trick
164+
165+
return features.to_sparse().to(device), labels.to(device), adj_mat.to_sparse().to(device)
166+
167+
168+
def train_iter(epoch, model, optimizer, criterion, input, target, mask_train, mask_val, print_every=10):
169+
start_t = time.time()
170+
model.train()
171+
optimizer.zero_grad()
172+
173+
# Forward pass
174+
output = model(*input)
175+
loss = criterion(output[mask_train], target[mask_train]) # Compute the loss using the training mask
176+
177+
loss.backward()
178+
optimizer.step()
179+
180+
# Evaluate the model performance on training and validation sets
181+
loss_train, acc_train = test(model, criterion, input, target, mask_train)
182+
loss_val, acc_val = test(model, criterion, input, target, mask_val)
183+
184+
if epoch % print_every == 0:
185+
# Print the training progress at specified intervals
186+
print(f'Epoch: {epoch:04d} ({(time.time() - start_t):.4f}s) loss_train: {loss_train:.4f} acc_train: {acc_train:.4f} loss_val: {loss_val:.4f} acc_val: {acc_val:.4f}')
187+
188+
189+
def test(model, criterion, input, target, mask):
190+
model.eval()
191+
with torch.no_grad():
192+
output = model(*input)
193+
output, target = output[mask], target[mask]
194+
195+
loss = criterion(output, target)
196+
acc = (output.argmax(dim=1) == target).float().sum() / len(target)
197+
return loss.item(), acc.item()
198+
199+
200+
if __name__ == '__main__':
201+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
202+
203+
parser = argparse.ArgumentParser(description='PyTorch Graph Convolutional Network')
204+
parser.add_argument('--epochs', type=int, default=200,
205+
help='number of epochs to train (default: 200)')
206+
parser.add_argument('--lr', type=float, default=0.01,
207+
help='learning rate (default: 0.01)')
208+
parser.add_argument('--l2', type=float, default=5e-4,
209+
help='weight decay (default: 5e-4)')
210+
parser.add_argument('--dropout-p', type=float, default=0.5,
211+
help='dropout probability (default: 0.5)')
212+
parser.add_argument('--hidden-dim', type=int, default=16,
213+
help='dimension of the hidden representation (default: 16)')
214+
parser.add_argument('--val-every', type=int, default=20,
215+
help='epochs to wait for print training and validation evaluation (default: 20)')
216+
parser.add_argument('--include-bias', action='store_true', default=False,
217+
help='use bias term in convolutions (default: False)')
218+
parser.add_argument('--no-cuda', action='store_true', default=False,
219+
help='disables CUDA training')
220+
parser.add_argument('--no-mps', action='store_true', default=False,
221+
help='disables macOS GPU training')
222+
parser.add_argument('--dry-run', action='store_true', default=False,
223+
help='quickly check a single pass')
224+
parser.add_argument('--seed', type=int, default=42, metavar='S',
225+
help='random seed (default: 42)')
226+
args = parser.parse_args()
227+
228+
use_cuda = not args.no_cuda and torch.cuda.is_available()
229+
use_mps = not args.no_mps and torch.backends.mps.is_available()
230+
231+
torch.manual_seed(args.seed)
232+
233+
if use_cuda:
234+
device = torch.device('cuda')
235+
elif use_mps:
236+
device = torch.device('mps')
237+
else:
238+
device = torch.device('cpu')
239+
print(f'Using {device} device')
240+
241+
cora_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'
242+
print('Downloading dataset...')
243+
with requests.get(cora_url, stream=True) as tgz_file:
244+
with tarfile.open(fileobj=tgz_file.raw, mode='r:gz') as tgz_object:
245+
tgz_object.extractall()
246+
247+
print('Loading dataset...')
248+
features, labels, adj_mat = load_cora(device=device)
249+
idx = torch.randperm(len(labels)).to(device)
250+
idx_test, idx_val, idx_train = idx[:1000], idx[1000:1500], idx[1500:]
251+
252+
gcn = GCN(features.shape[1], args.hidden_dim, labels.max().item() + 1,args.include_bias, args.dropout_p).to(device)
253+
optimizer = Adam(gcn.parameters(), lr=args.lr, weight_decay=args.l2)
254+
criterion = nn.NLLLoss()
255+
256+
for epoch in range(args.epochs):
257+
train_iter(epoch + 1, gcn, optimizer, criterion, (features, adj_mat), labels, idx_train, idx_val, args.val_every)
258+
if args.dry_run:
259+
break
260+
261+
loss_test, acc_test = test(gcn, criterion, (features, adj_mat), labels, idx_test)
262+
print(f'Test set results: loss {loss_test:.4f} accuracy {acc_test:.4f}')

Diff for: gcn/requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch
2+
torchvision
3+
requests
4+
numpy

Diff for: run_python_examples.sh

+8-1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ function word_language_model() {
172172
python main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
173173
}
174174

175+
function gcn() {
176+
start
177+
python main.py --epochs 1 --dry-run || error "graph convolutional network failed"
178+
}
179+
175180
function clean() {
176181
cd $BASE_DIR
177182
echo "running clean to remove cruft"
@@ -192,7 +197,8 @@ function clean() {
192197
super_resolution/model_epoch_1.pth \
193198
time_sequence_prediction/predict*.pdf \
194199
time_sequence_prediction/traindata.pt \
195-
word_language_model/model.pt || error "couldn't clean up some files"
200+
word_language_model/model.pt \
201+
gcn/cora/ || error "couldn't clean up some files"
196202

197203
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
198204
}
@@ -217,6 +223,7 @@ function run_all() {
217223
vision_transformer
218224
word_language_model
219225
fx
226+
gcn
220227
}
221228

222229
# by default, run all examples

0 commit comments

Comments
 (0)