Skip to content

Commit

Permalink
Allow some function equality comparison (#1978)
Browse files Browse the repository at this point in the history
* Allow some function equality comparison

Relax the initial restriction on comparing functions, which is that
comparing a function to anything else error out. In practice comparing a
function to a value of some other type isn't a problem, and can be
useful to allow patterns like `if x != null then ... else ...` without
having first to check if `x` is a function, because that would fail at
runtime.

Instead, we only forbid comparison between two function-like values
(functions, match expressions and custom contracts) and between two
opaque foreign values.

* Fix typo in code comment

* Additional tests for relaxed function comparison

Check the relaxed function comparison, which errors out when trying to
compare two functions, but can compare a function to a value of a
different type (by always returning false).

Tests that functions can't be compared were already there, but this
commit adds tests that function can be compared to other values.
  • Loading branch information
yannham authored Jul 1, 2024
1 parent a3bdf62 commit 89639ca
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 48 deletions.
54 changes: 36 additions & 18 deletions core/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,12 @@ pub enum EvalError {
label: label::Label,
call_stack: CallStack,
},
/// A non-equatable term was compared for equality.
EqError { eq_pos: TermPos, term: RichTerm },
/// Two non-equatable terms of the same type (e.g. functions) were compared for equality.
IncomparableValues {
eq_pos: TermPos,
left: RichTerm,
right: RichTerm,
},
/// A value didn't match any branch of a `match` expression at runtime. This is a specialized
/// version of [Self::NonExhaustiveMatch] when all branches are enum patterns. In this case,
/// the error message is more informative than the generic one.
Expand Down Expand Up @@ -1365,28 +1369,42 @@ impl IntoDiagnostics<FileId> for EvalError {
.with_message(format!("{format} parse error: {msg}"))
.with_labels(labels)]
}
EvalError::EqError { eq_pos, term: t } => {
let label = format!(
"an argument has type {}, which cannot be compared for equality",
t.term
EvalError::IncomparableValues {
eq_pos,
left,
right,
} => {
let mut labels = Vec::new();

if let Some(span) = eq_pos.as_opt_ref() {
labels.push(primary(span).with_message("in this equality comparison"));
}

// Push the label for the right or left argument and return the type of said
// argument.
let mut push_label = |prefix: &str, term: &RichTerm| -> String {
let type_of = term
.term
.type_of()
.unwrap_or_else(|| String::from("<unevaluated>")),
);
.unwrap_or_else(|| String::from("<unevaluated>"));

let labels = match eq_pos {
TermPos::Original(pos) | TermPos::Inherited(pos) if eq_pos != t.pos => {
vec![
primary(&pos).with_message(label),
secondary_term(&t, files)
.with_message("problematic argument evaluated to this"),
]
}
_ => vec![primary_term(&t, files).with_message(label)],
labels.push(
secondary_term(term, files)
.with_message(format!("{prefix} argument has type {type_of}")),
);

type_of
};

let left_type = push_label("left", &left);
let right_type = push_label("right", &right);

vec![Diagnostic::error()
.with_message("cannot compare values for equality")
.with_labels(labels)]
.with_labels(labels)
.with_notes(vec![format!(
"A {left_type} can't be meaningfully compared with a {right_type}"
)])]
}
EvalError::NonExhaustiveEnumMatch {
expected,
Expand Down
48 changes: 31 additions & 17 deletions core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3325,7 +3325,25 @@ impl RecPriority {
/// # Return
///
/// If the comparison is successful, returns a bool indicating whether the values were equal,
/// otherwise returns an [`EvalError`] indicating that the values cannot be compared.
/// otherwise returns an [`EvalError`] indicating that the values cannot be compared (typically two
/// functions).
///
/// # Uncomparable values
///
/// Comparing two functions is undecidable. Even in simple cases, it's not trivial to handle an
/// approximation (functions might capture free variables, you'd need to take eta-conversion into
/// account to equate e.g. `fun x => x` and `fun y => y`, etc.).
///
/// Thus, by default, comparing a function to something else always returns `false`. However, this
/// breaks the reflexivity property of equality, which users might rightfully rely on, because `fun
/// x => x` isn't equal to itself. Also, comparing two functions is probably never intentional nor
/// meaningful: thus we error out when trying to compare two functions. We still allow comparing
/// functions to something else, because it's useful to have tests like `if value == 1` or `if
/// value == null` typically in contracts without having to defensively check that `value` is a
/// function.
///
/// The same reasoning applies to foreign values (which we don't want to compare for security
/// reasons, at least right now, not because we can't).
fn eq<C: Cache>(
cache: &mut C,
c1: Closure,
Expand Down Expand Up @@ -3531,22 +3549,18 @@ fn eq<C: Cache>(
}
}
}
(Term::Fun(i, rt), _) => Err(EvalError::EqError {
eq_pos: pos_op,
term: RichTerm::new(Term::Fun(i, rt), pos1),
}),
(_, Term::Fun(i, rt)) => Err(EvalError::EqError {
eq_pos: pos_op,
term: RichTerm::new(Term::Fun(i, rt), pos2),
}),
(Term::ForeignId(v), _) => Err(EvalError::EqError {
eq_pos: pos_op,
term: RichTerm::new(Term::ForeignId(v), pos1),
}),
(_, Term::ForeignId(v)) => Err(EvalError::EqError {
eq_pos: pos_op,
term: RichTerm::new(Term::ForeignId(v), pos2),
}),
// Function-like terms and foreign id can't be compared together.
(
t1 @ (Term::Fun(..) | Term::Match(_) | Term::CustomContract(_)),
t2 @ (Term::Fun(..) | Term::Match(_) | Term::CustomContract(_)),
)
| (t1 @ Term::ForeignId(_), t2 @ Term::ForeignId(_)) => {
Err(EvalError::IncomparableValues {
eq_pos: pos_op,
left: RichTerm::new(t1, pos1),
right: RichTerm::new(t2, pos2),
})
}
(_, _) => Ok(EqResult::Bool(false)),
}
}
Expand Down
5 changes: 4 additions & 1 deletion core/src/eval/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,10 @@ fn foreign_id() {
RichTerm::from(Term::ForeignId(43)),
RichTerm::from(Term::ForeignId(42)),
);
assert_matches!(eval_no_import(t_eq), Err(EvalError::EqError { .. }));
assert_matches!(
eval_no_import(t_eq),
Err(EvalError::IncomparableValues { .. })
);

