diff --git a/benches/accum_conv.rs b/benches/accum_conv.rs index 3e40c502b..13be04c21 100644 --- a/benches/accum_conv.rs +++ b/benches/accum_conv.rs @@ -72,6 +72,7 @@ impl Circuit for MyCircuit { Box::new(PolyOp::Conv { padding: vec![(0, 0)], stride: vec![1; 2], + group: 1, }), ) .unwrap(); diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index 866773b47..2825c1b59 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -205,6 +205,7 @@ where let op = PolyOp::Conv { padding: vec![(PADDING, PADDING); 2], stride: vec![STRIDE; 2], + group: 1, }; let x = config .layer_config diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index f957bf800..766365e75 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -3023,7 +3023,7 @@ pub fn sumpool(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3171,7 +3171,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3184,7 +3184,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[17]), &[1, 1, 1, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3197,7 +3197,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3209,7 +3209,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3221,7 +3221,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3233,7 +3233,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3244,7 +3244,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// let x = ValTensor::from_i64_tensor(Tensor::::new( @@ -3259,7 +3259,7 @@ pub fn max_pool(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap(); +/// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3279,6 +3279,7 @@ pub fn deconv< padding: &[(usize, usize)], output_padding: &[usize], stride: &[usize], + num_groups: usize, ) -> Result, CircuitError> { let has_bias = inputs.len() == 3; let (image, kernel) = (&inputs[0], &inputs[1]); @@ -3364,6 +3365,7 @@ pub fn deconv< &conv_input, &vec![(0, 0); conv_dim], &vec![1; conv_dim], + num_groups, )?; Ok(output) @@ -3395,7 +3397,7 @@ pub fn deconv< /// Some(&[0]), /// &[1], /// ).unwrap()); -/// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); +/// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3413,7 +3415,7 @@ pub fn deconv< /// &[2], /// ).unwrap()); /// -/// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); +/// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap(); /// let expected = Tensor::::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// @@ -3431,7 +3433,7 @@ pub fn deconv< /// &[4], /// ).unwrap()); /// -/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); +/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap(); /// let expected = Tensor::::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` @@ -3450,6 +3452,7 @@ pub fn conv< values: &[ValTensor], padding: &[(usize, usize)], stride: &[usize], + num_groups: usize, ) -> Result, CircuitError> { let has_bias = values.len() == 3; let (mut image, mut kernel) = (values[0].clone(), values[1].clone()); @@ -3480,6 +3483,11 @@ pub fn conv< region.increment(*assigned_len.iter().max().unwrap()); } + // if image is 3d add a dummy batch dimension + if image.dims().len() == kernel.dims().len() - 1 { + image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?; + } + let image_dims = image.dims(); let kernel_dims = kernel.dims(); @@ -3513,10 +3521,17 @@ pub fn conv< log::debug!("slides: {:?}", slides); - let num_groups = input_channels / kernel_dims[1]; let input_channels_per_group = input_channels / num_groups; let output_channels_per_group = output_channels / num_groups; + if output_channels_per_group == 0 || input_channels_per_group == 0 { + return Err(TensorError::DimMismatch(format!( + "Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}", + num_groups, input_channels, output_channels + )) + .into()); + } + log::debug!( "num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}", num_groups, @@ -3524,14 +3539,6 @@ pub fn conv< output_channels_per_group ); - if output_channels_per_group == 0 { - return Err(TensorError::DimMismatch(format!( - "Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead", - num_groups, num_groups, output_channels_per_group - )) - .into()); - } - let num_outputs = batch_size * num_groups * output_channels_per_group * slides.iter().product::(); diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs index 892bc36a0..8603dc4ee 100644 --- a/src/circuit/ops/poly.rs +++ b/src/circuit/ops/poly.rs @@ -33,6 +33,7 @@ pub enum PolyOp { Conv { padding: Vec<(usize, usize)>, stride: Vec, + group: usize, }, Downsample { axis: usize, @@ -43,6 +44,7 @@ pub enum PolyOp { padding: Vec<(usize, usize)>, output_padding: Vec, stride: Vec, + group: usize, }, Add, Sub, @@ -148,17 +150,25 @@ impl< PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes), PolyOp::Prod { .. } => "PROD".into(), PolyOp::Pow(_) => "POW".into(), - PolyOp::Conv { stride, padding } => { - format!("CONV (stride={:?}, padding={:?})", stride, padding) + PolyOp::Conv { + stride, + padding, + group, + } => { + format!( + "CONV (stride={:?}, padding={:?}, group={})", + stride, padding, group + ) } PolyOp::DeConv { stride, padding, output_padding, + group, } => { format!( - "DECONV (stride={:?}, padding={:?}, output_padding={:?})", - stride, padding, output_padding + "DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})", + stride, padding, output_padding, group ) } PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis), @@ -212,9 +222,18 @@ impl< PolyOp::Prod { axes, .. } => { layouts::prod_axes(config, region, values[..].try_into()?, axes)? } - PolyOp::Conv { padding, stride } => { - layouts::conv(config, region, values[..].try_into()?, padding, stride)? - } + PolyOp::Conv { + padding, + stride, + group, + } => layouts::conv( + config, + region, + values[..].try_into()?, + padding, + stride, + *group, + )?, PolyOp::GatherElements { dim, constant_idx } => { if let Some(idx) = constant_idx { tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into() @@ -261,6 +280,7 @@ impl< padding, output_padding, stride, + group, } => layouts::deconv( config, region, @@ -268,6 +288,7 @@ impl< padding, output_padding, stride, + *group, )?, PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?, PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?, diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index f1db70bc7..b0af969b0 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1050,6 +1050,7 @@ mod conv { Box::new(PolyOp::Conv { padding: vec![(1, 1); 2], stride: vec![2; 2], + group: 1, }), ) .map_err(|_| Error::Synthesis) @@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow { Box::new(PolyOp::Conv { padding: vec![(1, 1); 2], stride: vec![2; 2], + group: 1, }), ) .map_err(|_| Error::Synthesis) @@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow { Box::new(PolyOp::Conv { padding: vec![(1, 1); 2], stride: vec![2; 2], + group: 1, }), ) .map_err(|_| Error::Synthesis); diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index e30408de0..0b26f1f21 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -283,10 +283,7 @@ pub fn new_op_from_onnx( .flat_map(|x| x.out_scales()) .collect::>(); - let input_dims = inputs - .iter() - .flat_map(|x| x.out_dims()) - .collect::>(); + let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::>(); let mut replace_const = |scale: crate::Scale, index: usize, @@ -1192,7 +1189,13 @@ pub fn new_op_from_onnx( } } - SupportedOp::Linear(PolyOp::Conv { padding, stride }) + let group = conv_node.group; + + SupportedOp::Linear(PolyOp::Conv { + padding, + stride, + group, + }) } "Not" => SupportedOp::Linear(PolyOp::Not), "And" => SupportedOp::Linear(PolyOp::And), @@ -1247,6 +1250,7 @@ pub fn new_op_from_onnx( padding, output_padding: deconv_node.adjustments.to_vec(), stride, + group: deconv_node.group, }) } "Downsample" => {