-
Notifications
You must be signed in to change notification settings - Fork 63
/
module_test.go
79 lines (67 loc) · 1.26 KB
/
module_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package cu
import (
"path/filepath"
"testing"
"unsafe"
)
func TestModule(t *testing.T) {
devices, _ := NumDevices()
if devices == 0 {
t.Log("No Devices Found")
return
}
ctx, err := Device(0).MakeContext(SchedAuto)
if err != nil {
t.Fatal(err)
}
defer ctx.Destroy()
mod, err := Load(filepath.Join("testdata", "module_test.ptx"))
if err != nil {
t.Fatal(err)
}
defer mod.Unload()
f, err := mod.Function("testMemset")
if err != nil {
t.Fatal(err)
}
N := 1000
N4 := 4 * int64(N)
a := make([]float32, N)
A, err := MemAlloc(N4)
if err != nil {
t.Fatal(err)
}
defer MemFree(A)
aptr := unsafe.Pointer(&a[0])
if err = MemcpyHtoD(A, aptr, N4); err != nil {
t.Fatal(err)
}
var value float32
value = 42
var n int
n = N / 2
block := 128
grid := DivUp(N, block)
shmem := 0
args := []unsafe.Pointer{unsafe.Pointer(&A), unsafe.Pointer(&value), unsafe.Pointer(&n)}
if err = f.Launch(grid, 1, 1, block, 1, 1, shmem, Stream{}, args); err != nil {
t.Fatal(err)
}
if err = MemcpyDtoH(aptr, A, N4); err != nil {
t.Fatal(err)
}
for i := 0; i < N/2; i++ {
if a[i] != 42 {
t.Fail()
}
}
for i := N / 2; i < N; i++ {
if a[i] != 0 {
t.Fail()
}
}
}
// Integer division rounded up.
func DivUp(x, y int) int {
return ((x - 1) / y) + 1
}