kmeans
A C++ library for k-means
Loading...
Searching...
No Matches
RefineMiniBatch.hpp
Go to the documentation of this file.
1#ifndef KMEANS_REFINE_MINIBATCH_HPP
2#define KMEANS_REFINE_MINIBATCH_HPP
3
4#include <vector>
5#include <algorithm>
6#include <numeric>
7#include <cstdint>
8#include <stdexcept>
9#include <limits>
10#include <random>
11#include <type_traits>
12
13#include "aarand/aarand.hpp"
14
15#include "Refine.hpp"
16#include "Details.hpp"
17#include "QuickSearch.hpp"
18#include "is_edge_case.hpp"
19#include "parallelize.hpp"
20
27namespace kmeans {
28
37 int max_iterations = 100;
38
43 int batch_size = 500;
44
49 double max_change_proportion = 0.01;
50
56
60 uint64_t seed = 1234567890u;
61
66 int num_threads = 1;
67};
68
93template<typename Matrix_ = SimpleMatrix<double, int>, typename Cluster_ = int, typename Float_ = double>
94class RefineMiniBatch : public Refine<Matrix_, Cluster_, Float_> {
95public:
100
104 RefineMiniBatch() = default;
105
106public:
112 return my_options;
113 }
114
115private:
116 RefineMiniBatchOptions my_options;
117
118public:
120 auto nobs = data.num_observations();
121 if (internal::is_edge_case(nobs, ncenters)) {
122 return internal::process_edge_case(data, ncenters, centers, clusters);
123 }
124
125 int iter = 0, status = 0;
126 std::vector<uint64_t> total_sampled(ncenters); // holds the number of sampled observations across iterations, so we need a large integer.
127 std::vector<Cluster_> previous(nobs);
128 typedef decltype(nobs) Index_;
129 std::vector<uint64_t> last_changed(ncenters), last_sampled(ncenters); // holds the number of sampled/changed observation for the last few iterations.
130
131 Index_ actual_batch_size = nobs;
132 typedef typename std::conditional<std::is_signed<Index_>::value, int, unsigned int>::type SafeCompInt; // waiting for C++20's comparison functions...
133 if (static_cast<SafeCompInt>(actual_batch_size) > my_options.batch_size) {
134 actual_batch_size = my_options.batch_size;
135 }
136 std::vector<Index_> chosen(actual_batch_size);
137 std::mt19937_64 eng(my_options.seed);
138
139 auto ndim = data.num_dimensions();
140 size_t long_ndim = ndim;
141 internal::QuickSearch<Float_, Cluster_, decltype(ndim)> index;
142
143 for (iter = 1; iter <= my_options.max_iterations; ++iter) {
144 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
145 if (iter > 1) {
146 for (auto o : chosen) {
147 previous[o] = clusters[o];
148 }
149 }
150
151 index.reset(ndim, ncenters, centers);
152 parallelize(my_options.num_threads, actual_batch_size, [&](int, Index_ start, Index_ length) {
153 auto work = data.create_workspace(chosen.data() + start, length);
154 for (Index_ s = start, end = start + length; s < end; ++s) {
155 auto ptr = data.get_observation(work);
156 clusters[chosen[s]] = index.find(ptr);
157 }
158 });
159
160 // Updating the means for each cluster.
161 auto work = data.create_workspace(chosen.data(), actual_batch_size);
162 for (auto o : chosen) {
163 const auto c = clusters[o];
164 auto& n = total_sampled[c];
165 ++n;
166
167 Float_ mult = static_cast<Float_>(1)/static_cast<Float_>(n);
168 auto ccopy = centers + static_cast<size_t>(c) * long_ndim;
169 auto ocopy = data.get_observation(work);
170
171 for (decltype(ndim) d = 0; d < ndim; ++d, ++ocopy, ++ccopy) {
172 (*ccopy) += (static_cast<Float_>(*ocopy) - *ccopy) * mult; // cast to ensure consistent precision regardless of Matrix_::data_type.
173 }
174 }
175
176 // Checking for updates.
177 if (iter != 1) {
178 for (auto o : chosen) {
179 auto p = previous[o];
180 ++(last_sampled[p]);
181 auto c = clusters[o];
182 if (p != c) {
183 ++(last_sampled[c]);
184 ++(last_changed[p]);
185 ++(last_changed[c]);
186 }
187 }
188
189 if (iter % my_options.convergence_history == 1) {
190 bool too_many_changes = false;
191 for (Cluster_ c = 0; c < ncenters; ++c) {
192 if (static_cast<double>(last_changed[c]) >= static_cast<double>(last_sampled[c]) * my_options.max_change_proportion) {
193 too_many_changes = true;
194 break;
195 }
196 }
197
198 if (!too_many_changes) {
199 break;
200 }
201 std::fill(last_sampled.begin(), last_sampled.end(), 0);
202 std::fill(last_changed.begin(), last_changed.end(), 0);
203 }
204 }
205 }
206
207 if (iter == my_options.max_iterations + 1) {
208 status = 2;
209 }
210
211 // Run through all observations to make sure they have the latest cluster assignments.
212 index.reset(ndim, ncenters, centers);
213 parallelize(my_options.num_threads, nobs, [&](int, Index_ start, Index_ length) {
214 auto work = data.create_workspace(start, length);
215 for (Index_ s = start, end = start + length; s < end; ++s) {
216 auto ptr = data.get_observation(work);
217 clusters[s] = index.find(ptr);
218 }
219 });
220
221 std::vector<Index_> cluster_sizes(ncenters);
222 for (Index_ o = 0; o < nobs; ++o) {
223 ++cluster_sizes[clusters[o]];
224 }
225
226 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
227 return Details<Index_>(std::move(cluster_sizes), iter, status);
228 }
229};
230
231}
232
233#endif
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the variance partitioning method of Su and Dy (2007).
Definition InitializeVariancePartition.hpp:164
Implements the mini-batch algorithm for k-means clustering.
Definition RefineMiniBatch.hpp:94
Details< typename Matrix_::index_type > run(const Matrix_ &data, Cluster_ ncenters, Float_ *centers, Cluster_ *clusters) const
Definition RefineMiniBatch.hpp:119
RefineMiniBatchOptions & get_options()
Definition RefineMiniBatch.hpp:111
RefineMiniBatch(RefineMiniBatchOptions options)
Definition RefineMiniBatch.hpp:99
Interface for all k-means refinement algorithms.
Definition Refine.hpp:23
Namespace for k-means clustering.
Definition compute_wcss.hpp:12
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Options for RefineMiniBatch construction.
Definition RefineMiniBatch.hpp:32
int max_iterations
Definition RefineMiniBatch.hpp:37
double max_change_proportion
Definition RefineMiniBatch.hpp:49
int convergence_history
Definition RefineMiniBatch.hpp:55
uint64_t seed
Definition RefineMiniBatch.hpp:60
int num_threads
Definition RefineMiniBatch.hpp:66
int batch_size
Definition RefineMiniBatch.hpp:43