Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support annotation in python 3.8 #2881

Merged
merged 1 commit into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions tools/nni_annotation/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import astor

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck

Expand Down Expand Up @@ -37,13 +38,13 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
call_funcs_keys.append(ast.Str(s=call_name))
call_funcs_keys.append(ast_Str(s=call_name))
call_funcs_values.append(call.func)
assert not call.args, 'Number of args without keyword should be zero'
kw_args = []
kw_values = []
for kw in call.keywords:
kw_args.append(ast.Str(s=kw.arg))
kw_args.append(ast_Str(s=kw.arg))
kw_values.append(kw.value)
call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values))
call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values)
Expand All @@ -57,12 +58,12 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
var_names = [ast.Str(s=astor.to_source(var).strip()) for var in value.elts]
var_names = [ast_Str(s=astor.to_source(var).strip()) for var in value.elts]
optional_inputs = ast.Dict(keys=var_names, values=value.elts)
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, \
assert type(value) is ast_Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value
fields['optional_input_size'] = True
Expand All @@ -79,8 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
target_call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='mutable_layer', ctx=ast.Load())
target_call_args = [ast.Str(s=mutable_id),
ast.Str(s=mutable_layer_id),
target_call_args = [ast_Str(s=mutable_id),
ast_Str(s=mutable_layer_id),
call_funcs,
call_kwargs]
if fields['fixed_inputs']:
Expand All @@ -93,8 +94,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
target_call_args.append(optional_input_size)
else:
target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0))
target_call_args.append(ast.Str(s=nas_mode))
target_call_args.append(ast_Num(n=0))
target_call_args.append(ast_Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode', 'darts_mode']:
target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
Expand Down Expand Up @@ -151,7 +152,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'

name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
Expand All @@ -169,7 +170,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True)

name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
call.keywords[0].value = ast_Str(s=name_str)

return call, funcs

Expand All @@ -180,12 +181,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
if type(arg) in [ast_Str, ast_Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
Expand All @@ -209,7 +210,7 @@ def test_variable_equal(node1, node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
Expand Down Expand Up @@ -282,7 +283,7 @@ def visit(self, node):
annotation = self.stack[-1]

# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
Expand All @@ -306,7 +307,7 @@ def _visit_string(self, node):
if string.startswith('@nni.training_update'):
expr = parse_annotation(string[1:])
call_node = expr.value
call_node.args.insert(0, ast.Str(s=self.nas_mode))
call_node.args.insert(0, ast_Str(s=self.nas_mode))
return expr

if string.startswith('@nni.report_intermediate_result') \
Expand Down
10 changes: 6 additions & 4 deletions tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import astor

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck


Expand Down Expand Up @@ -44,7 +46,7 @@ def generate_mutable_layer_search_space(self, args):
self.search_space[key]['_value'][mutable_layer] = {
'layer_choice': [k.s for k in args[2].keys],
'optional_inputs': [k.s for k in args[5].keys],
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
'optional_input_size': args[6].n if isinstance(args[6], ast_Num) else [args[6].elts[0].n, args[6].elts[1].n]
}

def visit_Call(self, node): # pylint: disable=invalid-name
Expand Down Expand Up @@ -73,7 +75,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# there is a `name` argument
assert len(node.keywords) == 1, 'Smart parameter has keyword argument other than "name"'
assert node.keywords[0].arg == 'name', 'Smart paramater\'s keyword argument is not "name"'
assert type(node.keywords[0].value) is ast.Str, 'Smart parameter\'s name must be string literal'
assert type(node.keywords[0].value) is ast_Str, 'Smart parameter\'s name must be string literal'
name = node.keywords[0].value.s
specified_name = True
else:
Expand All @@ -86,7 +88,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
# we will use keys in the dict as the choices, which is generated by code_generator according to the args given by user
assert len(node.args) == 1, 'Smart parameter has arguments other than dict'
# check if it is a number or a string and get its value accordingly
args = [key.n if type(key) is ast.Num else key.s for key in node.args[0].keys]
args = [key.n if type(key) is ast_Num else key.s for key in node.args[0].keys]
else:
# arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
Expand All @@ -95,7 +97,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name

key = self.module_name + '/' + name + '/' + func
# store key in ast.Call
node.keywords.append(ast.keyword(arg='key', value=ast.Str(s=key)))
node.keywords.append(ast.keyword(arg='key', value=ast_Str(s=key)))

if func == 'function_choice':
func = 'choice'
Expand Down
22 changes: 12 additions & 10 deletions tools/nni_annotation/specific_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import astor
from nni_cmd.common_utils import print_warning

from .utils import ast_Num, ast_Str

# pylint: disable=unidiomatic-typecheck

para_cfg = None
Expand Down Expand Up @@ -134,7 +136,7 @@ def parse_nni_variable(code):
assert arg.func.value.id == 'nni', 'nni.variable value is not a NNI function'

name_str = astor.to_source(name).strip()
keyword_arg = ast.keyword(arg='name', value=ast.Str(s=name_str))
keyword_arg = ast.keyword(arg='name', value=ast_Str(s=name_str))
arg.keywords.append(keyword_arg)
if arg.func.attr == 'choice':
convert_args_to_dict(arg)
Expand All @@ -152,7 +154,7 @@ def parse_nni_function(code):
convert_args_to_dict(call, with_lambda=True)

name_str = astor.to_source(name).strip()
call.keywords[0].value = ast.Str(s=name_str)
call.keywords[0].value = ast_Str(s=name_str)

return call, funcs

Expand All @@ -163,12 +165,12 @@ def convert_args_to_dict(call, with_lambda=False):
"""
keys, values = list(), list()
for arg in call.args:
if type(arg) in [ast.Str, ast.Num]:
if type(arg) in [ast_Str, ast_Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg_value = ast_Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
keys.append(arg_value)
values.append(arg)
Expand All @@ -192,7 +194,7 @@ def test_variable_equal(node1, node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ('lineno', 'col_offset', 'ctx'):
if k in ('lineno', 'col_offset', 'ctx', 'end_lineno', 'end_col_offset'):
continue
if not test_variable_equal(v, getattr(node2, k)):
return False
Expand Down Expand Up @@ -264,7 +266,7 @@ def visit(self, node):
annotation = self.stack[-1]

# this is a standalone string, may be an annotation
if type(node) is ast.Expr and type(node.value) is ast.Str:
if type(node) is ast.Expr and type(node.value) is ast_Str:
# must not annotate an annotation string
assert annotation is None, 'Annotating an annotation'
return self._visit_string(node)
Expand All @@ -290,23 +292,23 @@ def _visit_string(self, node):
"Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
args=[ast_Str(s='Get next parameter here...')], keywords=[]))

if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
args=[ast_Str(s='Training update here...')], keywords=[]))

if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
args=[ast_Str(s='nni.report_intermediate_result: '), arg], keywords=[]))

if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
args=[ast_Str(s='nni.report_final_result: '), arg], keywords=[]))

if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
Expand Down
15 changes: 15 additions & 0 deletions tools/nni_annotation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import ast
from sys import version_info


if version_info >= (3, 8):
ast_Num = ast_Str = ast_Bytes = ast_NameConstant = ast_Ellipsis = ast.Constant
else:
ast_Num = ast.Num
ast_Str = ast.Str
ast_Bytes = ast.Bytes
ast_NameConstant = ast.NameConstant
ast_Ellipsis = ast.Ellipsis