代码片段记录12

1️⃣[CUDA time]

正确测试代码在cuda运行时间。需要加上torch.cuda.synchronize()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import time

a = torch.randint(high=1000, size=(20, 200, 256)).double().cuda()
b = torch.randint(high=1000, size=(20, 200, 256)).double().cuda()


torch.cuda.synchronize()
start = time.time()

M = torch.bmm(a, b.transpose(1, 2))

torch.cuda.synchronize()
end = time.time()

print("bmm", end - start)
print("max_mem", torch.cuda.max_memory_allocated())

torch.cuda.synchronize()
start = time.time()

local_a = a.unsqueeze(2)
local_b = b.unsqueeze(1)
N = (local_a*local_b).sum(-1)

torch.cuda.synchronize()
end = time.time()

print("element-wise", end - start)
print("max_mem", torch.cuda.max_memory_allocated())

print("output difference (should be 0)", (N - M).abs().max())
print("In single precision this can fail because of the size of the tensors.")
print("Using double should always work")