Skip to content

Commit

Permalink
Support Problem Constants Injection (#54)
Browse files Browse the repository at this point in the history
* Add constants prediction logic (mod, yes/no) and its unit test

* Support Problem constants injection with new template engine Jinja2

* Raise prediction error instead of returning None and write unit tests

* Fix templates

* add unittest only_with_no_str

* add a unit test with tricky yes/no string case (+ rename test)

* add unit test test_nested_embeddings_on_template
  • Loading branch information
kyuridenamida authored Jan 4, 2019
1 parent e67a223 commit 8a710a3
Show file tree
Hide file tree
Showing 52 changed files with 826 additions and 80 deletions.
3 changes: 2 additions & 1 deletion atcodertools/codegen/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult

from abc import ABC, abstractmethod
Expand All @@ -6,7 +7,7 @@
class CodeGenerator(ABC):

@abstractmethod
def generate_code(self, prediction_result: FormatPredictionResult):
def generate_code(self, prediction_result: FormatPredictionResult, constants: ProblemConstantSet):
raise NotImplementedError


Expand Down
11 changes: 9 additions & 2 deletions atcodertools/codegen/cpp_code_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from atcodertools.codegen.code_gen_config import CodeGenConfig
from atcodertools.models.analyzer.analyzed_variable import AnalyzedVariable
from atcodertools.models.analyzer.simple_format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
from atcodertools.models.predictor.variable import Variable
from atcodertools.codegen.code_generator import CodeGenerator
Expand Down Expand Up @@ -28,14 +29,20 @@ def __init__(self, template: str, config: CodeGenConfig = CodeGenConfig()):
self._prediction_result = None
self._config = config

def generate_code(self, prediction_result: FormatPredictionResult):
def generate_code(self, prediction_result: FormatPredictionResult,
constants: ProblemConstantSet = ProblemConstantSet()):
if prediction_result is None:
raise NoPredictionResultGiven
self._prediction_result = prediction_result

return render(self._template,
formal_arguments=self._formal_arguments(),
actual_arguments=self._actual_arguments(),
input_part=self._input_part())
input_part=self._input_part(),
mod=constants.mod,
yes_str=constants.yes_str,
no_str=constants.no_str,
)

def _input_part(self):
lines = []
Expand Down
29 changes: 25 additions & 4 deletions atcodertools/codegen/template_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import string
import re
import warnings

import jinja2


def _substitute(s, reps):
Expand All @@ -17,14 +20,32 @@ def _substitute(s, reps):
sep = ('\n' + m.group(1)) if m.group(1).strip() == '' else '\n'

cr[m.group(2)] = sep.join(reps[m.group(2)])
i += m.end() # continue past last processed replaceable token
i += m.end() # continue past last processed replaceable token
return t.substitute(cr) # we can now substitute


def render(s, **args):
def render(template, **kwargs):
if "${" in template:
# If the template is old, render with the old engine.
# This logic is for backward compatibility
warnings.warn(
"The old template engine with ${} is deprecated. Please use the new Jinja2 template engine.", UserWarning)

return old_render(template, **kwargs)
else:
return render_by_jinja(template, **kwargs)


def old_render(template, **kwargs):
# This render function used to be used before version 1.0.3
new_args = {}

for k, v in args.items():
for k, v in kwargs.items():
new_args[k] = v if isinstance(v, list) else [v]

return _substitute(s, new_args)
return _substitute(template, new_args)


def render_by_jinja(template, **kwargs):
template = jinja2.Template(template)
return template.render(**kwargs) + "\n"
Empty file.
99 changes: 99 additions & 0 deletions atcodertools/constprediction/constants_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
import re
from typing import Tuple, Optional

from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from bs4 import BeautifulSoup

from atcodertools.models.problem_content import ProblemContent, InputFormatDetectionError, SampleDetectionError


class YesNoPredictionFailedError(Exception):
pass


class MultipleModCandidatesError(Exception):

