0% found this document useful (0 votes)
24 views

GNN 01 Intro

This document summarizes a notebook that trains a graph convolutional network (GCN) on node classification tasks. It imports a karate club network dataset, defines a GCN model, trains the model over 200 epochs, and achieves 100% accuracy on the test set. Key steps include loading the dataset, defining the GCN architecture, training with cross entropy loss and Adam optimization, and monitoring loss and accuracy over epochs.

Uploaded by

vitormeriat
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
24 views

GNN 01 Intro

This document summarizes a notebook that trains a graph convolutional network (GCN) on node classification tasks. It imports a karate club network dataset, defines a GCN model, trains the model over 200 epochs, and achieves 100% accuracy on the test set. Key steps include loading the dataset, defining the GCN architecture, training with cross entropy loss and Adam optimization, and monitoring loss and accuracy over epochs.

Uploaded by

vitormeriat
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 8

11/26/23, 11:33 PM 1. Graph Convolutional Networks.

ipynb - Colaboratory

Graph Convolutional Networks


Chapter 1 of the Graph Neural Network Course

❤️ Created by @maximelabonne.
Companion notebook to execute the code from the following article: https://mlabonne.github.io/blog/intrognn/

1 !pip -q install torch_geometric


2
3 import torch
4 import numpy as np
5 import networkx as nx
6 import matplotlib.pyplot as plt

1 from torch_geometric.datasets import KarateClub


2
3 # Import dataset from PyTorch Geometric
4 dataset = KarateClub()
5
6 # Print information
7 print(dataset)
8 print('------------')
9 print(f'Number of graphs: {len(dataset)}')
10 print(f'Number of features: {dataset.num_features}')
11 print(f'Number of classes: {dataset.num_classes}')

KarateClub()
------------
Number of graphs: 1
Number of features: 34
Number of classes: 4

1 # Print first element


2 print(f'Graph: {dataset[0]}')

Graph: Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

1 data = dataset[0]
2
3 print(f'x = {data.x.shape}')
4 print(data.x)

x = torch.Size([34, 34])
tensor([[1., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.]])

1 print(f'edge_index = {data.edge_index.shape}')
2 print(data.edge_index)

edge_index = torch.Size([2, 156])


tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7,
7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
[ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2,
3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0,
1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1,
2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2,
3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0,
1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23,
24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32,
33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15,
18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 1/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory

1 from torch_geometric.utils import to_dense_adj


2
3 A = to_dense_adj(data.edge_index)[0].numpy().astype(int)
4 print(f'A = {A.shape}')
5 print(A)

A = (34, 34)
[[0 1 1 ... 1 0 0]
[1 0 1 ... 0 0 0]
[1 1 0 ... 0 1 0]
...
[1 0 0 ... 0 1 1]
[0 0 1 ... 1 0 1]
[0 0 0 ... 1 1 0]]

1 print(f'y = {data.y.shape}')
2 print(data.y)

y = torch.Size([34])
tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
2, 2, 0, 0, 2, 0, 0, 2, 0, 0])

1 print(f'train_mask = {data.train_mask.shape}')
2 print(data.train_mask)

train_mask = torch.Size([34])
tensor([ True, False, False, False, True, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, False, False, False])

1 print(f'Edges are directed: {data.is_directed()}')


2 print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
3 print(f'Graph has loops: {data.has_self_loops()}')

Edges are directed: False


Graph has isolated nodes: False
Graph has loops: False

1 from torch_geometric.utils import to_networkx


2
3 G = to_networkx(data, to_undirected=True)
4 plt.figure(figsize=(12,12))
5 plt.axis('off')
6 nx.draw_networkx(G,
7 pos=nx.spring_layout(G, seed=0),
8 with_labels=True,
9 node_size=800,
10 node_color=data.y,
11 cmap="hsv",
12 vmin=-2,
13 vmax=3,
14 width=0.8,
15 edge_color="grey",
16 font_size=14
17 )
18 plt.show()

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 2/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory

