Skip to content

Commit

Permalink
학습된 모델을 명령창에서 실행해볼 수 있는 데모 스크립트 추가 #30
Browse files Browse the repository at this point in the history
  • Loading branch information
krikit committed Feb 10, 2019
1 parent 3097e6e commit 5753e17
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rsc/bin/compile_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _load_cfg_rsc(rsc_src: str, model_size: str) -> Tuple[Namespace, Resource]:
for key, val in cfg_dic.items():
setattr(cfg, key, val)
cwd = os.path.realpath(os.getcwd())
train_dir = os.path.realpath('{}/..'.format(rsc_src))
train_dir = os.path.realpath('{}/../../train'.format(rsc_src))
if cwd != train_dir:
os.chdir(train_dir)
rsc = Resource(cfg)
Expand Down
9 changes: 4 additions & 5 deletions src/main/python/khaiii/train/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,18 @@ class PosTagger:
"""
part-of-speech tagger
"""
def __init__(self, rsc_src: str):
def __init__(self, model_dir: str):
"""
Args:
rsc_src: resource dir
model_dir: model dir
"""
cfg_dict = json.load(open('{}/config.json'.format(rsc_src), 'r', encoding='UTF-8'))
cfg_dict = json.load(open('{}/config.json'.format(model_dir), 'r', encoding='UTF-8'))
self.cfg = Namespace()
for key, val in cfg_dict.items():
setattr(self.cfg, key, val)
self.cfg.rsc_src = rsc_src
self.rsc = Resource(self.cfg)
self.model = CnnModel(self.cfg, self.rsc)
self.model.load('{}/model.state'.format(rsc_src))
self.model.load('{}/model.state'.format(model_dir))
self.model.eval()

def tag_raw(self, raw_sent: str, enable_restore: bool = True) -> PosSentTensor:
Expand Down
2 changes: 1 addition & 1 deletion train/bin/pickle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _load_resource(cfg: Namespace, rsc_src: str) -> Resource:
Resource object
"""
cwd = os.path.realpath(os.getcwd())
train_dir = os.path.realpath('{}/..'.format(rsc_src))
train_dir = os.path.realpath('{}/../../train'.format(rsc_src))
if cwd != train_dir:
os.chdir(train_dir)
rsc = Resource(cfg)
Expand Down
76 changes: 76 additions & 0 deletions train/bin/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


"""
command line part-of-speech tagger demo
__author__ = 'Jamie (jamie.lim@kakaocorp.com)'
__copyright__ = 'Copyright (C) 2019-, Kakao Corp. All rights reserved.'
"""


###########
# imports #
###########
from argparse import ArgumentParser, Namespace
import logging
import os
import sys

from khaiii.train.tagger import PosTagger


#############
# functions #
#############
def run(args: Namespace):
"""
run function which is the start point of program
Args:
args: program arguments
"""
tgr = PosTagger(args.model_dir)
for line_num, line in enumerate(sys.stdin, start=1):
if line_num % 100000 == 0:
logging.info('%d00k-th line..', (line_num // 100000))
line = line.rstrip('\r\n')
if not line:
print()
continue
pos_sent = tgr.tag_raw(line)
for pos_word in pos_sent.pos_tagged_words:
print(pos_word.raw, end='\t')
print(' + '.join([str(m) for m in pos_word.pos_tagged_morphs]))
print()


########
# main #
########
def main():
"""
main function processes only argument parsing
"""
parser = ArgumentParser(description='command line part-of-speech tagger demo')
parser.add_argument('-m', '--model-dir', help='model dir', metavar='DIR', required=True)
parser.add_argument('--input', help='input file <default: stdin>', metavar='FILE')
parser.add_argument('--output', help='output file <default: stdout>', metavar='FILE')
parser.add_argument('--gpu-num', help='GPU number to use', metavar='INT', type=int, default=0)
parser.add_argument('--debug', help='enable debug', action='store_true')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_num)
if args.input:
sys.stdin = open(args.input, 'r', encoding='UTF-8')
if args.output:
sys.stdout = open(args.output, 'w', encoding='UTF-8')
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)

run(args)


if __name__ == '__main__':
main()

0 comments on commit 5753e17

Please sign in to comment.