Skip to content

Commit

Permalink
feat!: Bit shift is restricted to u8 right operand (#4907)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4841

## Summary\*
bit shift now requires the number of bits to be u8.


## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [X] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <tom@tomfren.ch>
  • Loading branch information
guipublic and TomAFrench authored Apr 26, 2024
1 parent 1ec9cdc commit c4b0369
Show file tree
Hide file tree
Showing 20 changed files with 108 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-nargo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ permissions:

jobs:
build-apple-darwin:
runs-on: macos-latest
runs-on: macos-12
env:
CROSS_CONFIG: ${{ github.workspace }}/.github/Cross.toml
NIGHTLY_RELEASE: ${{ inputs.tag == '' }}
Expand Down
4 changes: 3 additions & 1 deletion acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub(crate) fn evaluate_binary_int_op(
}
}
})?;
let rhs = rhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err {
let rhs_bit_size =
if op == &BinaryIntOp::Shl || op == &BinaryIntOp::Shr { 8 } else { bit_size };
let rhs = rhs.expect_integer_with_bit_size(rhs_bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
Expand Down
14 changes: 10 additions & 4 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,11 @@ impl<'block> BrilligBlock<'block> {
dfg: &DataFlowGraph,
result_variable: SingleAddrVariable,
) {
let binary_type =
type_of_binary_operation(dfg[binary.lhs].get_type(), dfg[binary.rhs].get_type());
let binary_type = type_of_binary_operation(
dfg[binary.lhs].get_type(),
dfg[binary.rhs].get_type(),
binary.operator,
);

let left = self.convert_ssa_single_addr_value(binary.lhs, dfg);
let right = self.convert_ssa_single_addr_value(binary.rhs, dfg);
Expand Down Expand Up @@ -1766,7 +1769,7 @@ impl<'block> BrilligBlock<'block> {
}

/// Returns the type of the operation considering the types of the operands
pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type) -> Type {
pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type, op: BinaryOp) -> Type {
match (lhs_type, rhs_type) {
(_, Type::Function) | (Type::Function, _) => {
unreachable!("Functions are invalid in binary operations")
Expand All @@ -1782,12 +1785,15 @@ pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type) -> Type
}
// If both sides are numeric type, then we expect their types to be
// the same.
(Type::Numeric(lhs_type), Type::Numeric(rhs_type)) => {
(Type::Numeric(lhs_type), Type::Numeric(rhs_type))
if op != BinaryOp::Shl && op != BinaryOp::Shr =>
{
assert_eq!(
lhs_type, rhs_type,
"lhs and rhs types in a binary operation are always the same but got {lhs_type} and {rhs_type}"
);
Type::Numeric(*lhs_type)
}
_ => lhs_type.clone(),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ impl BrilligContext {
result: SingleAddrVariable,
operation: BrilligBinaryOp,
) {
assert!(
lhs.bit_size == rhs.bit_size,
"Not equal bit size for lhs and rhs: lhs {}, rhs {}",
lhs.bit_size,
rhs.bit_size
);
let is_field_op = lhs.bit_size == FieldElement::max_num_bits();
let expected_result_bit_size =
BrilligContext::binary_result_bit_size(operation, lhs.bit_size);
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,12 @@ impl FunctionBuilder {
) -> ValueId {
let lhs_type = self.type_of_value(lhs);
let rhs_type = self.type_of_value(rhs);
assert_eq!(lhs_type, rhs_type, "ICE - Binary instruction operands must have the same type");
if operator != BinaryOp::Shl && operator != BinaryOp::Shr {
assert_eq!(
lhs_type, rhs_type,
"ICE - Binary instruction operands must have the same type"
);
}
let instruction = Instruction::Binary(Binary { lhs, rhs, operator });
self.insert_instruction(instruction, None).first()
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl Context<'_> {
} else {
// we use a predicate to nullify the result in case of overflow
let bit_size_var =
self.numeric_constant(FieldElement::from(bit_size as u128), typ.clone());
self.numeric_constant(FieldElement::from(bit_size as u128), Type::unsigned(8));
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ.clone());
// we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value
Expand Down
30 changes: 5 additions & 25 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ impl<'a> FunctionContext<'a> {
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => {
self.check_shift_overflow(result, rhs, bit_size, location, true)
self.check_shift_overflow(result, rhs, bit_size, location)
}
_ => unreachable!("operator {} should not overflow", operator),
}
Expand Down Expand Up @@ -408,7 +408,7 @@ impl<'a> FunctionContext<'a> {
}
}

self.check_shift_overflow(result, rhs, bit_size, location, false);
self.check_shift_overflow(result, rhs, bit_size, location);
}

_ => unreachable!("operator {} should not overflow", operator),
Expand All @@ -430,32 +430,12 @@ impl<'a> FunctionContext<'a> {
rhs: ValueId,
bit_size: u32,
location: Location,
is_signed: bool,
) -> ValueId {
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
let rhs = if is_signed {
self.insert_safe_cast(rhs, Type::unsigned(bit_size), location)
} else {
rhs
};
// Bit-shift with a negative number is an overflow
if is_signed {
// We compute the sign of rhs.
let half_width = self.builder.numeric_constant(
FieldElement::from(2_i128.pow(bit_size - 1)),
Type::unsigned(bit_size),
);
let sign = self.builder.insert_binary(rhs, BinaryOp::Lt, half_width);
self.builder.set_location(location).insert_constrain(
sign,
one,
Some("attempt to bit-shift with overflow".to_owned().into()),
);
}
assert!(self.builder.current_function.dfg.type_of_value(rhs) == Type::unsigned(8));

let max = self
.builder
.numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(bit_size));
let max =
self.builder.numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(8));
let overflow = self.builder.insert_binary(rhs, BinaryOp::Lt, max);
self.builder.set_location(location).insert_constrain(
overflow,
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub enum TypeCheckError {
FieldModulo { span: Span },
#[error("Fields cannot be compared, try casting to an integer first")]
FieldComparison { span: Span },
#[error("The bit count in a bit-shift operation must fit in a u8, try casting the right hand side into a u8 first")]
InvalidShiftSize { span: Span },
#[error("The number of bits to use for this bitwise operation is ambiguous. Either the operand's type or return type should be specified")]
AmbiguousBitWidth { span: Span },
#[error("Error with additional context")]
Expand Down Expand Up @@ -234,7 +236,8 @@ impl From<TypeCheckError> for Diagnostic {
| TypeCheckError::UnconstrainedReferenceToConstrained { span }
| TypeCheckError::UnconstrainedSliceReturnToConstrained { span }
| TypeCheckError::NonConstantSliceLength { span }
| TypeCheckError::StringIndexAssign { span } => {
| TypeCheckError::StringIndexAssign { span }
| TypeCheckError::InvalidShiftSize { span } => {
Diagnostic::simple_error(error.to_string(), String::new(), span)
}
TypeCheckError::PublicReturnType { typ, span } => Diagnostic::simple_error(
Expand Down
30 changes: 28 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use iter_extended::vecmap;
use noirc_errors::Span;

use crate::ast::{BinaryOpKind, UnaryOp};
use crate::ast::{BinaryOpKind, IntegerBitSize, UnaryOp};
use crate::macros_api::Signedness;
use crate::{
hir::{resolution::resolver::verify_mutable_reference, type_check::errors::Source},
hir_def::{
Expand Down Expand Up @@ -1129,11 +1130,30 @@ impl<'interner> TypeChecker<'interner> {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}

if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
self.unify(
rhs_type,
&Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight),
|| TypeCheckError::InvalidShiftSize { span },
);
let use_impl = if lhs_type.is_numeric() {
let integer_type = Type::polymorphic_integer(self.interner);
self.bind_type_variables_for_infix(lhs_type, op, &integer_type, span)
} else {
true
};
return Ok((lhs_type.clone(), use_impl));
}
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((other.clone(), use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
if *sign_y != Signedness::Unsigned || *bit_width_y != IntegerBitSize::Eight {
return Err(TypeCheckError::InvalidShiftSize { span });
}
return Ok((Integer(*sign_x, *bit_width_x), false));
}
if sign_x != sign_y {
return Err(TypeCheckError::IntegerSignedness {
sign_x: *sign_x,
Expand Down Expand Up @@ -1165,6 +1185,12 @@ impl<'interner> TypeChecker<'interner> {
(Bool, Bool) => Ok((Bool, false)),

(lhs, rhs) => {
if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
if rhs == &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight) {
return Ok((lhs.clone(), true));
}
return Err(TypeCheckError::InvalidShiftSize { span });
}
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/noir/concepts/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ sidebar_position: 3
| ^ | XOR two private input types together | Types must be integer |
| & | AND two private input types together | Types must be integer |
| \| | OR two private input types together | Types must be integer |
| \<\< | Left shift an integer by another integer amount | Types must be integer |
| >> | Right shift an integer by another integer amount | Types must be integer |
| \<\< | Left shift an integer by another integer amount | Types must be integer, shift must be u8 |
| >> | Right shift an integer by another integer amount | Types must be integer, shift must be u8 |
| ! | Bitwise not of a value | Type must be integer or boolean |
| \< | returns a bool if one value is less than the other | Upper bound must have a known bit size |
| \<= | returns a bool if one value is less than or equal to the other | Upper bound must have a known bit size |
Expand Down
28 changes: 14 additions & 14 deletions noir_stdlib/src/ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,30 @@ impl BitXor for i64 { fn bitxor(self, other: i64) -> i64 { self ^ other } }

// docs:start:shl-trait
trait Shl {
fn shl(self, other: Self) -> Self;
fn shl(self, other: u8) -> Self;
}
// docs:end:shl-trait

impl Shl for u32 { fn shl(self, other: u32) -> u32 { self << other } }
impl Shl for u64 { fn shl(self, other: u64) -> u64 { self << other } }
impl Shl for u32 { fn shl(self, other: u8) -> u32 { self << other } }
impl Shl for u64 { fn shl(self, other: u8) -> u64 { self << other } }
impl Shl for u8 { fn shl(self, other: u8) -> u8 { self << other } }
impl Shl for u1 { fn shl(self, other: u1) -> u1 { self << other } }
impl Shl for u1 { fn shl(self, other: u8) -> u1 { self << other } }

impl Shl for i8 { fn shl(self, other: i8) -> i8 { self << other } }
impl Shl for i32 { fn shl(self, other: i32) -> i32 { self << other } }
impl Shl for i64 { fn shl(self, other: i64) -> i64 { self << other } }
impl Shl for i8 { fn shl(self, other: u8) -> i8 { self << other } }
impl Shl for i32 { fn shl(self, other: u8) -> i32 { self << other } }
impl Shl for i64 { fn shl(self, other: u8) -> i64 { self << other } }

// docs:start:shr-trait
trait Shr {
fn shr(self, other: Self) -> Self;
fn shr(self, other: u8) -> Self;
}
// docs:end:shr-trait

impl Shr for u64 { fn shr(self, other: u64) -> u64 { self >> other } }
impl Shr for u32 { fn shr(self, other: u32) -> u32 { self >> other } }
impl Shr for u64 { fn shr(self, other: u8) -> u64 { self >> other } }
impl Shr for u32 { fn shr(self, other: u8) -> u32 { self >> other } }
impl Shr for u8 { fn shr(self, other: u8) -> u8 { self >> other } }
impl Shr for u1 { fn shr(self, other: u1) -> u1 { self >> other } }
impl Shr for u1 { fn shr(self, other: u8) -> u1 { self >> other } }

impl Shr for i8 { fn shr(self, other: i8) -> i8 { self >> other } }
impl Shr for i32 { fn shr(self, other: i32) -> i32 { self >> other } }
impl Shr for i64 { fn shr(self, other: i64) -> i64 { self >> other } }
impl Shr for i8 { fn shr(self, other: u8) -> i8 { self >> other } }
impl Shr for i32 { fn shr(self, other: u8) -> i32 { self >> other } }
impl Shr for i64 { fn shr(self, other: u8) -> i64 { self >> other } }
2 changes: 1 addition & 1 deletion noir_stdlib/src/sha512.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// 64 bytes.
// Internal functions act on 64-bit unsigned integers for simplicity.
// Auxiliary mappings; names as in FIPS PUB 180-4
fn rotr64(a: u64, b: u64) -> u64 // 64-bit right rotation
fn rotr64(a: u64, b: u8) -> u64 // 64-bit right rotation
{
// None of the bits overlap between `(a >> b)` and `(a << (64 - b))`
// Addition is then equivalent to OR, with fewer constraints.
Expand Down
12 changes: 6 additions & 6 deletions noir_stdlib/src/uint128.nr
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ impl BitXor for U128 {
}

impl Shl for U128 {
fn shl(self, other: U128) -> U128 {
assert(other < U128::from_u64s_le(128,0), "attempt to shift left with overflow");
let exp_bits = other.lo.to_be_bits(7);
fn shl(self, other: u8) -> U128 {
assert(other < 128, "attempt to shift left with overflow");
let exp_bits = (other as Field).to_be_bits(7);

let mut r: Field = 2;
let mut y: Field = 1;
Expand All @@ -271,9 +271,9 @@ impl Shl for U128 {
}

impl Shr for U128 {
fn shr(self, other: U128) -> U128 {
assert(other < U128::from_u64s_le(128,0), "attempt to shift right with overflow");
let exp_bits = other.lo.to_be_bits(7);
fn shr(self, other: u8) -> U128 {
assert(other < 128, "attempt to shift right with overflow");
let exp_bits = (other as Field).to_be_bits(7);

let mut r: Field = 2;
let mut y: Field = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ fn main() {
assert((x | y) == or(x, y));
// TODO SSA => ACIR has some issues with xor ops
assert(check_xor(x, y, 4));
assert((x >> y) == shr(x, y));
assert((x << y) == shl(x, y));
assert((x >> y as u8) == shr(x, y as u8));
assert((x << y as u8) == shl(x, y as u8));
}

unconstrained fn add(x: u32, y: u32) -> u32 {
Expand Down Expand Up @@ -67,11 +67,11 @@ unconstrained fn check_xor(x: u32, y: u32, result: u32) -> bool {
(x ^ y) == result
}

unconstrained fn shr(x: u32, y: u32) -> u32 {
unconstrained fn shr(x: u32, y: u8) -> u32 {
x >> y
}

unconstrained fn shl(x: u32, y: u32) -> u32 {
unconstrained fn shl(x: u32, y: u8) -> u32 {
x << y
}

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main(x: u64) {
//regression for 3481
assert(x << 63 == 0);

assert_eq((1 as u64) << (32 as u64), 0x0100000000);
assert_eq((1 as u64) << 32, 0x0100000000);
}

fn regression_2250() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fn main(x: u64, y: u64) {
fn main(x: u64, y: u8) {
// runtime shifts on compile-time known values
assert(64 << y == 128);
assert(64 >> y == 32);
Expand All @@ -11,10 +11,10 @@ fn main(x: u64, y: u64) {
let mut b: i8 = x as i8;
assert(b << 1 == -128);
assert(b >> 2 == 16);
assert(b >> a == 32);
assert(b >> y == 32);
a = -a;
assert(a << 7 == -128);
assert(a << -a == -2);
assert(a << y == -2);

assert(x >> x == 0);
assert(x >> (x as u8) == 0);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
unconstrained fn main(x: u64, y: u64) {
unconstrained fn main(x: u64, y: u8) {
// runtime shifts on compile-time known values
assert(64 << y == 128);
assert(64 >> y == 32);
assert(64 as u32 << y == 128);
assert(64 as u32 >> y == 32);
// runtime shifts on runtime values
assert(x << y == 128);
assert(x >> y == 32);
Expand All @@ -11,10 +11,10 @@ unconstrained fn main(x: u64, y: u64) {
let mut b: i8 = x as i8;
assert(b << 1 == -128);
assert(b >> 2 == 16);
assert(b >> a == 32);
assert(b >> y == 32);
a = -a;
assert(a << 7 == -128);
assert(a << -a == -2);
assert(a << y == -2);

assert(x >> x == 0);
assert(x >> (x as u8) == 0);
}
Loading

0 comments on commit c4b0369

Please sign in to comment.