def __init__(self, cands):
self.cands = cands


MOD_ANCHORS = ["余り", "あまり", "mod", "割っ", "modulo"]

MOD_STRATEGY_RE_LIST = [
re.compile("([0-9]+).?.?.?で割った"),
re.compile("modu?l?o?[^0-9]?[^0-9]?[^0-9]?([0-9]+)")
]


def is_mod_context(sentence):
for kw in MOD_ANCHORS:
if kw in sentence:
return True
return False


def predict_modulo(html: str) -> Optional[int]:
def normalize(sentence):
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").replace(
"10^9+7", "1000000007").lower().strip()

soup = BeautifulSoup(html, "html.parser")
sentences = soup.get_text().split("\n")
sentences = [normalize(s) for s in sentences if is_mod_context(s)]

mod_cands = set()

for s in sentences:
for regexp in MOD_STRATEGY_RE_LIST:
m = regexp.search(s)
if m is not None:
extracted_val = int(m.group(1))
mod_cands.add(extracted_val)

if len(mod_cands) == 0:
return None

if len(mod_cands) == 1:
return list(mod_cands)[0]

raise MultipleModCandidatesError(mod_cands)


def predict_yes_no(html: str) -> Tuple[Optional[str], Optional[str]]:
try:
outputs = set()
for sample in ProblemContent.from_html(html).get_samples():
for x in sample.get_output().split("\n"):
outputs.add(x.strip())
except (InputFormatDetectionError, SampleDetectionError) as e:
raise YesNoPredictionFailedError(e)

yes_kws = ["yes", "possible"]
no_kws = ["no", "impossible"]

yes_str = None
no_str = None
for val in outputs:
if val.lower() in yes_kws:
yes_str = val
if val.lower() in no_kws:
no_str = val

return yes_str, no_str


def predict_constants(html: str) -> ProblemConstantSet:
try:
yes_str, no_str = predict_yes_no(html)
except YesNoPredictionFailedError:
yes_str = no_str = None

try:
mod = predict_modulo(html)
except MultipleModCandidatesError as e:
logging.warning("Modulo prediction failed -- "
"two or more candidates {} are detected as modulo values".format(e.cands))
mod = None

return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str)
8 changes: 6 additions & 2 deletions atcodertools/fileutils/create_contest_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from atcodertools.codegen.code_generator import CodeGenerator
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.sample import Sample
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult

Expand All @@ -11,8 +12,11 @@ def _make_text_file(file_path, text):
f.write(text)


def create_code_from_prediction_result(result: FormatPredictionResult, code_generator: CodeGenerator, file_path: str):
_make_text_file(file_path, code_generator.generate_code(result))
def create_code_from(result: FormatPredictionResult,
constants: ProblemConstantSet,
code_generator: CodeGenerator,
file_path: str):
_make_text_file(file_path, code_generator.generate_code(result, constants))


def create_example(example: Sample, in_example_name: str, out_example_name: str):
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions atcodertools/models/constpred/problem_constant_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@


class ProblemConstantSet:
def __init__(self,
mod: int = None,
yes_str: str = None,
no_str: str = None,
):
self.mod = mod
self.yes_str = yes_str
self.no_str = no_str
24 changes: 16 additions & 8 deletions atcodertools/models/problem_content.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

from bs4 import BeautifulSoup

Expand Down Expand Up @@ -36,13 +36,17 @@ class InputFormatDetectionError(Exception):

class ProblemContent:

def __init__(self, input_format_text: str = None, samples: List[Sample] = None):
def __init__(self, input_format_text: Optional[str] = None,
samples: Optional[List[Sample]] = None,
original_html: Optional[str] = None,
):
self.samples = samples
self.input_format_text = input_format_text
self.original_html = original_html

