forked from etched-ai/open-oasis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_attention_mlx.py
114 lines (95 loc) · 4.01 KB
/
test_attention_mlx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pytest
import numpy as np
import torch
import mlx.core as mx
import mlx.nn as nn
from attention_mlx import TemporalAxialAttention, SpatialAxialAttention
from rotary_embedding_mlx import RotaryEmbedding as RotaryEmbedding_MLX
from attention import TemporalAxialAttention as TorchTemporalAxialAttention
from attention import SpatialAxialAttention as TorchSpatialAxialAttention
from rotary_embedding_torch import RotaryEmbedding as RotaryEmbedding_Torch
@pytest.fixture
def test_config():
return {
'batch_size': 2,
'time_steps': 4,
'height': 8,
'width': 16,
'dim': 64,
'heads': 4,
'dim_head': 16,
}
def test_temporal_attention_shapes(test_config):
B, T, H, W = (test_config['batch_size'], test_config['time_steps'],
test_config['height'], test_config['width'])
dim = test_config['dim']
# Initialize both implementations
torch_rope = RotaryEmbedding_Torch(dim=dim // test_config['heads'])
torch_model = TorchTemporalAxialAttention(
dim=dim,
heads=test_config['heads'],
dim_head=test_config['dim_head'],
rotary_emb=torch_rope,
)
mlx_rope = RotaryEmbedding_MLX(dim=dim // test_config['heads'])
mlx_model = TemporalAxialAttention(
dim=dim,
heads=test_config['heads'],
dim_head=test_config['dim_head'],
rotary_emb=mlx_rope,
)
# Load weights from torch model to mlx model. Include bias. Convert torch parameters to mlx parameters.
mlx_model.to_qkv.weight = mx.array(torch_model.to_qkv.weight.data)
mlx_model.to_out.weight = mx.array(torch_model.to_out.weight.data)
mlx_model.to_out.bias = mx.array(torch_model.to_out.bias.data)
# Create input tensors
x_np = np.random.randn(B, T, H, W, dim).astype(np.float32)
x_mlx = mx.array(x_np)
x_torch = torch.from_numpy(x_np)
# Forward pass
out_mlx = mlx_model(x_mlx)
out_torch = torch_model(x_torch)
# Convert outputs to numpy for comparison
out_torch_np = out_torch.detach().numpy()
# Check shapes
assert out_mlx.shape == out_torch_np.shape
# assert out_mlx.shape == (B, T, H, W, dim)
# Check values are close (allowing for small numerical differences)
np.testing.assert_allclose(out_mlx, out_torch_np, rtol=1e-4, atol=1e-4)
def test_spatial_attention_shapes(test_config):
B, T, H, W = (test_config['batch_size'], test_config['time_steps'],
test_config['height'], test_config['width'])
dim = test_config['dim']
# Initialize both implementations
mlx_rope = RotaryEmbedding_MLX(dim=dim // test_config['heads']//2, freqs_for="pixel", max_freq=256)
mlx_model = SpatialAxialAttention(
dim=dim,
heads=test_config['heads'],
dim_head=test_config['dim_head'],
rotary_emb=mlx_rope,
)
torch_rope = RotaryEmbedding_Torch(dim=dim // test_config['heads']//2, freqs_for="pixel", max_freq=256)
torch_model = TorchSpatialAxialAttention(
dim=dim,
heads=test_config['heads'],
dim_head=test_config['dim_head'],
rotary_emb=torch_rope,
)
# Load weights from torch model to mlx model. Include bias. Convert torch parameters to mlx parameters.
mlx_model.to_qkv.weight = mx.array(torch_model.to_qkv.weight.data)
mlx_model.to_out.weight = mx.array(torch_model.to_out.weight.data)
mlx_model.to_out.bias = mx.array(torch_model.to_out.bias.data)
# Create input tensors
x_np = np.random.randn(B, T, H, W, dim).astype(np.float32)
x_mlx = mx.array(x_np)
x_torch = torch.from_numpy(x_np)
# Forward pass
out_mlx = mlx_model(x_mlx)
out_torch = torch_model(x_torch)
# Convert outputs to numpy for comparison
out_torch_np = out_torch.detach().numpy()
# Check shapes
assert out_mlx.shape == out_torch_np.shape
assert out_mlx.shape == (B, T, H, W, dim)
# Check values are close (allowing for small numerical differences)
np.testing.assert_allclose(out_mlx, out_torch_np, rtol=1e-4, atol=1e-5)