kmeans
A C++ library for k-means
Loading...
Searching...
No Matches
RefineHartiganWong.hpp
Go to the documentation of this file.
1#ifndef KMEANS_HARTIGAN_WONG_HPP
2#define KMEANS_HARTIGAN_WONG_HPP
3
4#include <vector>
5#include <algorithm>
6#include <numeric>
7#include <cstdint>
8
9#include "Refine.hpp"
10#include "Details.hpp"
11#include "QuickSearch.hpp"
12#include "parallelize.hpp"
13#include "compute_centroids.hpp"
14#include "is_edge_case.hpp"
15
22namespace kmeans {
23
52
57
58/*
59 * The class below represents 'ncp', which has a dual interpretation in the
60 * original Fortran implementation:
61 *
62 * - In the optimal-transfer stage, NCP(L) stores the step at which cluster L
63 * was last updated. Each step is just the observation index as the optimal
64 * transfer only does one pass through the dataset.
65 * - In the quick-transfer stage, NCP(L) stores the step at which cluster L is
66 * last updated plus M (i.e., the number of observations). Here, the step
67 * corresponding to an observation will be 'M * X + obs' for some integer X
68 * >= 0, where X is the iteration of the quick transfer.
69 *
70 * Note that these two definitions bleed into each other as the NCP(L) set by
71 * optimal_transfer is still being used in the first few iterations of
72 * quick_transfer before it eventually gets written. The easiest way to
73 * interpret this is to consider the optimal transfer as "iteration -1" from
74 * the perspective of the quick transfer iterations.
75 *
76 * In short, this data structure specifies whether a cluster was modified
77 * within the last M steps. This counts steps in both optimal_transfer and
78 * quick_transfer, and considers modifications from both calls.
79 */
80template<typename Index_>
81class UpdateHistory {
82private:
83 /*
84 * The problem with the original implementation is that the integers are
85 * expected to hold 'max_quick_iterations * M'. For a templated integer
86 * type, that might not be possible, so instead we split it into two
87 * vectors; one holds the last iteration at which the cluster was modified,
88 * the other holds the last observation used in the modification.
89 */
90 Index_ my_last_observation = 0;
91
93
94 static constexpr int init = -3;
95 static constexpr int unchanged = -2;
96
97public:
98 void set_unchanged() {
100 }
101
102 // We treat the optimal_transfer as "iteration -1" here.
103 void set_optimal(Index_ obs) {
106 }
107
108 // Here, iter should be from '[0, max_quick_transfer_iterations)'.
109 void set_quick(int iter, Index_ obs) {
112 }
113
114public:
115 bool is_unchanged() const {
117 }
118
119public:
120 bool changed_after(int iter, Index_ obs) const {
121 if (my_last_iteration == iter) {
122 return my_last_observation > obs;
123 } else {
124 return my_last_iteration > iter;
125 }
126 }
127
128 bool changed_after_or_at(int iter, Index_ obs) const {
129 if (my_last_iteration == iter) {
130 return my_last_observation >= obs;
131 } else {
132 return my_last_iteration > iter;
133 }
134 }
135};
136
137/*
138 * The class below represents 'live', which has a tricky interpretation.
139 *
140 * - Before each optimal transfer call, LIVE(L) stores the observation at which
141 * cluster L was updated in the _previous_ call.
142 * - During the optimal transfer call, LIVE(L) is updated to the observation at
143 * which L was updated in this call, plus M (i.e., number of observations).
144 * - After the optimal transfer call, LIVE(L) is updated by subtracting M, so
145 * that the interpretation is correct in the next call.
146 *
147 * It basically tells us whether there was a recent transfer (optimal or quick)
148 * within the last M steps of optimal_transfer. If so, the cluster is "live".
149 */
150template<typename Index_>
151class LiveStatus {
152private:
153 enum class Event : uint8_t { NONE, PAST_OPT, CURRENT_OPT, QUICK, INIT };
154
155 /* The problem with the original implementation is that LIVE(L) needs to
156 * store at least 2*M, which might cause overflows in Index_. To avoid
157 * this, we split this information into two vectors:
158 *
159 * - 'my_had_recent_transfer' specifies specifies whether a transfer
160 * occurred in the current optimal_transfer call, or in the immediately
161 * preceding quick_transfer call. If this > PAST_OPT, the cluster is
162 * definitely live; if it is == PAST_OPT, it may or may not be live.
163 * - 'my_last_optimal_transfer' has two interpretations:
164 * - If 'my_had_recent_transfer == PAST_OPT', it specifies the
165 * observation at which the last transfer occurred in previous
166 * optimal_transfer call. If this is greater than the current
167 * observation, the cluster is live.
168 * - If 'my_had_recent_transfer == CURRENT_OPT', it specifies the
169 * observation at which the last transfer occurred in the current
170 * optimal_transfer call.
171 * - Otherwise it is undefined and should not be used.
172 *
173 * One might think that 'LiveStatus::my_last_optimal_transfer' is redundant
174 * with 'UpdateHistory::my_last_observation', but the former only tracks
175 * optimal transfers while the latter includes quick transfers.
176 */
177 Event my_had_recent_transfer = Event::INIT;
178 Index_ my_last_optimal_transfer = 0;
179
180public:
181 bool is_live(Index_ obs) const {
182 if (my_had_recent_transfer == Event::PAST_OPT) {
183 return my_last_optimal_transfer > obs;
184 } else {
185 return my_had_recent_transfer > Event::PAST_OPT;
186 }
187 }
188
189 void mark_current(Index_ obs) {
190 my_had_recent_transfer = Event::CURRENT_OPT;
191 my_last_optimal_transfer = obs;
192 }
193
194 void reset(bool was_quick_transferred) {
195 if (was_quick_transferred) {
196 my_had_recent_transfer = Event::QUICK;
197 } else if (my_had_recent_transfer == Event::CURRENT_OPT) {
198 my_had_recent_transfer = Event::PAST_OPT;
199 } else {
200 my_had_recent_transfer = Event::NONE;
201 }
202 }
203};
204
205template<typename Float_, typename Index_, typename Cluster_>
206struct Workspace {
207 // Array arguments in the same order as supplied to R's kmns_ function.
208 std::vector<Cluster_> second_best_cluster; // i.e., ic2
209 std::vector<Index_> cluster_sizes; // i.e., nc
210
211 std::vector<Float_> loss_multiplier; // i.e., an1
212 std::vector<Float_> gain_multiplier; // i.e., an2
213 std::vector<Float_> wcss_loss; // i.e., d
214
215 std::vector<UpdateHistory<Index_> > update_history; // i.e., ncp
216 std::vector<uint8_t> was_quick_transferred; // i.e., itran
217 std::vector<LiveStatus<Index_> > live_set; // i.e., live
218
219 Index_ optra_steps_since_last_transfer = 0; // i.e., indx
220
221public:
222 Workspace(Index_ nobs, Cluster_ ncenters) :
223 // Sizes taken from the .Fortran() call in stats::kmeans().
224 second_best_cluster(nobs),
225 cluster_sizes(ncenters),
226 loss_multiplier(ncenters),
227 gain_multiplier(ncenters),
228 wcss_loss(nobs),
229
230 // All the other bits and pieces.
231 update_history(ncenters),
232 was_quick_transferred(ncenters),
233 live_set(ncenters)
234 {}
235};
236
237template<typename Data_, typename Float_, typename Dim_>
239 Float_ output = 0;
240 for (decltype(ndim) dim = 0; dim < ndim; ++dim, ++data, ++center) {
241 Float_ delta = static_cast<Float_>(*data) - *center; // cast to float for consistent precision regardless of Data_.
242 output += delta * delta;
243 }
244 return output;
245}
246
247template<class Matrix_, typename Cluster_, typename Float_>
248void find_closest_two_centers(const Matrix_& data, Cluster_ ncenters, const Float_* centers, Cluster_* best_cluster, std::vector<Cluster_>& second_best_cluster, int nthreads) {
249 auto ndim = data.num_dimensions();
250
251 // We assume that there are at least two centers here, otherwise we should
252 // have detected that this was an edge case in RefineHartiganWong::run.
253 internal::QuickSearch<Float_, Cluster_, decltype(ndim)> index(ndim, ncenters, centers);
254
255 auto nobs = data.num_observations();
256 typedef typename Matrix_::index_type Index_;
257 parallelize(nthreads, nobs, [&](int, Index_ start, Index_ length) -> void {
258 auto matwork = data.create_workspace(start, length);
259 for (Index_ obs = start, end = start + length; obs < end; ++obs) {
260 auto optr = data.get_observation(matwork);
261 auto res2 = index.find2(optr);
262 best_cluster[obs] = res2.first;
263 second_best_cluster[obs] = res2.second;
264 }
265 });
266}
267
268template<typename Float_>
269constexpr Float_ big_number() {
270 return 1e30; // Some very big number.
271}
272
273template<typename Dim_, typename Data_, typename Index_, typename Cluster_, typename Float_>
275 // Yes, casts to float are deliberate here, so that the
276 // multipliers can be computed correctly.
277 Float_ al1 = work.cluster_sizes[l1], alw = al1 - 1;
278 Float_ al2 = work.cluster_sizes[l2], alt = al2 + 1;
279
280 size_t long_ndim = ndim;
281 auto copy1 = centers + static_cast<size_t>(l1) * long_ndim; // cast to avoid overflow.
282 auto copy2 = centers + static_cast<size_t>(l2) * long_ndim;
283 for (decltype(ndim) dim = 0; dim < ndim; ++dim, ++copy1, ++copy2, ++obs_ptr) {
284 Float_ oval = *obs_ptr; // cast to float for consistent precision regardless of Data_.
285 *copy1 = (*copy1 * al1 - oval) / alw;
286 *copy2 = (*copy2 * al2 + oval) / alt;
287 }
288
289 --work.cluster_sizes[l1];
290 ++work.cluster_sizes[l2];
291
292 work.gain_multiplier[l1] = alw / al1;
293 work.loss_multiplier[l1] = (alw > 1 ? alw / (alw - 1) : big_number<Float_>());
294 work.loss_multiplier[l2] = alt / al2;
295 work.gain_multiplier[l2] = alt / (alt + 1);
296
298 work.second_best_cluster[obs_id] = l1;
299}
300
301/* ALGORITHM AS 136.1 APPL. STATIST. (1979) VOL.28, NO.1
302 * This is the OPtimal TRAnsfer stage.
303 * ----------------------
304 * Each point is re-assigned, if necessary, to the cluster that will induce a
305 * maximum reduction in the within-cluster sum of squares. In this stage,
306 * there is only one pass through the data.
307 */
308template<class Matrix_, typename Cluster_, typename Float_>
310 auto nobs = data.num_observations();
311 auto ndim = data.num_dimensions();
312 auto matwork = data.create_workspace();
313 size_t long_ndim = ndim;
314
315 for (decltype(nobs) obs = 0; obs < nobs; ++obs) {
316 ++work.optra_steps_since_last_transfer;
317
318 auto l1 = best_cluster[obs];
319 if (work.cluster_sizes[l1] != 1) {
320 auto obs_ptr = data.get_observation(obs, matwork);
321
322 // The original Fortran implementation only recomputed the WCSS
323 // loss of an observation if its cluster had experienced an optimal
324 // transfer for an earlier observation. In theory, this sounds
325 // great to avoid recomputation, but the existing WCSS loss was
326 // computed in a running fashion during the quick transfers. This
327 // makes them susceptible to accumulation of numerical errors in
328 // the centroids; even after the centroids are freshly recomputed
329 // (in the run() loop), we still have errors in the loss values.
330 // So, we simplify matters and improve accuracy by just recomputing
331 // the loss all the time, which doesn't take too much extra effort.
332 auto& wcss_loss = work.wcss_loss[obs];
333 auto l1_ptr = centers + long_ndim * static_cast<size_t>(l1); // cast to avoid overflow.
335
336 // Find the cluster with minimum WCSS gain.
337 auto l2 = work.second_best_cluster[obs];
338 auto original_l2 = l2;
339 auto l2_ptr = centers + long_ndim * static_cast<size_t>(l2); // cast to avoid overflow.
340 auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
341
342 auto check_best_cluster = [&](Cluster_ cen) {
343 auto cen_ptr = centers + long_ndim * static_cast<size_t>(cen); // cast to avoid overflow.
345 if (candidate < wcss_gain) {
347 l2 = cen;
348 }
349 };
350
351 // If the best cluster is live, we need to consider all other clusters.
352 // Otherwise, we only need to consider other live clusters for transfer.
353 auto& live1 = work.live_set[l1];
354 if (live1.is_live(obs)) {
355 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
356 if (cen != l1 && cen != original_l2) {
358 }
359 }
360 } else {
361 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
362 if (cen != l1 && cen != original_l2 && work.live_set[cen].is_live(obs)) {
364 }
365 }
366 }
367
368 // Deciding whether to make the transfer based on the change to the WCSS.
369 if (wcss_gain >= wcss_loss) {
370 work.second_best_cluster[obs] = l2;
371 } else {
372 work.optra_steps_since_last_transfer = 0;
373
374 live1.mark_current(obs);
375 work.live_set[l2].mark_current(obs);
376 work.update_history[l1].set_optimal(obs);
377 work.update_history[l2].set_optimal(obs);
378
380 }
381 }
382
383 // Stop if we've iterated through the entire dataset and no transfer of
384 // any kind took place, be it optimal or quick.
385 if (work.optra_steps_since_last_transfer == nobs) {
386 return true;
387 }
388 }
389
390 return false;
391}
392
393/* ALGORITHM AS 136.2 APPL. STATIST. (1979) VOL.28, NO.1
394 * This is the Quick TRANsfer stage.
395 * --------------------
396 * IC1(I) is the cluster which point I currently belongs to.
397 * IC2(I) is the cluster which point I is most likely to be transferred to.
398 *
399 * For each point I, IC1(I) & IC2(I) are switched, if necessary, to reduce
400 * within-cluster sum of squares. The cluster centres are updated after each
401 * step. In this stage, we loop through the data until no further change is to
402 * take place, or we hit an iteration limit, whichever is first.
403 */
404template<class Matrix_, typename Cluster_, typename Float_>
405std::pair<bool, bool> quick_transfer(
406 const Matrix_& data,
408 Float_* centers,
411{
412 bool had_transfer = false;
413 std::fill(work.was_quick_transferred.begin(), work.was_quick_transferred.end(), 0);
414
415 auto nobs = data.num_observations();
416 auto matwork = data.create_workspace();
417 auto ndim = data.num_dimensions();
418 size_t long_ndim = data.num_dimensions();
419
420 typedef decltype(nobs) Index_;
422
423 for (int it = 0; it < quick_iterations; ++it) {
424 int prev_it = it - 1;
425
426 for (decltype(nobs) obs = 0; obs < nobs; ++obs) {
428 auto l1 = best_cluster[obs];
429
430 if (work.cluster_sizes[l1] != 1) {
431 const typename Matrix_::data_type* obs_ptr = NULL;
432
433 // Need to update the WCSS loss if the cluster was updated recently.
434 // Otherwise, we must have already updated the WCSS in a previous
435 // iteration of the outermost loop, so this can be skipped.
436 //
437 // Note that we use changed_at_or_after; if the same
438 // observation was changed in the previous iteration of the
439 // outermost loop, its WCSS loss won't have been updated yet.
440 auto& history1 = work.update_history[l1];
441 if (history1.changed_after_or_at(prev_it, obs)) {
442 auto l1_ptr = centers + static_cast<size_t>(l1) * long_ndim; // cast to avoid overflow.
443 obs_ptr = data.get_observation(obs, matwork);
444 work.wcss_loss[obs] = squared_distance_from_cluster(obs_ptr, l1_ptr, ndim) * work.loss_multiplier[l1];
445 }
446
447 // If neither the best or second-best clusters have changed
448 // after the previous iteration that we visited this
449 // observation, then there's no point reevaluating the
450 // transfer, because nothing's going to be different anyway.
451 auto l2 = work.second_best_cluster[obs];
452 auto& history2 = work.update_history[l2];
453 if (history1.changed_after(prev_it, obs) || history2.changed_after(prev_it, obs)) {
454 if (obs_ptr == NULL) {
455 obs_ptr = data.get_observation(obs, matwork);
456 }
457 auto l2_ptr = centers + static_cast<size_t>(l2) * long_ndim; // cast to avoid overflow.
458 auto wcss_gain = squared_distance_from_cluster(obs_ptr, l2_ptr, ndim) * work.gain_multiplier[l2];
459
460 if (wcss_gain < work.wcss_loss[obs]) {
461 had_transfer = true;
463
464 work.was_quick_transferred[l1] = true;
465 work.was_quick_transferred[l2] = true;
466
467 history1.set_quick(it, obs);
468 history2.set_quick(it, obs);
469
471 }
472 }
473 }
474
476 // Quit early if no transfer occurred within the past 'nobs'
477 // steps, as we've already converged for each observation.
478 return std::make_pair(had_transfer, false);
479 }
480 }
481 }
482
483 return std::make_pair(had_transfer, true);
484}
485
486}
520template<typename Matrix_ = SimpleMatrix<double, int>, typename Cluster_ = int, typename Float_ = double>
521class RefineHartiganWong : public Refine<Matrix_, Cluster_, Float_> {
522public:
527
532
533private:
534 RefineHartiganWongOptions my_options;
535 typedef typename Matrix_::index_type Index_;
536
537public:
543 return my_options;
544 }
545
546public:
547 Details<Index_> run(const Matrix_& data, Cluster_ ncenters, Float_* centers, Cluster_* clusters) const {
548 auto nobs = data.num_observations();
549 if (internal::is_edge_case(nobs, ncenters)) {
550 return internal::process_edge_case(data, ncenters, centers, clusters);
551 }
552
553 RefineHartiganWong_internal::Workspace<Float_, Index_, Cluster_> work(nobs, ncenters);
554
555 RefineHartiganWong_internal::find_closest_two_centers(data, ncenters, centers, clusters, work.second_best_cluster, my_options.num_threads);
556 for (Index_ obs = 0; obs < nobs; ++obs) {
557 ++work.cluster_sizes[clusters[obs]];
558 }
559 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
560
561 for (Cluster_ cen = 0; cen < ncenters; ++cen) {
562 Float_ num = work.cluster_sizes[cen]; // yes, cast is deliberate here so that the multipliers can be computed correctly.
563 work.gain_multiplier[cen] = num / (num + 1);
564 work.loss_multiplier[cen] = (num > 1 ? num / (num - 1) : RefineHartiganWong_internal::big_number<Float_>());
565 }
566
567 int iter = 0;
568 int ifault = 0;
569 while ((++iter) <= my_options.max_iterations) {
570 bool finished = RefineHartiganWong_internal::optimal_transfer(data, work, ncenters, centers, clusters);
571 if (finished) {
572 break;
573 }
574
575 auto quick_status = RefineHartiganWong_internal::quick_transfer(
576 data,
577 work,
578 centers,
579 clusters,
581 );
582
583 // Recomputing the centroids to avoid accumulation of numerical
584 // errors after many transfers (e.g., adding a whole bunch of
585 // values and then subtracting them again leaves behind some
586 // cancellation error). Note that we don't have to do this if
587 // 'finished = true' as this means that there was no transfer of
588 // any kind in the final pass through the dataset.
589 internal::compute_centroids(data, ncenters, centers, clusters, work.cluster_sizes);
590
591 if (quick_status.second) { // Hit the quick transfer iteration limit.
593 ifault = 4;
594 break;
595 }
596 } else {
597 // If quick transfer converged and there are only two clusters,
598 // there is no need to re-enter the optimal transfer stage.
599 if (ncenters == 2) {
600 break;
601 }
602 }
603
604 if (quick_status.first) { // At least one quick transfer was performed.
605 work.optra_steps_since_last_transfer = 0;
606 }
607
608 for (auto& u : work.update_history) {
609 u.set_unchanged();
610 }
611
612 for (Cluster_ c = 0; c < ncenters; ++c) {
613 work.live_set[c].reset(work.was_quick_transferred[c]);
614 }
615 }
616
617 if (iter == my_options.max_iterations + 1) {
618 ifault = 2;
619 }
620
621 return Details(std::move(work.cluster_sizes), iter, ifault);
622 }
623};
624
625}
626
627#endif
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the variance partitioning method of Su and Dy (2007).
Definition InitializeVariancePartition.hpp:164
Implements the Hartigan-Wong algorithm for k-means clustering.
Definition RefineHartiganWong.hpp:521
RefineHartiganWong(RefineHartiganWongOptions options)
Definition RefineHartiganWong.hpp:526
RefineHartiganWongOptions & get_options()
Definition RefineHartiganWong.hpp:542
Details< Index_ > run(const Matrix_ &data, Cluster_ ncenters, Float_ *centers, Cluster_ *clusters) const
Definition RefineHartiganWong.hpp:547
Interface for all k-means refinement algorithms.
Definition Refine.hpp:23
Namespace for k-means clustering.
Definition compute_wcss.hpp:12
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineHartiganWong.
Definition RefineHartiganWong.hpp:27
bool quit_on_quick_transfer_convergence_failure
Definition RefineHartiganWong.hpp:44
int max_iterations
Definition RefineHartiganWong.hpp:32
int num_threads
Definition RefineHartiganWong.hpp:50
int max_quick_transfer_iterations
Definition RefineHartiganWong.hpp:38