Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wgsl-in, glsl-in: Short circuit || and && #2028

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,120 @@ impl Context {
self.add_expression(Expression::Constant(constant), meta, body)
}
HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
// Logical operators must short circuit by emitting an if statement.
// Handle them as a special case.
if let BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr = op {
// Lower the lhs, then emit all expressions lowered so far.
let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?;
self.emit_restart(body);

// Lower the rhs into a special body for use in the if statement.
let mut right_body = Block::new();
let (mut right, right_meta) = self.lower_expect_inner(
stmt,
parser,
right,
ExprPos::Rhs,
&mut right_body,
)?;
self.emit_restart(&mut right_body);

// Type check and emit a conversion if necessary.
parser.typifier_grow(self, left, left_meta)?;
parser.typifier_grow(self, right, right_meta)?;
self.binary_implicit_conversion(
parser, &mut left, left_meta, &mut right, right_meta,
)?;
self.emit_end(body);

// Create a temporary local to hold the result of the operator.
let bool_ty = parser.module.types.insert(
Type {
name: None,
inner: TypeInner::Scalar {
kind: ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
Default::default(),
);
let local = self.locals.append(
LocalVariable {
name: None,
ty: bool_ty,
init: None,
},
meta,
);
let local_expr = self
.expressions
.append(Expression::LocalVariable(local), meta);

// Store the result of the RHS to the local.
right_body.push(
Statement::Store {
pointer: local_expr,
value: right,
},
right_meta,
);

// Create a value representing the result of the operator if it short circuits and does not evaluate the RHS.
let short_circuit_value = match op {
BinaryOperator::LogicalAnd => false,
BinaryOperator::LogicalOr => true,
_ => unreachable!(),
};
let short_circuit_constant = parser.module.constants.fetch_or_append(
Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::boolean(short_circuit_value),
},
Default::default(),
);
let short_circuit_expr = self
.expressions
.append(Expression::Constant(short_circuit_constant), meta);

// Create a short circuit body, which assigns the short circuit value to the local.
let mut short_circuit_body = Block::new();
short_circuit_body.push(
Statement::Store {
pointer: local_expr,
value: short_circuit_expr,
},
meta,
);

// Add an if statement which either evaluates the RHS block or the short circuit block.
let (accept, reject) = match op {
BinaryOperator::LogicalAnd => (right_body, short_circuit_body),
BinaryOperator::LogicalOr => (short_circuit_body, right_body),
_ => unimplemented!(),
};
body.push(
Statement::If {
condition: left,
accept,
reject,
},
meta,
);

// The result of lowering this operator is just the local in which the result is stored.
self.emit_start();
let load_local_expr = self.expressions.append(
Expression::Load {
pointer: local_expr,
},
meta,
);

return Ok((Some(load_local_expr), meta));
};

let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?;
let (mut right, right_meta) =
Expand Down
124 changes: 121 additions & 3 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ struct ExpressionContext<'input, 'temp, 'out> {
types: &'out mut UniqueArena<crate::Type>,
constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>,
local_vars: &'out Arena<crate::LocalVariable>,
local_vars: &'out mut Arena<crate::LocalVariable>,
arguments: &'out [crate::FunctionArgument],
functions: &'out Arena<crate::Function>,
block: &'temp mut crate::Block,
Expand Down Expand Up @@ -897,6 +897,124 @@ impl<'a> ExpressionContext<'a, '_, '_> {
Ok(accumulator)
}

fn parse_binary_short_circuit_op(
&mut self,
lexer: &mut Lexer<'a>,
classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
mut parser: impl FnMut(
&mut Lexer<'a>,
ExpressionContext<'a, '_, '_>,
) -> Result<TypedExpression, Error<'a>>,
) -> Result<TypedExpression, Error<'a>> {
let start = lexer.current_byte_offset() as u32;
let mut accumulator = parser(lexer, self.reborrow())?;
while let Some(op) = classifier(lexer.peek().0) {
let _ = lexer.next();

// Apply load rule to the lhs.
let left = self.apply_load_rule(accumulator);

// Emit all previous expressions, and prepare a body for the rhs.
self.block.extend(self.emitter.finish(self.expressions));
let mut rhs_body = crate::Block::new();
let mut rhs_ctx = self.reborrow();
rhs_ctx.block = &mut rhs_body;
rhs_ctx.emitter.start(rhs_ctx.expressions);

// Parse the rhs using the rhs body.
let unloaded_right = parser(lexer, rhs_ctx)?;
let end = lexer.current_byte_offset() as u32;
let span = NagaSpan::new(start, end);
let right = self.apply_load_rule(unloaded_right);
rhs_body.extend(self.emitter.finish(self.expressions));

// Create a temporary local to store the result of the operator.
let local_ty = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
Default::default(),
);
let local = self.local_vars.append(
crate::LocalVariable {
name: None,
ty: local_ty,
init: None,
},
span,
);

// Assign the rhs expression to the local variable inside the rhs body.
let local_expr = self
.expressions
.append(crate::Expression::LocalVariable(local), span);

rhs_body.push(
crate::Statement::Store {
pointer: local_expr,
value: right,
},
span,
);

// Make a short circuit body which assigns the local variable to the short circuit value.
let short_circuit_value = match op {
crate::BinaryOperator::LogicalAnd => false,
crate::BinaryOperator::LogicalOr => true,
_ => unreachable!(),
};

let short_circuit_constant = self.constants.fetch_or_append(
crate::Constant {
name: None,
specialization: None,
inner: ConstantInner::boolean(short_circuit_value),
},
Default::default(),
);
let short_circuit_expr = self
.expressions
.append(crate::Expression::Constant(short_circuit_constant), span);

let mut short_circuit_body = crate::Block::new();
short_circuit_body.push(
crate::Statement::Store {
pointer: local_expr,
value: short_circuit_expr,
},
span,
);

// Append an if statement. This implements the short circuiting behavior, by deciding whether to
// evaluate the rhs based on the value of the lhs.
let (accept, reject) = match op {
crate::BinaryOperator::LogicalAnd => (rhs_body, short_circuit_body),
crate::BinaryOperator::LogicalOr => (short_circuit_body, rhs_body),
_ => unreachable!(),
};
self.block.push(
crate::Statement::If {
condition: left,
accept,
reject,
},
span,
);

// Prepare to resume parsing expressions.
self.emitter.start(self.expressions);
accumulator = TypedExpression {
handle: local_expr,
is_reference: true,
};
}
Ok(accumulator)
}

