Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Allow impls on primitive types #847

Merged
merged 10 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/nargo/tests/target_tests_data/pass/import/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
mod import;
use dep::std;

use crate::import::hello;

fn main(x : Field, y : Field) {
let _k = std::hash::pedersen([x]);
let _k = dep::std::hash::pedersen([x]);
let _l = hello(x);

constrain x != import::hello(y);
Expand Down
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/7_function/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn test2(z: Field, t: u32 ) {

fn pow(base: Field, exponent: Field) -> Field {
let mut r = 1 as Field;
let b = std::field::to_le_bits(exponent, 32 as u32);
let b = exponent.to_le_bits(32 as u32);
for i in 1..33 {
r = r*r;
r = (b[32-i] as Field) * (r * base) + (1 - b[32-i] as Field) * r;
Expand Down
5 changes: 2 additions & 3 deletions crates/nargo/tests/test_data/9_conditional/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ fn test4() -> [u32; 4] {
}

fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){

// Regression test for issue #547
// Warning: it must be kept at the start of main
let arr: [u8; 2] = [1, 2];
Expand All @@ -49,7 +48,7 @@ fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){
//Issue reported in #421
if a == c[0] {
constrain c[0] == 0;
} else {
} else {
if a == c[1] {
constrain c[1] == 0;
} else {
Expand All @@ -64,7 +63,7 @@ fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){
let as_bits_hardcode_1 = [1, 0];
let mut c1 = 0;
for i in 0..2 {
let mut as_bits = std::field::to_le_bits(arr[i] as Field, 2);
let mut as_bits = (arr[i] as Field).to_le_bits(2);
c1 = c1 + as_bits[0] as Field;

if i == 0 {
Expand Down
8 changes: 4 additions & 4 deletions crates/nargo/tests/test_data/array_len/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use dep::std;

fn len_plus_1<T>(array: [T]) -> Field {
std::array::len(array) + 1
array.len() + 1
}

fn add_lens<T>(a: [T], b: [Field]) -> Field {
std::array::len(a) + std::array::len(b)
a.len() + b.len()
}

fn nested_call(b: [Field]) -> Field {
Expand All @@ -19,13 +19,13 @@ fn main(len3: [u8; 3], len4: [Field; 4]) {
constrain nested_call(len4) == 5;

// std::array::len returns a comptime value
constrain len4[std::array::len(len3)] == 4;
constrain len4[len3.len()] == 4;

// test for std::array::sort
let mut unsorted = len3;
unsorted[0] = len3[1];
unsorted[1] = len3[0];
constrain unsorted[0] > unsorted[1];
let sorted = std::array::sort(unsorted);
let sorted = unsorted.sort();
constrain sorted[0] < sorted[1];
}
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/hash_to_field/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ use dep::std;

fn main(input : Field) -> pub Field {
std::hash::hash_to_field([input])
}
}
10 changes: 5 additions & 5 deletions crates/nargo/tests/test_data/higher-order-functions/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ fn main() -> pub Field {
/// Test the array functions in std::array
fn test_array_functions() {
let myarray: [i32; 3] = [1, 2, 3];
constrain std::array::any(myarray, |n| n > 2);
constrain myarray.any(|n| n > 2);

let evens: [i32; 3] = [2, 4, 6];
constrain std::array::all(evens, |n| n > 1);
constrain evens.all(|n| n > 1);

constrain std::array::fold(evens, 0, |a, b| a + b) == 12;
constrain std::array::reduce(evens, |a, b| a + b) == 12;
constrain evens.fold(0, |a, b| a + b) == 12;
constrain evens.reduce(|a, b| a + b) == 12;

let descending = std::array::sort_via(myarray, |a, b| a > b);
let descending = myarray.sort_via(|a, b| a > b);
constrain descending == [3, 2, 1];
}

Expand Down
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/strings/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ struct Test {
a: Field,
b: Field,
c: [Field; 2],
}
}
4 changes: 2 additions & 2 deletions crates/nargo/tests/test_data/struct_inputs/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main(x : Field, y : pub myStruct, z: pub foo::bar::barStruct, a: pub foo::foo

check_inner_struct(a, z);

for i in 0..std::array::len(struct_from_bar.array) {
for i in 0 .. struct_from_bar.array.len() {
constrain struct_from_bar.array[i] == z.array[i];
}
constrain z.val == struct_from_bar.val;
Expand All @@ -30,7 +30,7 @@ fn main(x : Field, y : pub myStruct, z: pub foo::bar::barStruct, a: pub foo::foo

fn check_inner_struct(a: foo::fooStruct, z: foo::bar::barStruct) {
constrain a.bar_struct.val == z.val;
for i in 0..std::array::len(a.bar_struct.array) {
for i in 0.. a.bar_struct.array.len() {
constrain a.bar_struct.array[i] == z.array[i];
}
}
4 changes: 2 additions & 2 deletions crates/nargo/tests/test_data/to_le_bytes/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use dep::std;

fn main(x : Field) -> pub [u8; 4] {
// The result of this byte array will be little-endian
let byte_array = std::field::to_le_bytes(x, 31);
let byte_array = x.to_le_bytes(31);
let mut first_four_bytes = [0; 4];
for i in 0..4 {
first_four_bytes[i] = byte_array[i];
Expand All @@ -11,4 +11,4 @@ fn main(x : Field) -> pub [u8; 4] {
// We were incorrectly mapping our output array from bit decomposition functions during acir generation
first_four_bytes[3] = byte_array[31];
first_four_bytes
}
}
11 changes: 8 additions & 3 deletions crates/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,14 @@ fn collect_impls(
}
}
} else if typ != Type::Error {
let span = *span;
let error = DefCollectorErrorKind::NonStructTypeInImpl { span };
errors.push(error.into_file_diagnostic(unresolved.file_id))
if true
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
/* in std crate */
{
} else {
let span = *span;
let error = DefCollectorErrorKind::NonStructTypeInImpl { span };
errors.push(error.into_file_diagnostic(unresolved.file_id))
}
}
}
}
Expand Down
17 changes: 10 additions & 7 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,16 @@ fn lookup_method(

// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => {
errors.push(TypeCheckError::Unstructured {
span: interner.expr_span(expr_id),
msg: format!("Type '{other}' must be a struct type to call methods on it"),
});
None
}
other => match interner.lookup_primitive_method(other, method_name) {
Some(method_id) => Some(method_id),
None => {
errors.push(TypeCheckError::Unstructured {
span: interner.expr_span(expr_id),
msg: format!("No method named '{method_name}' found for type '{other}'",),
});
None
}
},
}
}

Expand Down
69 changes: 61 additions & 8 deletions crates/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ pub struct NodeInterner {

delayed_type_checks: Vec<TypeCheckFn>,

// A map from a struct type and method name to a function id for the method
// along with any generic on the struct it may require. E.g. if the impl is
// only for `impl Foo<String>` rather than all Foo, the generics will be `vec![String]`.
struct_methods: HashMap<(StructId, String), (Vec<Type>, FuncId)>,
/// A map from a struct type and method name to a function id for the method.
struct_methods: HashMap<(StructId, String), FuncId>,

/// Methods on primitive types defined in the stdlib.
primitive_methods: HashMap<(TypeMethodKey, String), FuncId>,
}

type TypeCheckFn = Box<dyn FnOnce() -> Result<(), TypeCheckError>>;
Expand Down Expand Up @@ -241,6 +242,7 @@ impl Default for NodeInterner {
language: Language::R1CS,
delayed_type_checks: vec![],
struct_methods: HashMap::new(),
primitive_methods: HashMap::new(),
};

// An empty block expression is used often, we add this into the `node` on startup
Expand Down Expand Up @@ -585,16 +587,67 @@ impl NodeInterner {
method_id: FuncId,
) -> Option<FuncId> {
match self_type {
Type::Struct(struct_type, generics) => {
Type::Struct(struct_type, _generics) => {
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, (generics.clone(), method_id)).map(|(_, id)| id)
self.struct_methods.insert(key, method_id)
}
Type::Error => None,

other => {
let key = get_type_method_key(self_type).unwrap_or_else(|| {
unreachable!("Cannot add a method to the unsupported type '{}'", other)
});
self.primitive_methods.insert((key, method_name), method_id)
}
other => unreachable!("Tried adding method to non-struct type '{}'", other),
}
}

/// Search by name for a method on the given struct
pub fn lookup_method(&self, id: StructId, method_name: &str) -> Option<FuncId> {
self.struct_methods.get(&(id, method_name.to_owned())).map(|(_, id)| *id)
self.struct_methods.get(&(id, method_name.to_owned())).copied()
}

/// Looks up a given method name on the given primitive type.
pub fn lookup_primitive_method(&self, typ: &Type, method_name: &str) -> Option<FuncId> {
get_type_method_key(typ)
.and_then(|key| self.primitive_methods.get(&(key, method_name.to_owned())).copied())
}
}

/// These are the primitive type variants that we support adding methods to
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
enum TypeMethodKey {
/// Fields and integers share methods for ease of use. These methods may still
/// accept only fields or integers, it is just that their names may not clash.
FieldOrInt,
Array,
Bool,
String,
Unit,
Tuple,
Function,
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}

fn get_type_method_key(typ: &Type) -> Option<TypeMethodKey> {
use TypeMethodKey::*;
let typ = typ.follow_bindings();
match &typ {
Type::FieldElement(_) => Some(FieldOrInt),
Type::Array(_, _) => Some(Array),
Type::Integer(_, _, _) => Some(FieldOrInt),
Type::PolymorphicInteger(_, _) => Some(FieldOrInt),
Type::Bool(_) => Some(Bool),
Type::String(_) => Some(String),
Type::Unit => Some(Unit),
Type::Tuple(_) => Some(Tuple),
Type::Function(_, _) => Some(Function),

// We do not support adding methods to these types
Type::TypeVariable(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _)
| Type::Constant(_)
| Type::Error
| Type::Struct(_, _) => None,
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}
}
18 changes: 7 additions & 11 deletions crates/noirc_frontend/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::sync::atomic::{AtomicU32, Ordering};
use crate::token::{Keyword, Token};
use crate::{ast::ImportStatement, Expression, NoirStruct};
use crate::{
BlockExpression, CallExpression, ExpressionKind, ForExpression, Ident, IndexExpression,
LetStatement, NoirFunction, NoirImpl, Path, PathKind, Pattern, Recoverable, Statement,
BlockExpression, ExpressionKind, ForExpression, Ident, IndexExpression, LetStatement,
MethodCallExpression, NoirFunction, NoirImpl, Path, PathKind, Pattern, Recoverable, Statement,
UnresolvedType,
};

Expand Down Expand Up @@ -371,19 +371,15 @@ impl ForRange {
expression: array,
});

let ident = |name: &str| Ident::new(name.to_string(), array_span);

// std::array::len(array)
// array.len()
let segments = vec![array_ident];
let array_ident =
ExpressionKind::Variable(Path { segments, kind: PathKind::Plain });

let segments = vec![ident("std"), ident("array"), ident("len")];
let func_ident = ExpressionKind::Variable(Path { segments, kind: PathKind::Dep });

let end_range = ExpressionKind::Call(Box::new(CallExpression {
func: Box::new(Expression::new(func_ident, array_span)),
arguments: vec![Expression::new(array_ident.clone(), array_span)],
let end_range = ExpressionKind::MethodCall(Box::new(MethodCallExpression {
object: Expression::new(array_ident.clone(), array_span),
method_name: Ident::new("len".to_string(), array_span),
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
arguments: vec![],
}));
let end_range = Expression::new(end_range, array_span);

Expand Down
Loading