diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 69d5223293..72081bd736 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -78,7 +78,7 @@ impl ExpressionConstnessTracker { } } -#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[derive(Clone, Debug, thiserror::Error)] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, @@ -144,6 +144,8 @@ pub enum ConstantEvaluatorError { RemainderByZero, #[error("RHS of shift operation is greater than or equal to 32")] ShiftedMoreThan32Bits, + #[error(transparent)] + Literal(#[from] crate::valid::LiteralError), } impl<'a> ConstantEvaluator<'a> { @@ -266,18 +268,18 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.constants[c].init) } Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { - Ok(self.register_evaluated_expr(expr.clone(), span)) + self.register_evaluated_expr(expr.clone(), span) } Expression::Compose { ty, ref components } => { let components = components .iter() .map(|component| self.check_and_get(*component)) .collect::, _>>()?; - Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) + self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.check_and_get(value)?; - Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span)) + self.register_evaluated_expr(Expression::Splat { size, value }, span) } Expression::AccessIndex { base, index } => { let base = self.check_and_get(base)?; @@ -391,7 +393,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![value; size as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::ZeroValue(ty) => { let inner = match self.types[ty].inner { @@ -400,7 +402,7 @@ impl<'a> ConstantEvaluator<'a> { }; let res_ty = self.types.insert(Type { name: None, inner }, span); let expr = Expression::ZeroValue(res_ty); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SplatScalarOnly), } @@ -432,11 +434,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) => { let dst_ty = get_dst_ty(ty)?; let expr = Expression::ZeroValue(dst_ty); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::Splat { value, .. } => { let expr = Expression::Splat { size, value }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; @@ -464,7 +466,7 @@ impl<'a> ConstantEvaluator<'a> { ty: dst_ty, components: swizzled_components, }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), } @@ -561,7 +563,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn math_clamp( @@ -666,7 +668,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidMathArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn array_length( @@ -680,7 +682,7 @@ impl<'a> ConstantEvaluator<'a> { TypeInner::Array { size, .. } => match size { crate::ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } crate::ArraySize::Dynamic => { Err(ConstantEvaluatorError::ArrayLengthDynamic) @@ -718,7 +720,7 @@ impl<'a> ConstantEvaluator<'a> { self.types.insert(Type { name: None, inner }, span) } }; - Ok(self.register_evaluated_expr(Expression::ZeroValue(ty), span)) + self.register_evaluated_expr(Expression::ZeroValue(ty), span) } } Expression::Splat { size, value } => { @@ -784,7 +786,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::zero(kind, width) .ok_or(ConstantEvaluatorError::TypeNotConstructible)?, ); - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Vector { size, kind, width } => { let scalar_ty = self.types.insert( @@ -799,7 +801,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; size as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Matrix { columns, @@ -822,7 +824,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; columns as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Array { base, @@ -834,7 +836,7 @@ impl<'a> ConstantEvaluator<'a> { ty, components: vec![el; size.get() as usize], }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } TypeInner::Struct { ref members, .. } => { let types: Vec<_> = members.iter().map(|m| m.ty).collect(); @@ -843,7 +845,7 @@ impl<'a> ConstantEvaluator<'a> { components.push(self.eval_zero_value_impl(ty, span)?); } let expr = Expression::Compose { ty, components }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } _ => Err(ConstantEvaluatorError::TypeNotConstructible), } @@ -929,7 +931,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidCastArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn unary_op( @@ -973,7 +975,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } fn binary_op( @@ -1109,7 +1111,7 @@ impl<'a> ConstantEvaluator<'a> { _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }; - Ok(self.register_evaluated_expr(expr, span)) + self.register_evaluated_expr(expr, span) } /// Deep copy `expr` from `expressions` into `self.expressions`. @@ -1128,17 +1130,17 @@ impl<'a> ConstantEvaluator<'a> { match expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) - | Expression::ZeroValue(_)) => Ok(self.register_evaluated_expr(expr.clone(), span)), + | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span), Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.copy_from(*component, expressions)?; } - Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) + self.register_evaluated_expr(Expression::Compose { ty, components }, span) } Expression::Splat { size, value } => { let value = self.copy_from(value, expressions)?; - Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span)) + self.register_evaluated_expr(Expression::Splat { size, value }, span) } _ => { log::debug!("copy_from: SubexpressionsAreNotConstant"); @@ -1147,8 +1149,15 @@ impl<'a> ConstantEvaluator<'a> { } } - fn register_evaluated_expr(&mut self, expr: Expression, span: Span) -> Handle { - // TODO: use the validate_literal function from https://github.com/gfx-rs/naga/pull/2508 here + fn register_evaluated_expr( + &mut self, + expr: Expression, + span: Span, + ) -> Result, ConstantEvaluatorError> { + match expr { + Expression::Literal(literal) => crate::valid::validate_literal(literal)?, + _ => {} + } if let Some(FunctionLocalData { ref mut emitter, @@ -1164,14 +1173,14 @@ impl<'a> ConstantEvaluator<'a> { let h = self.expressions.append(expr, span); emitter.start(self.expressions); expression_constness.insert(h); - h + Ok(h) } else { let h = self.expressions.append(expr, span); expression_constness.insert(h); - h + Ok(h) } } else { - self.expressions.append(expr, span) + Ok(self.expressions.append(expr, span)) } } } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 0432e0b00a..064f700047 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1568,7 +1568,7 @@ impl super::Validator { } } -fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> { +pub fn validate_literal(literal: crate::Literal) -> Result<(), LiteralError> { let is_nan = match literal { crate::Literal::F64(v) => v.is_nan(), crate::Literal::F32(v) => v.is_nan(), diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 6175aa0945..8c065bb159 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -24,6 +24,7 @@ use std::ops; use crate::span::{AddSpan as _, WithSpan}; pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use compose::ComposeError; +pub use expression::{validate_literal, LiteralError}; pub use expression::{ConstExpressionError, ExpressionError}; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError};