forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_autograd_functions.py
260 lines (220 loc) · 9.33 KB
/
gen_autograd_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Generates C++ autograd functions for the derivatives of ATen operations
#
# This writes two files:
# Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes
#
import os
import re
from .utils import nested_dict, CodeTemplate, write
from .gen_autograd import VIEW_FUNCTIONS
from .utils import IDENT_REGEX
FUNCTION_DECLARATION = CodeTemplate("""\
struct TORCH_API ${op} : public ${superclass} {
using ${superclass}::${superclass};
variable_list apply(variable_list&& grads) override;
std::string name() const override { return "${op}"; }
void release_variables() override {
${thread_lock}
${release_variables}
}
${will_release_variables}
${saved_variables}
${saved_list_sizes}
};
""")
WILL_RELEASE_VARIABLES = CodeTemplate("""\
bool retain_variables = true;
void will_release_variables() override {
retain_variables = false;
}
""")
FUNCTION_DEFINITION = CodeTemplate("""\
variable_list ${op}::apply(variable_list&& grads) {
${thread_lock}
${asserts}
IndexRangeGenerator gen;
${compute_index_ranges}
variable_list grad_inputs(gen.size());
${body}
return grad_inputs;
}
""")
PY_FUNCTION_DEFINITION = CodeTemplate("""\
static PyTypeObject ${op}Class;
addClass<${op}>(${op}Class, "${op}");
""")
GRAD_INPUT_MASK = CodeTemplate("""\
auto grad_input_mask = std::array<bool, ${n}>{
${masks}
};\
""")
DERIVATIVE_SINGLE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
""")
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
""")
DERIVATIVE_MULTI = CodeTemplate("""\
if (should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
}
""")
# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def gen_autograd_functions_lib(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "Functions")
def gen_autograd_functions_python(out, autograd_functions, template_path):
gen_autograd_functions(out, autograd_functions, template_path, "python_functions")
def gen_autograd_functions(out, autograd_functions, template_path, file_basename):
"""Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Node
for each every differentiable torch function.
"""
function_definitions = []
function_declarations = []
py_function_initializers = []
for func in autograd_functions:
env = process_function(func)
function_declarations.append(FUNCTION_DECLARATION.substitute(env))
function_definitions.append(FUNCTION_DEFINITION.substitute(env))
py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env))
top_env = {
'autograd_function_definitions': function_definitions,
'autograd_function_declarations': function_declarations,
'py_function_initializers': py_function_initializers,
}
for suffix in [".h", ".cpp"]:
f = file_basename + suffix
templated_output = CodeTemplate.from_file(os.path.join(template_path, f))
write(out, f, templated_output, top_env)
def process_function(func):
env = {}
saved_variables = []
release_variables = []
saved_list_sizes = []
unpack = []
asserts = []
env['compute_index_ranges'] = []
for arg in func['args_with_derivatives']:
if arg['type'] == 'TensorList':
size = '{}_size_'.format(arg['name'])
saved_list_sizes.append('size_t {}_size_;'.format(arg['name']))
else:
size = '1'
env['compute_index_ranges'].append('auto {}_ix = gen.range({});'.format(arg['name'], size))
def save_arg(arg, is_output):
name = arg['name']
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
saved_variables.append('SavedVariable {}_;'.format(name))
release_variables.append('{}_.reset_data();'.format(name))
release_variables.append('{}_.reset_grad_function();'.format(name))
ptr = 'shared_from_this()' if is_output else ''
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':
saved_variables.append('std::vector<SavedVariable> {}_;'.format(name))
saved_variables.append('bool {}_released_ = false;'.format(name))
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append('{}_.clear();'.format(name))
release_variables.append('{}_released_ = true;'.format(name))
unpack.append('auto {} = unpack_list({}_);'.format(name, name))
asserts.append('TORCH_CHECK(!{}_released_, ERR_BACKWARD_TWICE);'.format(name))
elif arg['type'] == 'IntArrayRef':
saved_variables.append('std::vector<int64_t> {};'.format(name))
elif arg['type'] == 'int64_t':
saved_variables.append('{} {} = 0;'.format(arg['type'], name))
else:
saved_variables.append('{} {};'.format(arg['type'], name))
for arg in func['saved_inputs']:
save_arg(arg, is_output=False)
for arg in func['saved_outputs']:
save_arg(arg, is_output=True)
env['saved_variables'] = saved_variables
env['release_variables'] = release_variables
env['saved_list_sizes'] = saved_list_sizes
env['asserts'] = asserts
# lock the mutex when we release variables and in Node::apply to protect thread safety
# see Note [Thread Safety on Autograd Node]
if len(release_variables) > 0:
env['thread_lock'] = "std::lock_guard<std::mutex> lock(mutex_);"
else:
env['thread_lock'] = ''
if uses_retain_variables(func):
env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
else:
env['will_release_variables'] = ''
body = []
if uses_single_grad(func):
body.append('auto& grad = grads[0];')
def emit_derivative(derivative, args_with_derivatives):
formula = derivative['formula']
var_names = derivative['var_names']
if len(var_names) == 1:
checks_any_grad_defined = False
if 'not_implemented' not in formula:
matching_args = [
arg for arg in args_with_derivatives
if ('name' in arg) and (arg['name'] == var_names[0])]
if len(matching_args) == 1:
# We can add undefined grad support if the input variable is a Tensor
if ('simple_type' in matching_args[0].keys()) and (matching_args[0]['simple_type'] == 'Tensor'):
formula = 'any_grad_defined ? (' + formula + ') : Tensor()'
checks_any_grad_defined = True
return (checks_any_grad_defined,
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula))
else:
if 'grad_input_mask' in formula:
masks = ['should_compute_output({{ {}_ix }}),'.format(n) for n in var_names]
grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
else:
grad_input_mask = ''
idx_ranges = ', '.join("{}_ix".format(n) for n in var_names)
copy_ranges = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges, copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask)
body.extend(unpack)
need_any_grad_defined_var = False
for derivative in func['derivatives']:
checks_any_grad_defined, derivative_text = emit_derivative(derivative, func['args_with_derivatives'])
body.append(derivative_text)
need_any_grad_defined_var |= checks_any_grad_defined
# Since single-output derivative formulas need to check if grads are
# defined, only perform the check once, before all the formulas
if need_any_grad_defined_var:
body.insert(-len(func['derivatives']),
'bool any_grad_defined = any_variable_defined(grads);')
env['body'] = body
if func['name'] in UNTRACEABLE_FUNCTIONS:
env['superclass'] = 'Node'
else:
env['superclass'] = 'TraceableFunction'
return nested_dict(env, func)
def uses_ident(func, ident):
if func is None:
return False
for derivative in func['derivatives']:
formula = derivative['formula']
if re.search(IDENT_REGEX.format(ident), formula):
return True
return False
def uses_retain_variables(func):
return uses_ident(func, 'retain_variables')
def uses_single_grad(func):
return uses_ident(func, 'grad')