Skip to content

Commit

Permalink
Simplify the one-hot implementation, support arbitrary rank. (hugging…
Browse files Browse the repository at this point in the history
…face#1514)

* Simplify the one-hot implementation, support arbitrary rank.

* More cleanup.
  • Loading branch information
LaurentMazare authored Jan 1, 2024
1 parent 41614b4 commit 135ae5f
Showing 1 changed file with 38 additions and 181 deletions.
219 changes: 38 additions & 181 deletions candle-nn/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ use candle::{bail, DType, Result, Tensor, WithDType};
/// # Bails
///
/// This method bails if:
/// - The input tensor has a rank greater than 3.
/// - One of the index value is less than -1.
/// - One of the index value is greater than or equal to the depth value.
/// - The input data type is not `U8`, `U32`, or `I64`.
Expand All @@ -91,195 +90,44 @@ pub fn one_hot<D: WithDType>(
on_value: D,
off_value: D,
) -> Result<Tensor> {
let dtype = indices.dtype();
let rank = indices.rank();

match rank {
0 => {
let mut v = vec![off_value; depth];
match dtype {
DType::U8 => {
let vi = indices.to_vec0::<u8>()?;
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
}
DType::U32 => {
let vi = indices.to_vec0::<u32>()?;
set_usize_value(vi as usize, 0, depth, &mut v, on_value)?;
}
DType::I64 => {
let vi = indices.to_vec0::<i64>()?;
set_int64_value(vi, 0, depth, &mut v, on_value)?;
}
d => unsupported_dtype(d)?,
};
Tensor::from_vec(v, (depth,), indices.device())
}
1 => {
let dim1 = indices.dims1()?;
let mut v = vec![off_value; depth * dim1];

match dtype {
DType::U8 => {
let indices = indices.to_vec1::<i64>()?;
for (i, &index) in indices.iter().enumerate() {
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
}
}
DType::U32 => {
let indices = indices.to_vec1::<i64>()?;
for (i, &index) in indices.iter().enumerate() {
set_usize_value(index as usize, i * depth, depth, &mut v, on_value)?;
}
}
DType::I64 => {
let indices = indices.to_vec1::<i64>()?;
for (i, &index) in indices.iter().enumerate() {
set_int64_value(index, i * depth, depth, &mut v, on_value)?;
}
}
d => unsupported_dtype(d)?,
};
Tensor::from_vec(v, (dim1, depth), indices.device())
let mut target_shape = indices.dims().to_vec();
target_shape.push(depth);
let indices = indices.flatten_all()?;
let mut out = vec![off_value; depth * indices.elem_count()];
match indices.dtype() {
DType::U8 => {
let indices = indices.to_vec1::<u8>()?;
for (i, &index) in indices.iter().enumerate() {
set_at_index(index, i * depth, depth, &mut out, on_value)?;
}
}
2 => {
let (dim1, dim2) = indices.dims2()?;
let mut v = vec![off_value; depth * dim1 * dim2];
let idx = |i: usize, j: usize, depth: usize, dim2: usize| -> usize {
i * depth * dim2 + j * depth
};
let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j)));
match dtype {
DType::U8 => {
let index = indices.to_vec2::<u8>()?;
for (i, j) in iter {
set_usize_value(
index[i][j] as usize,
idx(i, j, depth, dim2),
depth,
&mut v,
on_value,
)?;
}
}
DType::U32 => {
let index = indices.to_vec2::<u32>()?;
for (i, j) in iter {
set_usize_value(
index[i][j] as usize,
idx(i, j, depth, dim2),
depth,
&mut v,
on_value,
)?;
}
}
DType::I64 => {
let index = indices.to_vec2::<i64>()?;
for (i, j) in iter {
set_int64_value(
index[i][j],
idx(i, j, depth, dim2),
depth,
&mut v,
on_value,
)?;
}
}
d => unsupported_dtype(d)?,
};
Tensor::from_vec(v, (dim1, dim2, depth), indices.device())
DType::U32 => {
let indices = indices.to_vec1::<u32>()?;
for (i, &index) in indices.iter().enumerate() {
set_at_index(index, i * depth, depth, &mut out, on_value)?;
}
}
3 => {
let (dim1, dim2, dim3) = indices.dims3()?;
let mut v = vec![off_value; depth * dim1 * dim2 * dim3];
let idx =
|i: usize, j: usize, k: usize, depth: usize, dim2: usize, dim3: usize| -> usize {
i * depth * dim2 * dim3 + j * depth * dim3 + k * depth
};
let iter = (0..dim1)
.flat_map(|i| (0..dim2).flat_map(move |j| (0..dim3).map(move |k| (i, j, k))));
match dtype {
DType::U8 => {
let index = indices.to_vec3::<u8>()?;
for (i, j, k) in iter {
set_usize_value(
index[i][j][k] as usize,
idx(i, j, k, depth, dim2, dim3),
depth,
&mut v,
on_value,
)?;
}
}
DType::U32 => {
let index = indices.to_vec3::<u32>()?;
for (i, j, k) in iter {
set_usize_value(
index[i][j][k] as usize,
idx(i, j, k, depth, dim2, dim3),
depth,
&mut v,
on_value,
)?;
}
}
DType::I64 => {
let index = indices.to_vec3::<i64>()?;
for (i, j, k) in iter {
set_int64_value(
index[i][j][k],
idx(i, j, k, depth, dim2, dim3),
depth,
&mut v,
on_value,
)?;
}
}
d => unsupported_dtype(d)?,
};
Tensor::from_vec(v, (dim1, dim2, dim3, depth), indices.device())
DType::I64 => {
let indices = indices.to_vec1::<i64>()?;
for (i, &index) in indices.iter().enumerate() {
set_at_index(index, i * depth, depth, &mut out, on_value)?;
}
}
_ => {
bail!("one_hot: rank {} is not supported.", rank)
dtype => {
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
}
}
}

fn unsupported_dtype(dtype: DType) -> Result<()> {
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
}

// Set unsigned usize index values to the given value.
fn set_usize_value<D: WithDType>(
value: usize,
idx: usize,
depth: usize,
v: &mut Vec<D>,
on_value: D,
) -> Result<()> {
if value >= depth {
bail!("one_hot: index value {value} exceeds depth {depth}")
}
let idx = idx + value;
if idx >= v.len() {
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
}
v[idx] = on_value;
Ok(())
};
Tensor::from_vec(out, target_shape, indices.device())
}

// Set signed integer index values to the given value.
// Signed integer values are only permitted for `-1` values.
// Otherwise, the value must be positive for unsigned usize values.
// This method will only case i64 values to usize values if the value is positive,
// otherwise the method will bail.
fn set_int64_value<D: WithDType>(
value: i64,
idx: usize,
fn set_at_index<D: WithDType, I: Into<i64>>(
value: I,
offset: usize,
depth: usize,
v: &mut Vec<D>,
on_value: D,
) -> Result<()> {
let value = value.into();
// Skip for an entire row of off_values
if value == -1 {
return Ok(());
Expand All @@ -289,5 +137,14 @@ fn set_int64_value<D: WithDType>(
"one_hot: invalid negative index value {value}, expected a positive index value or -1"
);
}
set_usize_value(value as usize, idx, depth, v, on_value)
let value = value as usize;
if value >= depth {
bail!("one_hot: index value {value} exceeds depth {depth}")
}
let idx = offset + value;
if idx >= v.len() {
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
}
v[idx] = on_value;
Ok(())
}

0 comments on commit 135ae5f

Please sign in to comment.