Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Problem Constants Injection #54

Merged
merged 20 commits into from
Jan 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
93567a6
Add constants prediction logic (mod, yes/no) and its unittest
kyuridenamida Jan 2, 2019
90a9dc9
Support Problem constants injection with new template engine Jinja2
kyuridenamida Jan 2, 2019
683e3f3
fix issue that some test is recognized as test due to its name incorr…
kyuridenamida Jan 2, 2019
52d6f3b
Merge branch 'master' into issue/53
kyuridenamida Jan 2, 2019
164b6a6
Update atcodertools/models/problem_content.py
asi1024 Jan 2, 2019
f065a2b
Update atcodertools/models/problem_content.py
asi1024 Jan 2, 2019
edb6ca7
Update atcodertools/models/problem_content.py
asi1024 Jan 2, 2019
008e79f
Refactor the codes that were pointed out
kyuridenamida Jan 2, 2019
225483a
Merge branch 'issue/53' of github.com:kyuridenamida/atcoder-tools int…
kyuridenamida Jan 2, 2019
3271b94
import Optional to fix compile error
kyuridenamida Jan 2, 2019
b74cc2e
Remove strange heuristics to select one modulo value from multiple ca…
kyuridenamida Jan 2, 2019
23b1c15
Raise prediction error instead of returning None and write unit tests
kyuridenamida Jan 3, 2019
afb43f1
autopep
kyuridenamida Jan 3, 2019
4a019b9
Fix templates
kyuridenamida Jan 3, 2019
b544e2d
fix bug that increases false-positive by splitting samples into token…
kyuridenamida Jan 3, 2019
dc64040
add unittest only_with_no_str
kyuridenamida Jan 3, 2019
034acba
add unittest with tricky yes/no string case (+ rename test)
kyuridenamida Jan 3, 2019
6083a78
add unit test test_nested_embeddings_on_template
kyuridenamida Jan 3, 2019
acd59b5
follow pep8
kyuridenamida Jan 3, 2019
de22982
Fix ugly typo in tests
kyuridenamida Jan 3, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion atcodertools/codegen/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult

from abc import ABC, abstractmethod
Expand All @@ -6,7 +7,7 @@
class CodeGenerator(ABC):

@abstractmethod
def generate_code(self, prediction_result: FormatPredictionResult):
def generate_code(self, prediction_result: FormatPredictionResult, constants: ProblemConstantSet):
raise NotImplementedError


Expand Down
11 changes: 9 additions & 2 deletions atcodertools/codegen/cpp_code_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from atcodertools.codegen.code_gen_config import CodeGenConfig
from atcodertools.models.analyzer.analyzed_variable import AnalyzedVariable
from atcodertools.models.analyzer.simple_format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
from atcodertools.models.predictor.variable import Variable
from atcodertools.codegen.code_generator import CodeGenerator
Expand Down Expand Up @@ -28,14 +29,20 @@ def __init__(self, template: str, config: CodeGenConfig = CodeGenConfig()):
self._prediction_result = None
self._config = config

def generate_code(self, prediction_result: FormatPredictionResult):
def generate_code(self, prediction_result: FormatPredictionResult,
constants: ProblemConstantSet = ProblemConstantSet()):
if prediction_result is None:
raise NoPredictionResultGiven
self._prediction_result = prediction_result

return render(self._template,
formal_arguments=self._formal_arguments(),
actual_arguments=self._actual_arguments(),
input_part=self._input_part())
input_part=self._input_part(),
mod=constants.mod,
yes_str=constants.yes_str,
no_str=constants.no_str,
)

def _input_part(self):
lines = []
Expand Down
29 changes: 25 additions & 4 deletions atcodertools/codegen/template_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import string
import re
import warnings

import jinja2


def _substitute(s, reps):
Expand All @@ -17,14 +20,32 @@ def _substitute(s, reps):
sep = ('\n' + m.group(1)) if m.group(1).strip() == '' else '\n'

cr[m.group(2)] = sep.join(reps[m.group(2)])
i += m.end() # continue past last processed replaceable token
i += m.end() # continue past last processed replaceable token
return t.substitute(cr) # we can now substitute


def render(s, **args):
def render(template, **kwargs):
if "${" in template:
# If the template is old, render with the old engine.
# This logic is for backward compatibility
warnings.warn(
"The old template engine with ${} is deprecated. Please use the new Jinja2 template engine.", UserWarning)