fn parse_binary_splat_op(
&mut self,
lexer: &mut Lexer<'a>,
Expand Down Expand Up @@ -2683,15 +2801,15 @@ impl Parser {
) -> Result<(TypedExpression, Span), Error<'a>> {
self.push_scope(Scope::GeneralExpr, lexer);
// logical_or_expression
let handle = context.parse_binary_op(
let handle = context.parse_binary_short_circuit_op(
lexer,
|token| match token {
Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
_ => None,
},
// logical_and_expression
|lexer, mut context| {
context.parse_binary_op(
context.parse_binary_short_circuit_op(
lexer,
|token| match token {
Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
Expand Down
12 changes: 12 additions & 0 deletions tests/in/glsl/short_circuit.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#version 460 core

bool a() { return true; }
bool b() { return true; }
bool c() { return true; }
bool d() { return true; }

void main() {
bool out1 = a() || b() || c();
bool out2 = a() && b() && c();
bool out3 = (a() || b()) && (c() || d());
}
2 changes: 2 additions & 0 deletions tests/in/short_circuit.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
(
)
11 changes: 11 additions & 0 deletions tests/in/short_circuit.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
fn a() -> bool { return true; }
fn b() -> bool { return true; }
fn c() -> bool { return true; }
fn d() -> bool { return true; }

@compute @workgroup_size(1)
fn main() {
_ = a() || b() || c();
_ = a() && b() && c();
_ = (a() || b()) && (c() || d());
}
16 changes: 14 additions & 2 deletions tests/out/glsl/operators.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,22 @@ float constructors() {
}

void logical() {
bool local = false;
bool local_1 = false;
bool unnamed_11 = (!true);
bvec2 unnamed_12 = not(bvec2(true));
bool unnamed_13 = (true || false);
bool unnamed_14 = (true && false);
if (true) {
local = true;
} else {
local = false;
}
bool unnamed_13 = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool unnamed_14 = local_1;
bool unnamed_15 = (true || false);
bvec3 unnamed_16 = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z);
bool unnamed_17 = (true && false);
Expand Down
Loading