forked from HipGraph/FusedMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fusedMM_internal.h
90 lines (80 loc) · 2.74 KB
/
fusedMM_internal.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#ifndef FUSED_SDMM_SPMM_INTERNAL_H
#define FUSED_SDMM_SPMM_INTERNAL_H
#ifdef __cplusplus
extern "C"
{
#endif
/*
* NOTE:
* This header file is meant for internal use to implement the general kernel,
* not intended for the user. See fusedMM.h for function prototypes
*/
/*
* Function pointer for user define functions
*/
typedef int (*FP_VOP_UDEF_FUNC)(INDEXTYPE lhs_dim, const VALUETYPE *lhs,
INDEXTYPE rhs_dim, const VALUETYPE *rhs, INDEXTYPE out_dim,
VALUETYPE *out);
typedef int (*FP_ROP_UDEF_FUNC)(INDEXTYPE lhs_dim, const VALUETYPE *lhs,
INDEXTYPE rhs_dim, const VALUETYPE *rhs, VALUETYPE *out);
typedef int (*FP_SOP_UDEF_FUNC)(VALUETYPE val, VALUETYPE *out);
typedef int (*FP_VSC_UDEF_FUNC)(INDEXTYPE rhs_dim, const VALUETYPE *rhs,
VALUETYPE scal, INDEXTYPE out_dim, VALUETYPE *out);
typedef int (*FP_AOP_UDEF_FUNC)(INDEXTYPE rhs_dim, const VALUETYPE *rhs,
INDEXTYPE out_dim, VALUETYPE *out);
/*
* Function pointer for each stage of operations
*/
typedef int (*FP_VOP_FUNC)(INDEXTYPE lhs_dim, const VALUETYPE *lhs,
INDEXTYPE rhs_dim, const VALUETYPE *rhs, INDEXTYPE out_dim,
VALUETYPE *out);
typedef int (*FP_ROP_FUNC)(INDEXTYPE lhs_dim, const VALUETYPE *lhs,
INDEXTYPE rhs_dim, const VALUETYPE *rhs, VALUETYPE *out);
typedef int (*FP_SOP_FUNC)(VALUETYPE val, VALUETYPE *out);
typedef int (*FP_VSC_FUNC)(INDEXTYPE rhs_dim, const VALUETYPE *rhs,
VALUETYPE scal, INDEXTYPE out_dim, VALUETYPE *out);
typedef int (*FP_AOP_FUNC)(INDEXTYPE rhs_dim, const VALUETYPE *rhs, INDEXTYPE out_dim,
VALUETYPE *out);
/*
* USER DEFINE FUNC IMPLEMENTATION
* DUMMY function, always return error when not implemented by user but used in
* message using UDEF
*/
#ifndef VOP_UDEF_IMPL
int VOP_UDEF_FUNC(INDEXTYPE lhs_dim, const VALUETYPE *lhs, INDEXTYPE rhs_dim,
const VALUETYPE *rhs, INDEXTYPE out_dim, VALUETYPE *out)
{
return FUSEDMM_UNDEFINED_USER_FUNCTION;
}
#endif
#ifndef ROP_UDEF_IMPL
int ROP_UDEF_FUNC(INDEXTYPE lhs_dim, const VALUETYPE *lhs, INDEXTYPE rhs_dim,
const VALUETYPE *rhs, VALUETYPE *out)
{
return FUSEDMM_UNDEFINED_USER_FUNCTION;
}
#endif
#ifndef SOP_UDEF_IMPL
int SOP_UDEF_FUNC(VALUETYPE val, VALUETYPE *out)
{
return FUSEDMM_UNDEFINED_USER_FUNCTION;
}
#endif
#ifndef VSC_UDEF_IMPL
int VSC_UDEF_FUNC(INDEXTYPE rhs_dim, const VALUETYPE *rhs, VALUETYPE scal,
INDEXTYPE out_dim, VALUETYPE *out)
{
return FUSEDMM_UNDEFINED_USER_FUNCTION;
}
#endif
#ifndef AOP_UDEF_IMPL /* func prototype */
int AOP_UDEF_FUNC(INDEXTYPE rhs_dim, const VALUETYPE *rhs, INDEXTYPE out_dim,
VALUETYPE *out)
{
return FUSEDMM_UNDEFINED_USER_FUNCTION;
}
#endif
#ifdef __cplusplus
} // extern "C"
#endif
#endif /* end of FUSED_SDMM_SPMM_INTERNAL_H */