1#ifndef KMEANS_HARTIGAN_WONG_HPP
2#define KMEANS_HARTIGAN_WONG_HPP
9#include "sanisizer/sanisizer.hpp"
13#include "QuickSearch.hpp"
15#include "compute_centroids.hpp"
16#include "is_edge_case.hpp"
61namespace RefineHartiganWong_internal {
135template<
typename Index_>
138 Index_ my_last_observation = 0;
140 static constexpr int current_optimal_transfer = -1;
141 static constexpr int previous_optimal_transfer = -2;
142 static constexpr int ancient_history = -3;
146 int my_last_iteration = ancient_history;
149 void reset(
const Index_ total_obs) {
150 if (my_last_iteration > current_optimal_transfer) {
151 my_last_observation = total_obs;
152 my_last_iteration = previous_optimal_transfer;
153 }
else if (my_last_iteration == current_optimal_transfer) {
155 my_last_iteration = previous_optimal_transfer;
157 my_last_iteration = ancient_history;
161 void set_optimal(
const Index_ obs) {
162 my_last_observation = obs;
163 my_last_iteration = current_optimal_transfer;
167 void set_quick(
const int iter,
const Index_ obs) {
168 my_last_iteration = iter;
169 my_last_observation = obs;
173 bool changed_after(
const int iter,
const Index_ obs)
const {
174 if (my_last_iteration == iter) {
175 return my_last_observation > obs;
177 return my_last_iteration > iter;
181 bool changed_after_or_at(
const int iter,
const Index_ obs)
const {
182 if (my_last_iteration == iter) {
183 return my_last_observation >= obs;
185 return my_last_iteration > iter;
189 bool is_live(
const Index_ obs)
const {
190 return changed_after(previous_optimal_transfer, obs);
194template<
typename Float_,
typename Index_,
typename Cluster_>
197 std::vector<Cluster_> best_destination_cluster;
198 std::vector<Index_> cluster_sizes;
200 std::vector<Float_> loss_multiplier;
201 std::vector<Float_> gain_multiplier;
202 std::vector<Float_> wcss_loss;
204 std::vector<UpdateHistory<Index_> > update_history;
206 Index_ optra_steps_since_last_transfer = 0;
209 Workspace(Index_ nobs, Cluster_ ncenters) :
211 best_destination_cluster(sanisizer::cast<decltype(I(best_destination_cluster.size()))>(nobs)),
212 cluster_sizes(sanisizer::cast<decltype(I(cluster_sizes.size()))>(ncenters)),
213 loss_multiplier(sanisizer::cast<decltype(I(loss_multiplier.size()))>(ncenters)),
214 gain_multiplier(sanisizer::cast<decltype(I(gain_multiplier.size()))>(ncenters)),
215 wcss_loss(sanisizer::cast<decltype(I(wcss_loss.size()))>(nobs)),
216 update_history(sanisizer::cast<decltype(I(update_history.size()))>(ncenters))
220template<
typename Data_,
typename Float_>
221Float_ squared_distance_from_cluster(
const Data_*
const data,
const Float_*
const center,
const std::size_t ndim) {
223 for (
decltype(I(ndim)) d = 0; d < ndim; ++d) {
224 const Float_ delta =
static_cast<Float_
>(data[d]) - center[d];
225 output += delta * delta;
230template<
class Matrix_,
typename Cluster_,
typename Float_>
231void find_closest_two_centers(
233 const Cluster_ ncenters,
234 const Float_*
const centers,
235 Cluster_*
const best_cluster,
236 std::vector<Cluster_>& best_destination_cluster,
239 const auto ndim = data.num_dimensions();
243 const internal::QuickSearch<Float_, Cluster_> index(ndim, ncenters, centers);
245 const auto nobs = data.num_observations();
246 parallelize(nthreads, nobs, [&](
const int,
const decltype(I(nobs)) start,
const decltype(I(nobs)) length) ->
void {
247 auto matwork = data.new_extractor(start, length);
248 for (
decltype(I(start)) obs = start, end = start + length; obs < end; ++obs) {
249 const auto optr = matwork->get_observation();
250 const auto res2 = index.find2(optr);
251 best_cluster[obs] = res2.first;
252 best_destination_cluster[obs] = res2.second;
257template<
typename Float_>
258constexpr Float_ big_number() {
262template<
typename Data_,
typename Index_,
typename Cluster_,
typename Float_>
264 const std::size_t ndim,
265 const Data_*
const obs_ptr,
269 Float_*
const centers,
270 Cluster_*
const best_cluster,
271 Workspace<Float_, Index_, Cluster_>& work)
275 const Float_ al1 = work.cluster_sizes[l1], alw = al1 - 1;
276 const Float_ al2 = work.cluster_sizes[l2], alt = al2 + 1;
278 for (
decltype(I(ndim)) d = 0; d < ndim; ++d) {
279 const Float_ oval = obs_ptr[d];
280 auto& c1 = centers[sanisizer::nd_offset<std::size_t>(d, ndim, l1)];
281 c1 = (c1 * al1 - oval) / alw;
282 auto& c2 = centers[sanisizer::nd_offset<std::size_t>(d, ndim, l2)];
283 c2 = (c2 * al2 + oval) / alt;
286 --work.cluster_sizes[l1];
287 ++work.cluster_sizes[l2];
289 work.gain_multiplier[l1] = alw / al1;
290 work.loss_multiplier[l1] = (alw > 1 ? alw / (alw - 1) : big_number<Float_>());
291 work.loss_multiplier[l2] = alt / al2;
292 work.gain_multiplier[l2] = alt / (alt + 1);
294 best_cluster[obs_id] = l2;
295 work.best_destination_cluster[obs_id] = l1;
305template<
class Matrix_,
typename Cluster_,
typename Float_>
306bool optimal_transfer(
307 const Matrix_& data, Workspace<Float_, Index<Matrix_>, Cluster_>& work,
308 const Cluster_ ncenters,
309 Float_*
const centers,
310 Cluster_*
const best_cluster,
313 const auto nobs = data.num_observations();
314 const auto ndim = data.num_dimensions();
315 auto matwork = data.new_extractor();
317 for (
decltype(I(nobs)) obs = 0; obs < nobs; ++obs) {
318 ++work.optra_steps_since_last_transfer;
320 const auto l1 = best_cluster[obs];
321 if (work.cluster_sizes[l1] != 1) {
322 const auto obs_ptr = matwork->get_observation(obs);
334 auto& wcss_loss = work.wcss_loss[obs];
335 const auto l1_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, l1);
336 wcss_loss = squared_distance_from_cluster(obs_ptr, l1_ptr, ndim) * work.loss_multiplier[l1];
339 auto l2 = work.best_destination_cluster[obs];
340 const auto original_l2 = l2;
341 const auto l2_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, l2);
342 auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
344 const auto update_destination_cluster = [&](
const Cluster_ cen) ->
void {
345 auto cen_ptr = centers + sanisizer::product_unsafe<std::size_t>(ndim, cen);
346 auto candidate = squared_distance_from_cluster(obs_ptr, cen_ptr, ndim) * work.gain_multiplier[cen];
347 if (candidate < wcss_gain) {
348 wcss_gain = candidate;
366 if (all_live || work.update_history[l1].is_live(obs)) {
367 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
368 if (cen != l1 && cen != original_l2) {
369 update_destination_cluster(cen);
373 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
374 if (cen != l1 && cen != original_l2 && work.update_history[cen].is_live(obs)) {
375 update_destination_cluster(cen);
381 if (wcss_gain >= wcss_loss) {
382 work.best_destination_cluster[obs] = l2;
384 work.optra_steps_since_last_transfer = 0;
385 work.update_history[l1].set_optimal(obs);
386 work.update_history[l2].set_optimal(obs);
387 transfer_point(ndim, obs_ptr, obs, l1, l2, centers, best_cluster, work);
393 if (work.optra_steps_since_last_transfer == nobs) {
412template<
class Matrix_,
typename Cluster_,
typename Float_>
413std::pair<bool, bool> quick_transfer(
415 Workspace<Float_, Index<Matrix_>, Cluster_>& work,
416 Float_*
const centers,
417 Cluster_*
const best_cluster,
418 const int quick_iterations)
420 bool had_transfer =
false;
422 const auto nobs = data.num_observations();
423 const auto ndim = data.num_dimensions();
424 auto matwork = data.new_extractor();
426 decltype(I(nobs)) steps_since_last_quick_transfer = 0;
428 for (
int it = 0; it < quick_iterations; ++it) {
429 const int prev_it = it - 1;
431 for (
decltype(I(nobs)) obs = 0; obs < nobs; ++obs) {
432 ++steps_since_last_quick_transfer;
433 const auto l1 = best_cluster[obs];
435 if (work.cluster_sizes[l1] != 1) {
436 decltype(I(matwork->get_observation(obs))) obs_ptr = NULL;
445 auto& history1 = work.update_history[l1];
446 if (history1.changed_after_or_at(prev_it, obs)) {
447 const auto l1_ptr = centers + sanisizer::product_unsafe<std::size_t>(l1, ndim);
448 obs_ptr = matwork->get_observation(obs);
449 work.wcss_loss[obs] = squared_distance_from_cluster(obs_ptr, l1_ptr, ndim) * work.loss_multiplier[l1];
456 const auto l2 = work.best_destination_cluster[obs];
457 auto& history2 = work.update_history[l2];
458 if (history1.changed_after(prev_it, obs) || history2.changed_after(prev_it, obs)) {
459 if (obs_ptr == NULL) {
460 obs_ptr = matwork->get_observation(obs);
462 const auto l2_ptr = centers + sanisizer::product_unsafe<std::size_t>(l2, ndim);
463 const auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
465 if (wcss_gain < work.wcss_loss[obs]) {
467 steps_since_last_quick_transfer = 0;
468 history1.set_quick(it, obs);
469 history2.set_quick(it, obs);
470 transfer_point(ndim, obs_ptr, obs, l1, l2, centers, best_cluster, work);
475 if (steps_since_last_quick_transfer == nobs) {
478 return std::make_pair(had_transfer,
false);
483 return std::make_pair(had_transfer,
true);
526template<
typename Index_,
typename Data_,
typename Cluster_,
typename Float_,
class Matrix_ = Matrix<Index_, Data_> >
555 Details<Index_> run(
const Matrix_& data,
const Cluster_ ncenters, Float_*
const centers, Cluster_*
const clusters)
const {
556 const auto nobs = data.num_observations();
557 if (internal::is_edge_case(nobs, ncenters)) {
558 return internal::process_edge_case(data, ncenters, centers, clusters);
561 RefineHartiganWong_internal::Workspace<Float_, Index_, Cluster_> work(nobs, ncenters);
563 RefineHartiganWong_internal::find_closest_two_centers(data, ncenters, centers, clusters, work.best_destination_cluster, my_options.
num_threads);
564 for (Index_ obs = 0; obs < nobs; ++obs) {
565 ++work.cluster_sizes[clusters[obs]];
567 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
569 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
570 const Float_ num = work.cluster_sizes[cen];
571 work.gain_multiplier[cen] = num / (num + 1);
572 work.loss_multiplier[cen] = (num > 1 ? num / (num - 1) : RefineHartiganWong_internal::big_number<Float_>());
578 const bool finished = RefineHartiganWong_internal::optimal_transfer(data, work, ncenters, centers, clusters, (iter == 0));
583 const auto quick_status = RefineHartiganWong_internal::quick_transfer(
597 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
599 if (quick_status.second) {
612 if (quick_status.first) {
613 work.optra_steps_since_last_transfer = 0;
616 for (Cluster_ c = 0; c < ncenters; ++c) {
617 work.update_history[c].reset(nobs);
627 return Details(std::move(work.cluster_sizes), iter, ifault);
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the Hartigan-Wong algorithm for k-means clustering.
Definition RefineHartiganWong.hpp:527
RefineHartiganWongOptions & get_options()
Definition RefineHartiganWong.hpp:547
RefineHartiganWong(RefineHartiganWongOptions options)
Definition RefineHartiganWong.hpp:532
RefineHartiganWong()=default
Interface for k-means refinement algorithms.
Definition Refine.hpp:30
virtual Details< Index_ > run(const Matrix_ &data, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters) const =0
Perform k-means clustering.
Definition compute_wcss.hpp:16
void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineHartiganWong.
Definition RefineHartiganWong.hpp:30
bool quit_on_quick_transfer_convergence_failure
Definition RefineHartiganWong.hpp:49
int max_iterations
Definition RefineHartiganWong.hpp:35
int num_threads
Definition RefineHartiganWong.hpp:55
int max_quick_transfer_iterations
Definition RefineHartiganWong.hpp:41