GNN 01 Intro
GNN 01 Intro
ipynb - Colaboratory
❤️ Created by @maximelabonne.
Companion notebook to execute the code from the following article: https://mlabonne.github.io/blog/intrognn/
KarateClub()
------------
Number of graphs: 1
Number of features: 34
Number of classes: 4
1 data = dataset[0]
2
3 print(f'x = {data.x.shape}')
4 print(data.x)
x = torch.Size([34, 34])
tensor([[1., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.]])
1 print(f'edge_index = {data.edge_index.shape}')
2 print(data.edge_index)
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 1/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
A = (34, 34)
[[0 1 1 ... 1 0 0]
[1 0 1 ... 0 0 0]
[1 1 0 ... 0 1 0]
...
[1 0 0 ... 0 1 1]
[0 0 1 ... 1 0 1]
[0 0 0 ... 1 1 0]]
1 print(f'y = {data.y.shape}')
2 print(data.y)
y = torch.Size([34])
tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
2, 2, 0, 0, 2, 0, 0, 2, 0, 0])
1 print(f'train_mask = {data.train_mask.shape}')
2 print(data.train_mask)
train_mask = torch.Size([34])
tensor([ True, False, False, False, True, False, False, False, True, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, False, False, False])
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 2/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
GCN(
(gcn): GCNConv(34, 3)
(out): Linear(in_features=3, out_features=4, bias=True)
)
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 3/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 criterion = torch.nn.CrossEntropyLoss()
2 optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
3
4 # Calculate accuracy
5 def accuracy(pred_y, y):
6 return (pred_y == y).sum() / len(y)
7
8 # Data for animations
9 embeddings = []
10 losses = []
11 accuracies = []
12 outputs = []
13
14 # Training loop
15 for epoch in range(201):
16 # Clear gradients
17 optimizer.zero_grad()
18
19 # Forward pass
20 h, z = model(data.x, data.edge_index)
21
22 # Calculate loss function
23 loss = criterion(z, data.y)
24
25 # Calculate accuracy
26 acc = accuracy(z.argmax(dim=1), data.y)
27
28 # Compute gradients
29 loss.backward()
30
31 # Tune parameters
32 optimizer.step()
33
34 # Store data for animations
35 embeddings.append(h)
36 losses.append(loss)
37 accuracies.append(acc)
38 outputs.append(z.argmax(dim=1))
39
40 # Print metrics every 10 epochs
41 if epoch % 10 == 0:
42 print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 4/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 %%capture
2 from IPython.display import HTML
3 from matplotlib import animation
4 plt.rcParams["animation.bitrate"] = 3000
5
6 def animate(i):
7 G = to_networkx(data, to_undirected=True)
8 nx.draw_networkx(G,
9 pos=nx.spring_layout(G, seed=0),
10 with_labels=True,
11 node_size=800,
12 node_color=outputs[i],
13 cmap="hsv",
14 vmin=-2,
15 vmax=3,
16 width=0.8,
17 edge_color="grey",
18 font_size=14
19 )
20 plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
21 fontsize=18, pad=20)
22
23 fig = plt.figure(figsize=(12, 12))
24 plt.axis('off')
25
26 anim = animation.FuncAnimation(fig, animate, \
27 np.arange(0, 200, 10), interval=500, repeat=True)
28 html = HTML(anim.to_html5_video())
1 display(html)
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 5/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 # Print embeddings
2 print(f'Final embeddings = {h.shape}')
3 print(h)
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 6/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
1 %%capture
2
3 def animate(i):
4 embed = embeddings[i].detach().cpu().numpy()
5 ax.clear()
6 ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2],
7 s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
8 plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%',
9 fontsize=18, pad=40)
10
11 fig = plt.figure(figsize=(12, 12))
12 plt.axis('off')
13 ax = fig.add_subplot(projection='3d')
14 plt.tick_params(left=False,
15 bottom=False,
16 labelleft=False,
17 labelbottom=False)
18
19 anim = animation.FuncAnimation(fig, animate, \
20 np.arange(0, 200, 10), interval=800, repeat=True)
21 html = HTML(anim.to_html5_video())
1 display(html)
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 7/8
11/26/23, 11:33 PM 1. Graph Convolutional Networks.ipynb - Colaboratory
https://colab.research.google.com/drive/1ZugveUjRrbSNwUbryeKJN2wyhGFRCw0q?usp=sharing#scrollTo=xBDrky2VV17_ 8/8