1 from torch.nn import Linear


2 from torch_geometric.nn import GCNConv
3
4
5 class GCN(torch.nn.Module):
6 def __init__(self):
7 super().__init__()
8 self.gcn = GCNConv(dataset.num_features, 3)
9 self.out = Linear(3, dataset.num_classes)
10
11 def forward(self, x, edge_index):
12 h = self.gcn(x, edge_index).relu()
13 z = self.out(h)
14 return h, z
15
16 model = GCN()
17 print(model)

GCN(
(gcn): GCNConv(34, 3)
(out): Linear(in_features=3, out_features=4, bias=True)
)

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 3/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 criterion = torch.nn.CrossEntropyLoss()
2 optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
3
4 # Calculate accuracy
5 def accuracy(pred_y, y):
6 return (pred_y == y).sum() / len(y)
7
8 # Data for animations
9 embeddings = []
10 losses = []
11 accuracies = []
12 outputs = []
13
14 # Training loop
15 for epoch in range(201):
16 # Clear gradients
17 optimizer.zero_grad()
18
19 # Forward pass
20 h, z = model(data.x, data.edge_index)
21
22 # Calculate loss function
23 loss = criterion(z, data.y)
24
25 # Calculate accuracy
26 acc = accuracy(z.argmax(dim=1), data.y)
27
28 # Compute gradients
29 loss.backward()
30
31 # Tune parameters
32 optimizer.step()
33
34 # Store data for animations
35 embeddings.append(h)
36 losses.append(loss)
37 accuracies.append(acc)
38 outputs.append(z.argmax(dim=1))
39
40 # Print metrics every 10 epochs
41 if epoch % 10 == 0:
42 print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')

Epoch 0 | Loss: 1.40 | Acc: 41.18%


Epoch 10 | Loss: 1.21 | Acc: 47.06%
Epoch 20 | Loss: 1.02 | Acc: 67.65%
Epoch 30 | Loss: 0.80 | Acc: 73.53%
Epoch 40 | Loss: 0.59 | Acc: 73.53%
Epoch 50 | Loss: 0.39 | Acc: 94.12%
Epoch 60 | Loss: 0.23 | Acc: 97.06%
Epoch 70 | Loss: 0.13 | Acc: 100.00%
Epoch 80 | Loss: 0.07 | Acc: 100.00%
Epoch 90 | Loss: 0.05 | Acc: 100.00%
Epoch 100 | Loss: 0.03 | Acc: 100.00%
Epoch 110 | Loss: 0.02 | Acc: 100.00%
Epoch 120 | Loss: 0.02 | Acc: 100.00%
Epoch 130 | Loss: 0.02 | Acc: 100.00%
Epoch 140 | Loss: 0.01 | Acc: 100.00%
Epoch 150 | Loss: 0.01 | Acc: 100.00%
Epoch 160 | Loss: 0.01 | Acc: 100.00%
Epoch 170 | Loss: 0.01 | Acc: 100.00%
Epoch 180 | Loss: 0.01 | Acc: 100.00%
Epoch 190 | Loss: 0.01 | Acc: 100.00%
Epoch 200 | Loss: 0.01 | Acc: 100.00%

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 4/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 %%capture
2 from IPython.display import HTML
3 from matplotlib import animation
4 plt.rcParams["animation.bitrate"] = 3000
5
6 def animate(i):
7 G = to_networkx(data, to_undirected=True)
8 nx.draw_networkx(G,
9 pos=nx.spring_layout(G, seed=0),
10 with_labels=True,
11 node_size=800,
12 node_color=outputs[i],
13 cmap="hsv",
14 vmin=-2,
15 vmax=3,
16 width=0.8,
17 edge_color="grey",
18 font_size=14
19 )
20 plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
21 fontsize=18, pad=20)
22
23 fig = plt.figure(figsize=(12, 12))
24 plt.axis('off')
25
26 anim = animation.FuncAnimation(fig, animate, \
27 np.arange(0, 200, 10), interval=500, repeat=True)
28 html = HTML(anim.to_html5_video())

