forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindirect_assert_helper.py
50 lines (36 loc) · 1.13 KB
/
indirect_assert_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import torch
def first_arg(x, y):
return x[y]
def second_arg(x, y):
return x[:, y]
def same_pm_one(x, y):
return x[y + 1, y - 1]
def same_pp_one(x, y):
return x[y + 1, y + 1]
def store(x, y, z):
x[y + 1, y + 1] = z
if __name__ == "__main__":
_, fn_name, dims, dyn_shape, one_size = sys.argv
assert fn_name in ("first_arg", "second_arg", "same_pm_one", "same_pp_one", "store")
assert one_size in ("True", "False")
one_size = one_size == "True"
assert dims in ("2", "3")
shape_x = [3, 2, 4] if dims == "3" else [3, 2]
if one_size:
assert (
fn_name == "first_arg"
), "only first_arg can be tested for a special case of 1-size tensor"
shape_x[0] = 1
assert dyn_shape in ("True", "False")
dynamic_shapes = dyn_shape == "True"
x = torch.randn(shape_x, device="cuda")
y = torch.arange(4, device="cuda")
fn = vars()[fn_name]
fn = torch.compile(dynamic=dynamic_shapes)(fn)
if fn_name == "store":
shape = (y.numel(),) + x.shape[2:]
z = torch.randn(shape, device="cuda")
fn(x, y, z)
else:
fn(x, y)