forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCAsmUtils.cuh
145 lines (126 loc) · 3.34 KB
/
THCAsmUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#ifndef THC_ASM_UTILS_INC
#define THC_ASM_UTILS_INC
// Collection of direct PTX functions
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
toInsert &= m;
toInsert <<= pos;
m <<= pos;
return (val & ~m) | toInsert;
#else
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
template <>
struct Bitfield<uint64_t> {
static __device__ __forceinline__
uint64_t getBitfield(uint64_t val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
return (val >> pos) & m;
#else
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ __forceinline__
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
toInsert &= m;
toInsert <<= pos;
m <<= pos;
return (val & ~m) | toInsert;
#else
uint64_t ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" :
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
__device__ __forceinline__ int getLaneId() {
#if defined(__HIP_PLATFORM_HCC__)
return __lane_id();
#else
int laneId;
asm("mov.s32 %0, %laneid;" : "=r"(laneId) );
return laneId;
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
std::uint64_t m = (1ull << getLaneId()) - 1ull;
return m;
#else
__device__ __forceinline__ unsigned getLaneMaskLt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
return mask;
#endif
}
#if defined (__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
return m;
}
#else
__device__ __forceinline__ unsigned getLaneMaskLe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
}
#endif
#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
std::uint64_t m = getLaneMaskLe();
return m ? ~m : m;
#else
__device__ __forceinline__ unsigned getLaneMaskGt() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
return mask;
#endif
}
#if defined(__HIP_PLATFORM_HCC__)
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
std::uint64_t m = getLaneMaskLt();
return ~m;
#else
__device__ __forceinline__ unsigned getLaneMaskGe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
return mask;
#endif
}
#endif // THC_ASM_UTILS_INC