1 display(html)

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 5/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory

1 # Print embeddings
2 print(f'Final embeddings = {h.shape}')
3 print(h)

Final embeddings = torch.Size([34, 3])


tensor([[1.9099e+00, 2.3584e+00, 7.4027e-01],
[2.6203e+00, 2.7997e+00, 0.0000e+00],
[2.2567e+00, 2.2962e+00, 6.4663e-01],
[2.0802e+00, 2.8785e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 2.9694e+00],
[0.0000e+00, 0.0000e+00, 3.3817e+00],
[0.0000e+00, 1.5008e-04, 3.4246e+00],
[1.7593e+00, 2.4292e+00, 2.4551e-01],
[1.9757e+00, 6.1032e-01, 1.8986e+00],
[1.7770e+00, 1.9950e+00, 6.7018e-01],
[0.0000e+00, 1.1683e-04, 2.9738e+00],
[1.8988e+00, 2.0512e+00, 2.6225e-01],
[1.7081e+00, 2.3618e+00, 1.9609e-01],
[1.8303e+00, 2.1591e+00, 3.5906e-01],
[2.0755e+00, 2.7468e-01, 1.9804e+00],
[1.9676e+00, 3.7185e-01, 2.0011e+00],
[0.0000e+00, 0.0000e+00, 3.4787e+00],
[1.6945e+00, 2.0350e+00, 1.9789e-01],
[1.9808e+00, 3.2633e-01, 2.1349e+00],
[1.7846e+00, 1.9585e+00, 4.8021e-01],
[2.0420e+00, 2.7512e-01, 1.9810e+00],
[1.7665e+00, 2.1357e+00, 4.0325e-01],
[1.9870e+00, 3.3886e-01, 2.0421e+00],
[2.0614e+00, 5.1042e-01, 2.4872e+00],
[1.8381e-01, 2.1094e+00, 2.2035e+00],
[1.8858e-01, 2.0701e+00, 2.1601e+00],
[2.2553e+00, 4.1764e-01, 2.0231e+00],
[1.6532e+00, 8.6745e-01, 2.2131e+00],
[2.4265e-01, 2.1862e+00, 1.6104e+00],
[2.5709e+00, 4.6342e-02, 2.3627e+00],
[2.1778e+00, 4.4730e-01, 2.0077e+00],
[3.8906e-02, 2.3443e+00, 1.9195e+00],
[3.0748e+00, 0.0000e+00, 3.0789e+00],
[3.4316e+00, 1.9716e-01, 2.5231e+00]], grad_fn=<ReluBackward0>)

1 # Get first embedding at epoch = 0


2 embed = h.detach().cpu().numpy()
3
4 fig = plt.figure(figsize=(12, 12))
5 ax = fig.add_subplot(projection='3d')
6 ax.patch.set_alpha(0)
7 plt.tick_params(left=False,
8 bottom=False,
9 labelleft=False,
10 labelbottom=False)
11 ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
12 s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
13
14 plt.show()

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 6/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory

1 %%capture
2
3 def animate(i):
4 embed = embeddings[i].detach().cpu().numpy()
5 ax.clear()
6 ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
7 s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
8 plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
9 fontsize=18, pad=40)
10
11 fig = plt.figure(figsize=(12, 12))
12 plt.axis('off')
13 ax = fig.add_subplot(projection='3d')
14 plt.tick_params(left=False,
15 bottom=False,
16 labelleft=False,
17 labelbottom=False)
18
19 anim = animation.FuncAnimation(fig, animate, \
20 np.arange(0, 200, 10), interval=800, repeat=True)
21 html = HTML(anim.to_html5_video())

1 display(html)

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 7/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory

https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 8/8

You might also like