Skip to content

Commit

Permalink
fix codes
Browse files Browse the repository at this point in the history
Signed-off-by: whateveraname <12011319@mail.sustech.edu.cn>
  • Loading branch information
whateveraname committed Feb 3, 2024
1 parent cc33815 commit 7fef203
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 203 deletions.
115 changes: 50 additions & 65 deletions crates/service/src/algorithms/clustering/elkan_k_means.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use crate::prelude::*;
use crate::utils::cells::SyncUnsafeCell;
use crate::utils::vec2::Vec2;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use std::ops::{Index, IndexMut};
use std::sync::atomic::{AtomicUsize, Ordering};

pub struct ElkanKMeans<S: G> {
dims: u16,
Expand All @@ -30,36 +26,33 @@ impl<S: G> ElkanKMeans<S> {

let mut rand = StdRng::from_entropy();
let mut centroids = Vec2::new(dims, c);
let lowerbound = SyncUnsafeCell::new(Square::new(n, c));
let mut lowerbound = Square::new(n, c);
let mut upperbound = vec![F32::zero(); n];
let mut assign = vec![0usize; n];

centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]);

let weight = SyncUnsafeCell::new(vec![F32::infinity(); n]);
let mut weight = vec![F32::infinity(); n];
let mut dis = vec![F32::zero(); n];
for i in 0..c {
let mut sum = F32::zero();
(0..n).into_par_iter().for_each(|j| {
let dis = S::elkan_k_means_distance(&samples[j], &centroids[i]);
unsafe {
(&mut *lowerbound.get())[(j, i)] = dis;
}
if dis * dis < weight.get_ref()[j] {
unsafe {
(&mut *weight.get())[j] = dis * dis;
}
}
dis.par_iter_mut().enumerate().for_each(|(j, x)| {
*x = S::elkan_k_means_distance(&samples[j], &centroids[i]);
});
for j in 0..n {
sum += weight.get_ref()[j];
lowerbound[(j, i)] = dis[j];
if dis[j] * dis[j] < weight[j] {
weight[j] = dis[j] * dis[j];
}
sum += weight[j];
}
if i + 1 == c {
break;
}
let index = 'a: {
let mut choice = sum * rand.gen_range(0.0..1.0);
for j in 0..(n - 1) {
choice -= weight.get_ref()[j];
choice -= weight[j];
if choice <= F32::zero() {
break 'a j;
}
Expand All @@ -73,7 +66,7 @@ impl<S: G> ElkanKMeans<S> {
let mut minimal = F32::infinity();
let mut target = 0;
for j in 0..c {
let dis = lowerbound.get_ref()[(i, j)];
let dis = lowerbound[(i, j)];
if dis < minimal {
minimal = dis;
target = j;
Expand All @@ -82,7 +75,6 @@ impl<S: G> ElkanKMeans<S> {
assign[i] = target;
upperbound[i] = minimal;
}
let lowerbound = lowerbound.get_ref().clone();

Self {
dims,
Expand Down Expand Up @@ -131,26 +123,31 @@ impl<S: G> ElkanKMeans<S> {
let dims = self.dims;
let samples = &self.samples;
let rand = &mut self.rand;
let assign = &mut self.assign;
let centroids = &mut self.centroids;
let change = AtomicUsize::new(0);
let lowerbound = &mut self.lowerbound;
let upperbound = &mut self.upperbound;
let mut change = 0;
let n = samples.len();
if n <= c {
return self.quick_centroids();
}

// Step 1
let dist0 = SyncUnsafeCell::new(Square::new(c, c));
let mut dist0 = Square::new(c, c);
let mut sp = vec![F32::zero(); c];
(0..c).into_par_iter().for_each(|i| {
for j in i + 1..c {
let dis = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
unsafe {
(&mut *dist0.get())[(i, j)] = dis;
(&mut *dist0.get())[(j, i)] = dis;
}
dist0.v.par_iter_mut().enumerate().for_each(|(ii, v)| {
let i = ii / c;
let j = ii % c;
if i <= j {
*v = S::elkan_k_means_distance(&centroids[i], &centroids[j]) * 0.5;
}
});
let dist0 = dist0.get_ref().clone();
for i in 1..c {
for j in 0..i - 1 {
dist0[(i, j)] = dist0[(j, i)];
}
}
for i in 0..c {
let mut minimal = F32::infinity();
for j in 0..c {
Expand All @@ -165,54 +162,43 @@ impl<S: G> ElkanKMeans<S> {
sp[i] = minimal;
}

let assign = SyncUnsafeCell::new(self.assign.clone());
let lowerbound = SyncUnsafeCell::new(self.lowerbound.clone());
let upperbound = SyncUnsafeCell::new(self.upperbound.clone());

(0..n).into_par_iter().for_each(|i| {
// Step 2
if upperbound.get_ref()[i] <= sp[assign.get_ref()[i]] {
return;
let mut dis = vec![F32::zero(); n];
dis.par_iter_mut().enumerate().for_each(|(i, x)| {
if upperbound[i] > sp[assign[i]] {
*x = S::elkan_k_means_distance(&samples[i], &centroids[assign[i]]);
}
let mut minimal =
S::elkan_k_means_distance(&samples[i], &centroids[assign.get_ref()[i]]);
unsafe {
(&mut *lowerbound.get())[(i, assign.get_ref()[i])] = minimal;
(&mut *upperbound.get())[i] = minimal;
});
for i in 0..n {
// Step 2
if upperbound[i] <= sp[assign[i]] {
continue;
}
let mut minimal = dis[i];
lowerbound[(i, assign[i])] = minimal;
upperbound[i] = minimal;
// Step 3
for j in 0..c {
if j == assign.get_ref()[i] {
if j == assign[i] {
continue;
}
if upperbound.get_ref()[i] <= lowerbound.get_ref()[(i, j)] {
if upperbound[i] <= lowerbound[(i, j)] {
continue;
}
if upperbound.get_ref()[i] <= dist0[(assign.get_ref()[i], j)] {
if upperbound[i] <= dist0[(assign[i], j)] {
continue;
}
if minimal > lowerbound.get_ref()[(i, j)]
|| minimal > dist0[(assign.get_ref()[i], j)]
{
if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] {
let dis = S::elkan_k_means_distance(&samples[i], &centroids[j]);
unsafe {
(&mut *lowerbound.get())[(i, j)] = dis;
}
lowerbound[(i, j)] = dis;
if dis < minimal {
minimal = dis;
unsafe {
(&mut *assign.get())[i] = j;
(&mut *upperbound.get())[i] = dis;
}
change.fetch_add(1, Ordering::Relaxed);
assign[i] = j;
upperbound[i] = dis;
change += 1;
}
}
}
});

self.assign = assign.get_ref().clone();
self.lowerbound = lowerbound.get_ref().clone();
self.upperbound = upperbound.get_ref().clone();
}

// Step 4, 7
let old = std::mem::replace(centroids, Vec2::new(dims, c));
Expand Down Expand Up @@ -277,15 +263,14 @@ impl<S: G> ElkanKMeans<S> {
self.upperbound[i] += dist1[self.assign[i]];
}

change.load(Ordering::Relaxed) == 0
change == 0
}

pub fn finish(self) -> Vec2<S> {
self.centroids
}
}

#[derive(Clone)]
pub struct Square {
x: usize,
y: usize,
Expand Down
Loading

0 comments on commit 7fef203

Please sign in to comment.