Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Improve annotation #138

Merged
merged 2 commits into from
Sep 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions src/nni_manager/yarn.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,6 @@ def _get_param(func, name):
del frame # see official doc
module = inspect.getmodulename(filename)
if name is None:
name = '#{:d}'.format(lineno)
name = '__line{:d}'.format(lineno)
key = '{}/{}/{}'.format(module, name, func)
return trial.get_parameter(key)
4 changes: 2 additions & 2 deletions src/sdk/pynni/tests/test_smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class SmartParamTestCase(TestCase):
def setUp(self):
params = {
'test_smartparam/choice1/choice': 2,
'test_smartparam/#{:d}/uniform'.format(lineno1): '5',
'test_smartparam/__line{:d}/uniform'.format(lineno1): '5',
'test_smartparam/func/function_choice': 1,
'test_smartparam/#{:d}/function_choice'.format(lineno2): 0
'test_smartparam/__line{:d}/function_choice'.format(lineno2): 0
}
nni.trial._params = { 'parameter_id': 'test_trial', 'parameters': params }

Expand Down
1,304 changes: 1,304 additions & 0 deletions src/webui/yarn.lock

Large diffs are not rendered by default.

Empty file modified test/naive/nnictl
100644 → 100755
Empty file.
Empty file modified test/naive/nnimanager
100644 → 100755
Empty file.
15 changes: 13 additions & 2 deletions tools/nni_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _generate_file_search_space(path, module):

def expand_annotations(src_dir, dst_dir):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
"""
Expand All @@ -77,6 +78,8 @@ def expand_annotations(src_dir, dst_dir):
if dst_dir[-1] == '/':
dst_dir = dst_dir[:-1]

annotated = False

for src_subdir, dirs, files in os.walk(src_dir):
assert src_subdir.startswith(src_dir)
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
Expand All @@ -86,17 +89,25 @@ def expand_annotations(src_dir, dst_dir):
src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'):
_expand_file_annotations(src_path, dst_path)
annotated |= _expand_file_annotations(src_path, dst_path)
else:
shutil.copyfile(src_path, dst_path)

for dir_name in dirs:
os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)

return dst_dir if annotated else src_dir

def _expand_file_annotations(src_path, dst_path):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
dst.write(code_generator.parse(src.read()))
annotated_code = code_generator.parse(src.read())
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True

except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(exc.args))
Expand Down
14 changes: 10 additions & 4 deletions tools/nni_annotation/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class Transformer(ast.NodeTransformer):
def __init__(self):
self.stack = []
self.last_line = 0
self.annotated = False

def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
Expand Down Expand Up @@ -190,8 +191,9 @@ def visit(self, node):

def _visit_string(self, node):
string = node.value.s

if not string.startswith('@nni.'):
if string.startswith('@nni.'):
self.annotated = True
else:
return node # not an annotation, ignore it

if string.startswith('@nni.report_intermediate_result(') \
Expand All @@ -216,19 +218,23 @@ def _visit_children(self, node):

def parse(code):
"""Annotate user code.
Return annotated code (str).
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
"""
try:
ast_tree = ast.parse(code)
except Exception:
raise RuntimeError('Bad Python code')

transformer = Transformer()
try:
Transformer().visit(ast_tree)
transformer.visit(ast_tree)
except AssertionError as exc:
raise RuntimeError('%d: %s' % (ast_tree.last_line, exc.args[0]))

if not transformer.annotated:
return None

last_future_import = -1
import_nni = ast.Import(names=[ast.alias(name='nni', asname=None)])
nodes = ast_tree.body
Expand Down
2 changes: 1 addition & 1 deletion tools/nni_annotation/search_space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def visit_Call(self, node): # pylint: disable=invalid-name
else:
# generate the missing name automatically
assert len(node.args) > 0, 'Smart parameter expression has no argument'
name = '#' + str(node.args[-1].lineno)
name = '__line' + str(node.args[-1].lineno)
specified_name = False

if func in ('choice', 'function_choice'):
Expand Down
9 changes: 8 additions & 1 deletion tools/nni_annotation/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json
import os
import shutil
import tempfile
from unittest import TestCase, main


Expand All @@ -43,12 +44,18 @@ def test_search_space_generator(self):
self.assertEqual(search_space, json.load(f))

def test_code_generator(self):
expand_annotations('testcase/usercode', '_generated')
code_dir = expand_annotations('testcase/usercode', '_generated')
self.assertEqual(code_dir, '_generated')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
assert src.read() == dst.read()

def test_annotation_detecting(self):
dir_ = 'testcase/usercode/non_annotation'
code_dir = expand_annotations(dir_, tempfile.mkdtemp())
self.assertEqual(code_dir, dir_)

def _assert_source_equal(self, src1, src2):
with open(src1) as f1, open(src2) as f2:
ast1 = ast.dump(ast.parse(f1.read()))
Expand Down
5 changes: 5 additions & 0 deletions tools/nni_annotation/testcase/annotated/non_annotation/bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import nni

def bar():
"""I'm doc string"""
return nni.report_final_result(0)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print('hello')
6 changes: 3 additions & 3 deletions tools/nni_annotation/testcase/searchspace.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
"_type": "choice",
"_value": [ 0, 1, 2, 3 ]
},
"handwrite/#5/function_choice": {
"handwrite/__line5/function_choice": {
"_type": "choice",
"_value": [ 0, 1, 2 ]
},
"handwrite/#8/qlognormal": {
"handwrite/__line8/qlognormal": {
"_type": "qlognormal",
"_value": [ 1.2, 3, 4.5 ]
},
"handwrite/#13/choice": {
"handwrite/__line13/choice": {
"_type": "choice",
"_value": [ 0, 1 ]
},
Expand Down
5 changes: 5 additions & 0 deletions tools/nni_annotation/testcase/usercode/non_annotation/bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import nni

def bar():
"""I'm doc string"""
return nni.report_final_result(0)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print('hello')
13 changes: 5 additions & 8 deletions tools/nnicmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from subprocess import Popen, PIPE, call
import tempfile
from nni_annotation import *
import random
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url
Expand Down Expand Up @@ -189,13 +188,11 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation', ''.join(random.sample(string.ascii_letters + string.digits, 8)))
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path)
expand_annotations(experiment_config['trial']['codeDir'], path)
experiment_config['trial']['codeDir'] = path
search_space = generate_search_space(experiment_config['trial']['codeDir'])
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation')
path = tempfile.mkdtemp(dir=path)
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path)
experiment_config['trial']['codeDir'] = code_dir
search_space = generate_search_space(code_dir)
experiment_config['searchSpace'] = json.dumps(search_space)
assert search_space, ERROR_INFO % 'Generated search space is empty'
elif experiment_config.get('searchSpacePath'):
Expand Down