Skip to content

Commit

Permalink
validators are owned by subtasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Virv12 committed Apr 2, 2024
1 parent ea0716f commit de92cff
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 149 deletions.
2 changes: 1 addition & 1 deletion src/tools/find_bad_case/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ pub fn patch_task_for_batch(
let testcase = TestcaseInfo::new(
testcase_id,
input_generator,
task.input_validator_generator.generate(Some(0)),
testcase_template.output_generator.clone(),
);

Expand All @@ -97,6 +96,7 @@ pub fn patch_task_for_batch(
max_score: 100.0,
testcases,
is_default: false,
input_validator: task.input_validator_generator.generate(Some(0)),
..Default::default()
};
task.subtasks.insert(0, subtask);
Expand Down
3 changes: 2 additions & 1 deletion task-maker-format/src/ioi/dag/input_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ pub const TM_VALIDATION_FILE_NAME: &str = "tm_validation_file";

/// An input file validator is responsible for checking that the input file follows the format and
/// constraints defined by the task.
#[derive(Debug, Clone, Serialize, Deserialize, TypeScriptify)]
#[derive(Debug, Clone, Serialize, Deserialize, TypeScriptify, Default)]
pub enum InputValidator {
/// Skip the validation and assume the input file is valid.
#[default]
AssumeValid,
/// Use a custom command to check if the input file is valid. The command should exit with
/// non-zero return code if and only if the input is invalid.
Expand Down
219 changes: 86 additions & 133 deletions task-maker-format/src/ioi/format/italian_yaml/cases_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ where
self.result.push(TaskInputEntry::Testcase(TestcaseInfo::new(
self.testcase_id,
generator,
self.get_validator(&variables)
.context("Cannot get testcase validator")?,
(self.get_output_gen)(self.testcase_id),
)));
self.testcase_id += 1;
Expand Down Expand Up @@ -428,15 +426,28 @@ where
/// Parse a `:VAL` command.
fn parse_val(&mut self, line: Pair) -> Result<(), Error> {
let line: Vec<_> = line.into_inner().collect();
CasesGen::<OutGen>::process_gen_val(
line,
&self.task_dir,
self.subtask_id,
&mut self.default_validator,
&mut self.current_validator,
&mut self.validators,
"validator",
)?;
if line.len() == 1 {
let val = self
.get_validator(&self.get_auto_variables())
.context("Failed to get validator")?;

// set validator for the last subtask
let Some(TaskInputEntry::Subtask(subtask)) = self.result.last_mut() else {
bail!("The validator must be set after a subtask");
};
subtask.input_validator = val;
} else {
CasesGen::<OutGen>::process_gen_val(
line,
&self.task_dir,
self.subtask_id,
&mut self.default_validator,
&mut self.current_validator,
&mut self.validators,
"validator",
)?;
}

Ok(())
}

Expand Down Expand Up @@ -530,10 +541,11 @@ where
.as_deref()
.map(|s| s.chars().filter(|&c| c != ' ' && c != '\t').collect());
self.subtask_name = name.clone();
self.subtask_id += 1;
self.result.push(TaskInputEntry::Subtask(
#[allow(deprecated)]
SubtaskInfo {
id: self.subtask_id,
id: self.subtask_id - 1,
name,
description,
max_score: score,
Expand All @@ -545,10 +557,10 @@ where
)
.ok(),
is_default: false,
input_validator: self.get_validator(&self.get_auto_variables())?,
..Default::default()
},
));
self.subtask_id += 1;
Ok(())
}

