-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Problem Constants Injection (#54)
* Add constants prediction logic (mod, yes/no) and its unit test * Support Problem constants injection with new template engine Jinja2 * Raise prediction error instead of returning None and write unit tests * Fix templates * add unittest only_with_no_str * add a unit test with tricky yes/no string case (+ rename test) * add unit test test_nested_embeddings_on_template
- Loading branch information
1 parent
e67a223
commit 8a710a3
Showing
52 changed files
with
826 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import logging | ||
import re | ||
from typing import Tuple, Optional | ||
|
||
from atcodertools.models.constpred.problem_constant_set import ProblemConstantSet | ||
from bs4 import BeautifulSoup | ||
|
||
from atcodertools.models.problem_content import ProblemContent, InputFormatDetectionError, SampleDetectionError | ||
|
||
|
||
class YesNoPredictionFailedError(Exception): | ||
pass | ||
|
||
|
||
class MultipleModCandidatesError(Exception): | ||
|
||
def __init__(self, cands): | ||
self.cands = cands | ||
|
||
|
||
MOD_ANCHORS = ["余り", "あまり", "mod", "割っ", "modulo"] | ||
|
||
MOD_STRATEGY_RE_LIST = [ | ||
re.compile("([0-9]+).?.?.?で割った"), | ||
re.compile("modu?l?o?[^0-9]?[^0-9]?[^0-9]?([0-9]+)") | ||
] | ||
|
||
|
||
def is_mod_context(sentence): | ||
for kw in MOD_ANCHORS: | ||
if kw in sentence: | ||
return True | ||
return False | ||
|
||
|
||
def predict_modulo(html: str) -> Optional[int]: | ||
def normalize(sentence): | ||
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").replace( | ||
"10^9+7", "1000000007").lower().strip() | ||
|
||
soup = BeautifulSoup(html, "html.parser") | ||
sentences = soup.get_text().split("\n") | ||
sentences = [normalize(s) for s in sentences if is_mod_context(s)] | ||
|
||
mod_cands = set() | ||
|
||
for s in sentences: | ||
for regexp in MOD_STRATEGY_RE_LIST: | ||
m = regexp.search(s) | ||
if m is not None: | ||
extracted_val = int(m.group(1)) | ||
mod_cands.add(extracted_val) | ||
|
||
if len(mod_cands) == 0: | ||
return None | ||
|
||
if len(mod_cands) == 1: | ||
return list(mod_cands)[0] | ||
|
||
raise MultipleModCandidatesError(mod_cands) | ||
|
||
|
||
def predict_yes_no(html: str) -> Tuple[Optional[str], Optional[str]]: | ||
try: | ||
outputs = set() | ||
for sample in ProblemContent.from_html(html).get_samples(): | ||
for x in sample.get_output().split("\n"): | ||
outputs.add(x.strip()) | ||
except (InputFormatDetectionError, SampleDetectionError) as e: | ||
raise YesNoPredictionFailedError(e) | ||
|
||
yes_kws = ["yes", "possible"] | ||
no_kws = ["no", "impossible"] | ||
|
||
yes_str = None | ||
no_str = None | ||
for val in outputs: | ||
if val.lower() in yes_kws: | ||
yes_str = val | ||
if val.lower() in no_kws: | ||
no_str = val | ||
|
||
return yes_str, no_str | ||
|
||
|
||
def predict_constants(html: str) -> ProblemConstantSet: | ||
try: | ||
yes_str, no_str = predict_yes_no(html) | ||
except YesNoPredictionFailedError: | ||
yes_str = no_str = None | ||
|
||
try: | ||
mod = predict_modulo(html) | ||
except MultipleModCandidatesError as e: | ||
logging.warning("Modulo prediction failed -- " | ||
"two or more candidates {} are detected as modulo values".format(e.cands)) | ||
mod = None | ||
|
||
return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
|
||
class ProblemConstantSet: | ||
def __init__(self, | ||
mod: int = None, | ||
yes_str: str = None, | ||
no_str: str = None, | ||
): | ||
self.mod = mod | ||
self.yes_str = yes_str | ||
self.no_str = no_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
#include <bits/stdc++.h> | ||
using namespace std; | ||
|
||
void solve(${formal_arguments}){ | ||
|
||
{% if mod is not none %}const long long MOD = {{ mod }};{% endif %} | ||
{% if yes_str is not none %}const string YES = "{{ yes_str }}";{% endif %} | ||
{% if no_str is not none %}const string NO = "{{ no_str }}";{% endif %} | ||
|
||
void solve({{ formal_arguments }}){ | ||
|
||
} | ||
|
||
int main(){ | ||
${input_part} | ||
solve(${actual_arguments}); | ||
int main(){ | ||
{{input_part}} | ||
solve({{ actual_arguments }}); | ||
return 0; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,18 @@ | ||
import java.io.*; | ||
import java.util.*; | ||
|
||
class Main{ | ||
public static void main(String[] args) throws Exception{ | ||
final Scanner sc = new Scanner(System.in); | ||
${input_part} | ||
solve(${actual_arguments}); | ||
} | ||
static void solve(${formal_arguments}){ | ||
|
||
class Main { | ||
{% if mod %static final long MOD = {{ mod }};{% endif %} | ||
{% if yes_str is not none %}static final String YES = "{{ yes_str }}";{% endif %} | ||
{% if no_str is not none %}static final String NO = "{{ no_str }}";{% endif %} | ||
|
||
public static void main(String[] args) throws Exception { | ||
final Scanner sc = new Scanner(System.in); | ||
{{ input_part }} | ||
solve({{ actual_arguments }}); | ||
} | ||
} | ||
|
||
static void solve({{ formal_arguments }}){ | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
beautifulsoup4 | ||
requests | ||
colorama | ||
toml | ||
toml | ||
jinja2 |
Oops, something went wrong.