@classmethod
def from_html(cls, html: str = None):
res = ProblemContent()
def from_html(cls, html: str):
res = ProblemContent(original_html=html)
soup = BeautifulSoup(html, "html.parser")
res.input_format_text, res.samples = res._extract_input_format_and_samples(
soup)
Expand Down Expand Up @@ -81,13 +85,17 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]:
if len(input_tags) != len(output_tags):
raise SampleDetectionError

res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
for in_tag, out_tag in zip(input_tags, output_tags)]
try:
res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
for in_tag, out_tag in zip(input_tags, output_tags)]

if input_format_tag is None:
raise InputFormatDetectionError

if input_format_tag is None:
input_format_text = normalize(input_format_tag.text)
except AttributeError:
raise InputFormatDetectionError

input_format_text = normalize(input_format_tag.text)
return input_format_text, res

@staticmethod
Expand Down
18 changes: 11 additions & 7 deletions atcodertools/tools/envgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from atcodertools.codegen.code_gen_config import CodeGenConfig
from atcodertools.codegen.cpp_code_generator import CppCodeGenerator
from atcodertools.codegen.java_code_generator import JavaCodeGenerator
from atcodertools.fileutils.create_contest_file import create_examples, create_code_from_prediction_result
from atcodertools.constprediction.constants_prediction import predict_constants
from atcodertools.fileutils.create_contest_file import create_examples, \
create_code_from
from atcodertools.models.problem_content import InputFormatDetectionError, SampleDetectionError
from atcodertools.client.atcoder import AtCoderClient, Contest, LoginError
from atcodertools.fmtprediction.predict_format import FormatPredictor, NoPredictionResultError, \
Expand Down Expand Up @@ -101,20 +103,22 @@ def emit_info(text):
new_path))

try:
result = FormatPredictor().predict(content)

with open(template_code_path, "r") as f:
template = f.read()

if lang == "cpp":
gen_class = CppCodeGenerator
elif lang == "java":
gen_class = JavaCodeGenerator
else:
raise NotImplementedError("only supporting cpp and java")

create_code_from_prediction_result(
with open(template_code_path, "r") as f:
template = f.read()

result = FormatPredictor().predict(content)
constants = predict_constants(content.original_html)

create_code_from(
result,
constants,
gen_class(template, config),
code_file_path)
emit_info(
Expand Down
15 changes: 9 additions & 6 deletions atcodertools/tools/templates/cpp/template_success.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include <bits/stdc++.h>
using namespace std;

void solve(${formal_arguments}){

{% if mod is not none %}const long long MOD = {{ mod }};{% endif %}
{% if yes_str is not none %}const string YES = "{{ yes_str }}";{% endif %}
{% if no_str is not none %}const string NO = "{{ no_str }}";{% endif %}

void solve({{ formal_arguments }}){

}

int main(){
${input_part}
solve(${actual_arguments});
int main(){
{{input_part}}
solve({{ actual_arguments }});
return 0;
}

2 changes: 1 addition & 1 deletion atcodertools/tools/templates/java/template_failure.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class Main{
public static void main(String[] args) throws Exception{
final Scanner sc = new Scanner(System.in);

}
}

22 changes: 13 additions & 9 deletions atcodertools/tools/templates/java/template_success.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import java.io.*;
import java.util.*;

class Main{
public static void main(String[] args) throws Exception{
final Scanner sc = new Scanner(System.in);
${input_part}
solve(${actual_arguments});
}
static void solve(${formal_arguments}){

class Main {
{% if mod %static final long MOD = {{ mod }};{% endif %}
{% if yes_str is not none %}static final String YES = "{{ yes_str }}";{% endif %}
{% if no_str is not none %}static final String NO = "{{ no_str }}";{% endif %}

public static void main(String[] args) throws Exception {
final Scanner sc = new Scanner(System.in);
{{ input_part }}
solve({{ actual_arguments }});
}
}

static void solve({{ formal_arguments }}){

}
}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
beautifulsoup4
requests
colorama
toml
toml
jinja2
Loading

0 comments on commit 8a710a3

Please sign in to comment.