// Opaque values cannot be merged (even if they're equal, since they can't get compared for equality).
let t_merge = mk_term::op2(
Expand Down
7 changes: 7 additions & 0 deletions core/tests/integration/inputs/contracts/equating_fn_match.ncl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::IncomparableValues'
let g = fun x => x + 1 in
let h = match { 0 => 0 } in
g == h
6 changes: 0 additions & 6 deletions core/tests/integration/inputs/contracts/equating_fns_lhs.ncl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# test.type = 'error'
#
# [test.metadata]
# error = 'EvalError::EqError'
# error = 'EvalError::IncomparableValues'
let g = fun x => x + 1 in
"a" == g
g == g
5 changes: 5 additions & 0 deletions core/tests/integration/inputs/core/eq.ncl
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,10 @@
'Left 1 != 'Left 2,
'Left 2 != 'Right 2,
'Up [1,2,3] == 'Up [0+1,1+1,2+1],

# Functions can be compared to non-functions

(fun x => x) != 1,
((+) 1) != null,
]
|> std.test.assert_all
8 changes: 4 additions & 4 deletions core/tests/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ enum Expectation {
#[serde(tag = "error", content = "expectation")]
enum ErrorExpectation {
// TODO: can we somehow unify this with the `Display` impl below?
#[serde(rename = "EvalError::EqError")]
EvalEqError,
#[serde(rename = "EvalError::IncomparableValues")]
EvalIncomparableValues,
#[serde(rename = "EvalError::Other")]
EvalOther,
#[serde(rename = "EvalError::UnaryPrimopTypeError")]
Expand Down Expand Up @@ -230,7 +230,7 @@ impl PartialEq<Error> for ErrorExpectation {
Error::EvalError(EvalError::IllegalPolymorphicTailAccess { .. }),
)
| (EvalTypeError, Error::EvalError(EvalError::TypeError(..)))
| (EvalEqError, Error::EvalError(EvalError::EqError { .. }))
| (EvalIncomparableValues, Error::EvalError(EvalError::IncomparableValues { .. }))
| (EvalNAryPrimopTypeError, Error::EvalError(EvalError::NAryPrimopTypeError { .. }))
| (
EvalUnaryPrimopTypeError,
Expand Down Expand Up @@ -393,7 +393,7 @@ impl std::fmt::Display for ErrorExpectation {
ImportIoError => "ImportError::IoError".to_owned(),
EvalBlameError => "EvalError::BlameError".to_owned(),
EvalTypeError => "EvalError::TypeError".to_owned(),
EvalEqError => "EvalError::EqError".to_owned(),
EvalIncomparableValues => "EvalError::IncomparableValues".to_owned(),
EvalOther => "EvalError::Other".to_owned(),
EvalMergeIncompatibleArgs => "EvalError::MergeIncompatibleArgs".to_owned(),
EvalNAryPrimopTypeError => "EvalError::NAryPrimopTypeError".to_owned(),
Expand Down

0 comments on commit 89639ca

Please sign in to comment.