forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CPUFallback.h
50 lines (42 loc) · 2.35 KB
/
CPUFallback.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/Metaprogramming.h>
#include <torch/library.h>
namespace at { namespace native {
// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn final {};
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn<fallback_fn, Op, ReturnType(ParameterTypes...)> final {
static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<typename Op::schema>::return_type>::value,
"Return type mismatch");
static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<typename Op::schema>::parameter_types>::value,
"Parameter types mismatch");
static ReturnType call(ParameterTypes... args) {
auto op = c10::Dispatcher::singleton()
// TODO: figure out how to make compiler happy without dynamic casts
.findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
//.findSchemaOrThrow("a", "b")
.typed<ReturnType (ParameterTypes...)>();
return c10::impl::BoxedKernelWrapper<ReturnType (ParameterTypes...)>::call(
c10::KernelFunction::make_boxed_function<fallback_fn>,
nullptr,
op,
c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
//std::forward<ParameterTypes...>(args...)
// TODO: get std::forward<> to work
args...
);
}
};
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, typename Op::schema>;
} // namespace native
} // namespace at