Expand All @@ -572,8 +584,6 @@ where
self.result.push(TaskInputEntry::Testcase(TestcaseInfo::new(
self.testcase_id,
InputGenerator::StaticFile(path),
self.get_validator(&self.get_auto_variables())
.context("Cannot get testcase validator")?,
(self.get_output_gen)(self.testcase_id),
)));
self.testcase_id += 1;
Expand All @@ -592,16 +602,11 @@ where
for arg in &validator.args {
// variables may (and should!) start with `$`, remove it before accessing
// the `variables` map.
let arg = if let Some(rest) = arg.strip_prefix('$') {
rest
} else {
arg.as_str()
};
if let Some(value) = variables.get(arg) {
args.push(value.clone());
} else {
let arg = arg.strip_prefix('$').unwrap_or(arg);
let Some(value) = variables.get(arg) else {
bail!("Unknown variable in validator arguments: ${}", arg);
}
};
args.push(value.clone());
}
args
};
Expand Down Expand Up @@ -641,7 +646,6 @@ where
let mut vars = HashMap::new();
vars.insert("INPUT".to_string(), TM_VALIDATION_FILE_NAME.to_string());
vars.insert("ST_NUM".to_string(), (self.subtask_id - 1).to_string());
vars.insert("TC_NUM".to_string(), self.testcase_id.to_string());
if let Some(name) = &self.subtask_name {
vars.insert("ST_NAME".to_string(), name.clone());
}
Expand Down Expand Up @@ -1285,6 +1289,63 @@ mod tests {
}
}

#[test]
fn test_add_subtask_with_default_val() {
let gen = TestHelper::new()
.add_file("gen/generator.py")
.add_file("gen/val.py")
.cases_gen(
":GEN gen gen/generator.py\n:VAL default gen/val.py\n:SUBTASK 42\n:RUN gen 4 5 6",
)
.unwrap();
assert_eq!(gen.subtask_id, 1);
assert_eq!(gen.testcase_id, 1);
assert_eq!(gen.result.len(), 2);
let subtask = &gen.result[0];
let TaskInputEntry::Subtask(subtask) = subtask else {
panic!("Expecting a subtask, got: {:?}", subtask);
};
assert_eq!(subtask.id, 0);
if let InputValidator::Custom(_, args) = &subtask.input_validator {
assert_eq!(args.len(), 2);
assert_eq!(args[1], "0");
} else {
panic!(
"Expecting an AssumeValid but got: {:?}",
subtask.input_validator
);
}
}

#[test]
fn test_subtask_validator_args_custom() {
let gen = TestHelper::new()
.add_file("gen/generator.py")
.add_file("gen/val.py")
.cases_gen(":GEN default gen/generator.py N M seed\n:VAL default gen/val.py $INPUT $ST_NUM\n:SUBTASK 42\n1 2 3")
.unwrap();
assert_eq!(gen.subtask_id, 1);
assert_eq!(gen.testcase_id, 1);
assert_eq!(gen.result.len(), 2);
let subtask = &gen.result[0];
let TaskInputEntry::Subtask(subtask) = subtask else {
panic!("Expecting a subtask, got: {:?}", subtask);
};
assert_eq!(subtask.id, 0);
if let InputValidator::Custom(source, args) = &subtask.input_validator {
assert_eq!(source.name(), "val.py");
assert_eq!(
args,
&vec![TM_VALIDATION_FILE_NAME, "0"]
);
} else {
panic!(
"Expecting a custom validator, got: {:?}",
subtask.input_validator
);
}
}

/**********************
* : COPY
*********************/
Expand All @@ -1308,13 +1369,6 @@ mod tests {
testcase.input_generator
);
}
if let InputValidator::AssumeValid = testcase.input_validator {
} else {
panic!(
"Expecting an AssumeValid but got: {:?}",
testcase.input_validator
);
}
} else {
panic!("Expecting a testcase, got: {:?}", testcase);
}
Expand Down Expand Up @@ -1358,50 +1412,6 @@ mod tests {
testcase.input_generator
);
}
if let InputValidator::AssumeValid = testcase.input_validator {
} else {
panic!(
"Expecting an AssumeValid but got: {:?}",
testcase.input_validator
);
}
} else {
panic!("Expecting a testcase, got: {:?}", testcase);
}
}

#[test]
fn test_add_run_with_val() {
let gen = TestHelper::new()
.add_file("gen/generator.py")
.add_file("gen/val.py")
.cases_gen(
":GEN gen gen/generator.py\n:VAL default gen/val.py\n:SUBTASK 42\n:RUN gen 4 5 6",
)
.unwrap();
assert_eq!(gen.subtask_id, 1);
assert_eq!(gen.testcase_id, 1);
assert_eq!(gen.result.len(), 2);
let testcase = &gen.result[1];
if let TaskInputEntry::Testcase(testcase) = testcase {
assert_eq!(testcase.id, 0);
if let InputGenerator::Custom(_, args) = &testcase.input_generator {
assert_eq!(args, &vec!["4", "5", "6"]);
} else {
panic!(
"Expecting a custom generator, got: {:?}",
testcase.input_generator
);
}
if let InputValidator::Custom(_, args) = &testcase.input_validator {
assert_eq!(args.len(), 2);
assert_eq!(args[1], "0");
} else {
panic!(
"Expecting an AssumeValid but got: {:?}",
testcase.input_validator
);
}
} else {
panic!("Expecting a testcase, got: {:?}", testcase);
}
Expand Down Expand Up @@ -1566,63 +1576,6 @@ mod tests {
}
}