return old_render(template, **kwargs)
else:
return render_by_jinja(template, **kwargs)


def old_render(template, **kwargs):
# This render function used to be used before version 1.0.3
new_args = {}

for k, v in args.items():
for k, v in kwargs.items():
new_args[k] = v if isinstance(v, list) else [v]

return _substitute(s, new_args)
return _substitute(template, new_args)


def render_by_jinja(template, **kwargs):
template = jinja2.Template(template)
return template.render(**kwargs) + "\n"
Empty file.
99 changes: 99 additions & 0 deletions atcodertools/constprediction/constants_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
import re
from typing import Tuple, Optional

from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from bs4 import BeautifulSoup

from atcodertools.models.problem_content import ProblemContent, InputFormatDetectionError, SampleDetectionError


class YesNoPredictionFailedError(Exception):
pass


class MultipleModCandidatesError(Exception):

def __init__(self, cands):
self.cands = cands


MOD_ANCHORS = ["余り", "あまり", "mod", "割っ", "modulo"]

MOD_STRATEGY_RE_LIST = [
re.compile("([0-9]+).?.?.?で割った"),
re.compile("modu?l?o?[^0-9]?[^0-9]?[^0-9]?([0-9]+)")
]


def is_mod_context(sentence):
for kw in MOD_ANCHORS:
if kw in sentence:
return True
return False


def predict_modulo(html: str) -> Optional[int]:
def normalize(sentence):
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").replace(
"10^9+7", "1000000007").lower().strip()

soup = BeautifulSoup(html, "html.parser")
sentences = soup.get_text().split("\n")
sentences = [normalize(s) for s in sentences if is_mod_context(s)]

mod_cands = set()

for s in sentences:
for regexp in MOD_STRATEGY_RE_LIST:
m = regexp.search(s)
if m is not None:
extracted_val = int(m.group(1))
mod_cands.add(extracted_val)

if len(mod_cands) == 0:
return None

if len(mod_cands) == 1:
return list(mod_cands)[0]

raise MultipleModCandidatesError(mod_cands)


def predict_yes_no(html: str) -> Tuple[Optional[str], Optional[str]]:
kyuridenamida marked this conversation as resolved.
Show resolved Hide resolved
try:
outputs = set()
for sample in ProblemContent.from_html(html).get_samples():
for x in sample.get_output().split("\n"):
outputs.add(x.strip())
except (InputFormatDetectionError, SampleDetectionError) as e:
raise YesNoPredictionFailedError(e)

yes_kws = ["yes", "possible"]
no_kws = ["no", "impossible"]

yes_str = None
no_str = None
for val in outputs:
if val.lower() in yes_kws:
yes_str = val
if val.lower() in no_kws:
no_str = val

return yes_str, no_str


def predict_constants(html: str) -> ProblemConstantSet:
try:
yes_str, no_str = predict_yes_no(html)
except YesNoPredictionFailedError:
yes_str = no_str = None

try:
mod = predict_modulo(html)
except MultipleModCandidatesError as e:
logging.warning("Modulo prediction failed -- "
"two or more candidates {} are detected as modulo values".format(e.cands))
mod = None

return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str)
8 changes: 6 additions & 2 deletions atcodertools/fileutils/create_contest_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from atcodertools.codegen.code_generator import CodeGenerator
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet
from atcodertools.models.sample import Sample
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult

Expand All @@ -11,8 +12,11 @@ def _make_text_file(file_path, text):
f.write(text)


def create_code_from_prediction_result(result: FormatPredictionResult, code_generator: CodeGenerator, file_path: str):
_make_text_file(file_path, code_generator.generate_code(result))
def create_code_from(result: FormatPredictionResult,
constants: ProblemConstantSet,
code_generator: CodeGenerator,
file_path: str):
_make_text_file(file_path, code_generator.generate_code(result, constants))


def create_example(example: Sample, in_example_name: str, out_example_name: str):
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions atcodertools/models/constpred/problem_constant_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@


class ProblemConstantSet:
def __init__(self,
mod: int = None,
yes_str: str = None,
no_str: str = None,
):
self.mod = mod
self.yes_str = yes_str
self.no_str = no_str
24 changes: 16 additions & 8 deletions atcodertools/models/problem_content.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

from bs4 import BeautifulSoup

Expand Down Expand Up @@ -36,13 +36,17 @@ class InputFormatDetectionError(Exception):

class ProblemContent:

def __init__(self, input_format_text: str = None, samples: List[Sample] = None):
def __init__(self, input_format_text: Optional[str] = None,
samples: Optional[List[Sample]] = None,
original_html: Optional[str] = None,
):
self.samples = samples
self.input_format_text = input_format_text
self.original_html = original_html

@classmethod
def from_html(cls, html: str = None):
res = ProblemContent()
def from_html(cls, html: str):
res = ProblemContent(original_html=html)
soup = BeautifulSoup(html, "html.parser")
res.input_format_text, res.samples = res._extract_input_format_and_samples(
soup)
Expand Down Expand Up @@ -81,13 +85,17 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]:
if len(input_tags) != len(output_tags):
raise SampleDetectionError

res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
for in_tag, out_tag in zip(input_tags, output_tags)]
try:
res = [Sample(normalize(in_tag.text), normalize(out_tag.text))
for in_tag, out_tag in zip(input_tags, output_tags)]

if input_format_tag is None:
raise InputFormatDetectionError

if input_format_tag is None:
input_format_text = normalize(input_format_tag.text)
except AttributeError:
raise InputFormatDetectionError

input_format_text = normalize(input_format_tag.text)
return input_format_text, res

@staticmethod
Expand Down
18 changes: 11 additions & 7 deletions atcodertools/tools/envgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from atcodertools.codegen.code_gen_config import CodeGenConfig
from atcodertools.codegen.cpp_code_generator import CppCodeGenerator
from atcodertools.codegen.java_code_generator import JavaCodeGenerator
from atcodertools.fileutils.create_contest_file import create_examples, create_code_from_prediction_result
from atcodertools.constprediction.constants_prediction import predict_constants
from atcodertools.fileutils.create_contest_file import create_examples, \
create_code_from
from atcodertools.models.problem_content import InputFormatDetectionError, SampleDetectionError
from atcodertools.client.atcoder import AtCoderClient, Contest, LoginError
from atcodertools.fmtprediction.predict_format import FormatPredictor, NoPredictionResultError, \
Expand Down Expand Up @@ -101,20 +103,22 @@ def emit_info(text):
new_path))

try:
result = FormatPredictor().predict(content)

with open(template_code_path, "r") as f:
template = f.read()

if lang == "cpp":
gen_class = CppCodeGenerator
elif lang == "java":
gen_class = JavaCodeGenerator
else:
raise NotImplementedError("only supporting cpp and java")

create_code_from_prediction_result(
with open(template_code_path, "r") as f:
template = f.read()

result = FormatPredictor().predict(content)
constants = predict_constants(content.original_html)

create_code_from(
result,
constants,
gen_class(template, config),
code_file_path)
emit_info(
Expand Down
15 changes: 9 additions & 6 deletions atcodertools/tools/templates/cpp/template_success.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include <bits/stdc++.h>
using namespace std;

void solve(${formal_arguments}){

{% if mod is not none %}const long long MOD = {{ mod }};{% endif %}
{% if yes_str is not none %}const string YES = "{{ yes_str }}";{% endif %}
{% if no_str is not none %}const string NO = "{{ no_str }}";{% endif %}

void solve({{ formal_arguments }}){

}

int main(){
${input_part}
solve(${actual_arguments});
int main(){
{{input_part}}
solve({{ actual_arguments }});
return 0;
}

2 changes: 1 addition & 1 deletion atcodertools/tools/templates/java/template_failure.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class Main{
public static void main(String[] args) throws Exception{
final Scanner sc = new Scanner(System.in);

}
}

22 changes: 13 additions & 9 deletions atcodertools/tools/templates/java/template_success.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import java.io.*;
import java.util.*;

class Main{
public static void main(String[] args) throws Exception{
final Scanner sc = new Scanner(System.in);
${input_part}
solve(${actual_arguments});
}
static void solve(${formal_arguments}){

class Main {
{% if mod %static final long MOD = {{ mod }};{% endif %}
{% if yes_str is not none %}static final String YES = "{{ yes_str }}";{% endif %}
{% if no_str is not none %}static final String NO = "{{ no_str }}";{% endif %}

public static void main(String[] args) throws Exception {
final Scanner sc = new Scanner(System.in);
{{ input_part }}
solve({{ actual_arguments }});
}
}

static void solve({{ formal_arguments }}){

}
}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
beautifulsoup4
requests
colorama
toml
toml
jinja2
Loading