diff --git a/candle-nn/src/encoding.rs b/candle-nn/src/encoding.rs index 51cb75dd78..38e2cc3bbc 100644 --- a/candle-nn/src/encoding.rs +++ b/candle-nn/src/encoding.rs @@ -77,7 +77,6 @@ use candle::{bail, DType, Result, Tensor, WithDType}; /// # Bails /// /// This method bails if: -/// - The input tensor has a rank greater than 3. /// - One of the index value is less than -1. /// - One of the index value is greater than or equal to the depth value. /// - The input data type is not `U8`, `U32`, or `I64`. @@ -91,195 +90,44 @@ pub fn one_hot( on_value: D, off_value: D, ) -> Result { - let dtype = indices.dtype(); - let rank = indices.rank(); - - match rank { - 0 => { - let mut v = vec![off_value; depth]; - match dtype { - DType::U8 => { - let vi = indices.to_vec0::()?; - set_usize_value(vi as usize, 0, depth, &mut v, on_value)?; - } - DType::U32 => { - let vi = indices.to_vec0::()?; - set_usize_value(vi as usize, 0, depth, &mut v, on_value)?; - } - DType::I64 => { - let vi = indices.to_vec0::()?; - set_int64_value(vi, 0, depth, &mut v, on_value)?; - } - d => unsupported_dtype(d)?, - }; - Tensor::from_vec(v, (depth,), indices.device()) - } - 1 => { - let dim1 = indices.dims1()?; - let mut v = vec![off_value; depth * dim1]; - - match dtype { - DType::U8 => { - let indices = indices.to_vec1::()?; - for (i, &index) in indices.iter().enumerate() { - set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?; - } - } - DType::U32 => { - let indices = indices.to_vec1::()?; - for (i, &index) in indices.iter().enumerate() { - set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?; - } - } - DType::I64 => { - let indices = indices.to_vec1::()?; - for (i, &index) in indices.iter().enumerate() { - set_int64_value(index, i * depth, depth, &mut v, on_value)?; - } - } - d => unsupported_dtype(d)?, - }; - Tensor::from_vec(v, (dim1, depth), indices.device()) + let mut target_shape = indices.dims().to_vec(); + target_shape.push(depth); + let indices = indices.flatten_all()?; + let mut out = vec![off_value; depth * indices.elem_count()]; + match indices.dtype() { + DType::U8 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } } - 2 => { - let (dim1, dim2) = indices.dims2()?; - let mut v = vec![off_value; depth * dim1 * dim2]; - let idx = |i: usize, j: usize, depth: usize, dim2: usize| -> usize { - i * depth * dim2 + j * depth - }; - let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j))); - match dtype { - DType::U8 => { - let index = indices.to_vec2::()?; - for (i, j) in iter { - set_usize_value( - index[i][j] as usize, - idx(i, j, depth, dim2), - depth, - &mut v, - on_value, - )?; - } - } - DType::U32 => { - let index = indices.to_vec2::()?; - for (i, j) in iter { - set_usize_value( - index[i][j] as usize, - idx(i, j, depth, dim2), - depth, - &mut v, - on_value, - )?; - } - } - DType::I64 => { - let index = indices.to_vec2::()?; - for (i, j) in iter { - set_int64_value( - index[i][j], - idx(i, j, depth, dim2), - depth, - &mut v, - on_value, - )?; - } - } - d => unsupported_dtype(d)?, - }; - Tensor::from_vec(v, (dim1, dim2, depth), indices.device()) + DType::U32 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } } - 3 => { - let (dim1, dim2, dim3) = indices.dims3()?; - let mut v = vec![off_value; depth * dim1 * dim2 * dim3]; - let idx = - |i: usize, j: usize, k: usize, depth: usize, dim2: usize, dim3: usize| -> usize { - i * depth * dim2 * dim3 + j * depth * dim3 + k * depth - }; - let iter = (0..dim1) - .flat_map(|i| (0..dim2).flat_map(move |j| (0..dim3).map(move |k| (i, j, k)))); - match dtype { - DType::U8 => { - let index = indices.to_vec3::()?; - for (i, j, k) in iter { - set_usize_value( - index[i][j][k] as usize, - idx(i, j, k, depth, dim2, dim3), - depth, - &mut v, - on_value, - )?; - } - } - DType::U32 => { - let index = indices.to_vec3::()?; - for (i, j, k) in iter { - set_usize_value( - index[i][j][k] as usize, - idx(i, j, k, depth, dim2, dim3), - depth, - &mut v, - on_value, - )?; - } - } - DType::I64 => { - let index = indices.to_vec3::()?; - for (i, j, k) in iter { - set_int64_value( - index[i][j][k], - idx(i, j, k, depth, dim2, dim3), - depth, - &mut v, - on_value, - )?; - } - } - d => unsupported_dtype(d)?, - }; - Tensor::from_vec(v, (dim1, dim2, dim3, depth), indices.device()) + DType::I64 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } } - _ => { - bail!("one_hot: rank {} is not supported.", rank) + dtype => { + bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64") } - } -} - -fn unsupported_dtype(dtype: DType) -> Result<()> { - bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64") -} - -// Set unsigned usize index values to the given value. -fn set_usize_value( - value: usize, - idx: usize, - depth: usize, - v: &mut Vec, - on_value: D, -) -> Result<()> { - if value >= depth { - bail!("one_hot: index value {value} exceeds depth {depth}") - } - let idx = idx + value; - if idx >= v.len() { - bail!("one_hot: index out of bounds {idx}, len {}", v.len()); - } - v[idx] = on_value; - Ok(()) + }; + Tensor::from_vec(out, target_shape, indices.device()) } -// Set signed integer index values to the given value. -// Signed integer values are only permitted for `-1` values. -// Otherwise, the value must be positive for unsigned usize values. -// This method will only case i64 values to usize values if the value is positive, -// otherwise the method will bail. -fn set_int64_value( - value: i64, - idx: usize, +fn set_at_index>( + value: I, + offset: usize, depth: usize, v: &mut Vec, on_value: D, ) -> Result<()> { + let value = value.into(); // Skip for an entire row of off_values if value == -1 { return Ok(()); @@ -289,5 +137,14 @@ fn set_int64_value( "one_hot: invalid negative index value {value}, expected a positive index value or -1" ); } - set_usize_value(value as usize, idx, depth, v, on_value) + let value = value as usize; + if value >= depth { + bail!("one_hot: index value {value} exceeds depth {depth}") + } + let idx = offset + value; + if idx >= v.len() { + bail!("one_hot: index out of bounds {idx}, len {}", v.len()); + } + v[idx] = on_value; + Ok(()) }