#[test]
fn test_testcase_validator_args_default() {
let gen = TestHelper::new()
.add_file("gen/generator.py")
.add_file("gen/val.py")
.cases_gen(":GEN default gen/generator.py\n:VAL default gen/val.py\n:SUBTASK 42\n1 2 3")
.unwrap();
assert_eq!(gen.subtask_id, 1);
assert_eq!(gen.testcase_id, 1);
assert_eq!(gen.result.len(), 2);
let testcase = &gen.result[1];
if let TaskInputEntry::Testcase(testcase) = testcase {
assert_eq!(testcase.id, 0);
if let InputValidator::Custom(source, args) = &testcase.input_validator {
assert_eq!(source.name(), "val.py");
assert_eq!(args, &vec![TM_VALIDATION_FILE_NAME, "0"]);
} else {
panic!(
"Expecting a custom validator, got: {:?}",
testcase.input_validator
);
}
} else {
panic!("Expecting a testcase, got: {:?}", testcase);
}
}

#[test]
fn test_testcase_validator_args_custom() {
let gen = TestHelper::new()
.add_file("gen/generator.py")
.add_file("gen/val.py")
.cases_gen(":GEN default gen/generator.py N M seed\n:VAL default gen/val.py $N $M $seed $INPUT $TC_NUM $ST_NUM\n:SUBTASK 42\n1 2 3")
.unwrap();
assert_eq!(gen.subtask_id, 1);
assert_eq!(gen.testcase_id, 1);
assert_eq!(gen.result.len(), 2);
let testcase = &gen.result[1];
if let TaskInputEntry::Testcase(testcase) = testcase {
assert_eq!(testcase.id, 0);
if let InputValidator::Custom(source, args) = &testcase.input_validator {
assert_eq!(source.name(), "val.py");
assert_eq!(
args,
&vec!["1", "2", "3", TM_VALIDATION_FILE_NAME, "0", "0"]
);
} else {
panic!(
"Expecting a custom validator, got: {:?}",
testcase.input_validator
);
}
} else {
panic!("Expecting a testcase, got: {:?}", testcase);
}
}

#[test]
fn test_testcase_valid_constraints() {
let gen = TestHelper::new().add_file("gen/generator.py").cases_gen(
Expand Down
3 changes: 1 addition & 2 deletions task-maker-format/src/ioi/format/italian_yaml/gen_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ where
)
.ok(),
is_default: false,
input_validator: get_validator(Some(subtask_id)),
..Default::default()
}));
subtask_id += 1;
Expand Down Expand Up @@ -143,7 +144,6 @@ where
entries.push(TaskInputEntry::Testcase(TestcaseInfo::new(
testcase_count,
InputGenerator::StaticFile(task_dir.join(what)),
get_validator(Some(subtask_id - 1)),
get_output_gen(testcase_count),
)));
testcase_count += 1;
Expand All @@ -162,7 +162,6 @@ where
entries.push(TaskInputEntry::Testcase(TestcaseInfo::new(
testcase_count,
InputGenerator::Custom(generator.clone(), cmd),
get_validator(Some(subtask_id - 1)),
output_generator,
)));
testcase_count += 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ where
name: Some("static-testcases".into()),
max_score: 100.0,
is_default: true,
input_validator: (self.get_validator)(Some(0)),
..Default::default()
}));
}
Expand All @@ -50,7 +51,6 @@ where
Some(TaskInputEntry::Testcase(TestcaseInfo::new(
id,
InputGenerator::StaticFile(path),
(self.get_validator)(Some(0)),
(self.get_output_gen)(id),
)))
} else {
Expand Down
Loading

0 comments on commit de92cff

Please sign in to comment.