kmeans
k-means clustering in C++
Loading...
Searching...
No Matches
RefineMiniBatch.hpp
Go to the documentation of this file.
1#ifndef KMEANS_REFINE_MINIBATCH_HPP
2#define KMEANS_REFINE_MINIBATCH_HPP
3
4#include <vector>
5#include <algorithm>
6#include <cstddef>
7#include <random>
8
9#include "sanisizer/sanisizer.hpp"
10#include "aarand/aarand.hpp"
11
12#include "Refine.hpp"
13#include "Details.hpp"
14#include "QuickSearch.hpp"
15#include "is_edge_case.hpp"
16#include "parallelize.hpp"
17
24namespace kmeans {
25
29typedef std::mt19937_64 RefineMiniBatchRng;
30
39 int max_iterations = 100;
40
45 int batch_size = 500;
46
51 double max_change_proportion = 0.01;
52
58
62 typename RefineMiniBatchRng::result_type seed = sanisizer::cap<typename RefineMiniBatchRng::result_type>(1234567890);
63
68 int num_threads = 1;
69};
70
99template<typename Index_, typename Data_, typename Cluster_, typename Float_, typename Matrix_ = Matrix<Index_, Data_> >
100class RefineMiniBatch : public Refine<Index_, Data_, Cluster_, Float_, Matrix_> {
101public:
105 RefineMiniBatch(RefineMiniBatchOptions options) : my_options(std::move(options)) {}
106
110 RefineMiniBatch() = default;
111
112public:
118 return my_options;
119 }
120
121private:
122 RefineMiniBatchOptions my_options;
123
124public:
128 Details<Index_> run(const Matrix_& data, const Cluster_ ncenters, Float_* const centers, Cluster_* const clusters) const {
129 const auto nobs = data.num_observations();
130 if (internal::is_edge_case(nobs, ncenters)) {
131 return internal::process_edge_case(data, ncenters, centers, clusters);
132 }
133
134 auto total_sampled = sanisizer::create<std::vector<unsigned long long> >(ncenters); // holds the number of sampled observations across iterations, so we need a large integer.
135 auto last_changed = sanisizer::create<std::vector<unsigned long long> >(ncenters); // holds the number of sampled/changed observation for the last few iterations.
136 auto last_sampled = sanisizer::create<std::vector<unsigned long long> >(ncenters);
137 auto previous = sanisizer::create<std::vector<Cluster_> >(nobs);
138
139 const decltype(I(nobs)) actual_batch_size = sanisizer::min(nobs, my_options.batch_size);
140 sanisizer::cast<std::size_t>(actual_batch_size); // check that static_cast for new_extractor() calls will be safe.
141 auto chosen = sanisizer::create<std::vector<Index_> >(actual_batch_size);
142 RefineMiniBatchRng eng(my_options.seed);
143
144 const auto ndim = data.num_dimensions();
145 internal::QuickSearch<Float_, Cluster_> index;
146
147 decltype(I(my_options.max_iterations)) iter = 0;
148 for (; iter < my_options.max_iterations; ++iter) {
149 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
150 if (iter > 0) {
151 for (const auto o : chosen) {
152 previous[o] = clusters[o];
153 }
154 }
155
156 index.reset(ndim, ncenters, centers);
157 parallelize(my_options.num_threads, actual_batch_size, [&](const int, const Index_ start, const Index_ length) -> void {
158 auto work = data.new_extractor(chosen.data() + start, static_cast<std::size_t>(length));
159 for (Index_ s = start, end = start + length; s < end; ++s) {
160 const auto ptr = work->get_observation();
161 clusters[chosen[s]] = index.find(ptr);
162 }
163 });
164
165 // Updating the means for each cluster.
166 auto work = data.new_extractor(chosen.data(), static_cast<std::size_t>(chosen.size()));
167 for (const auto o : chosen) {
168 const auto c = clusters[o];
169 auto& n = total_sampled[c];
170 ++n;
171
172 const auto ocopy = work->get_observation();
173 for (decltype(I(ndim)) d = 0; d < ndim; ++d) {
174 auto& curcenter = centers[sanisizer::nd_offset<std::size_t>(d, ndim, c)];
175 curcenter += (static_cast<Float_>(ocopy[d]) - curcenter) / n; // cast to ensure consistent precision regardless of Matrix_::data_type.
176 }
177 }
178
179 // Checking for updates.
180 if (iter != 0) {
181 for (const auto o : chosen) {
182 const auto p = previous[o];
183 ++(last_sampled[p]);
184 const auto c = clusters[o];
185 if (p != c) {
186 ++(last_sampled[c]);
187 ++(last_changed[p]);
188 ++(last_changed[c]);
189 }
190 }
191
192 if (iter % my_options.convergence_history == 0) {
193 bool too_many_changes = false;
194 for (Cluster_ c = 0; c < ncenters; ++c) {
195 if (static_cast<double>(last_changed[c]) >= static_cast<double>(last_sampled[c]) * my_options.max_change_proportion) {
196 too_many_changes = true;
197 break;
198 }
199 }
200
201 if (!too_many_changes) {
202 break;
203 }
204 std::fill(last_sampled.begin(), last_sampled.end(), 0);
205 std::fill(last_changed.begin(), last_changed.end(), 0);
206 }
207 }
208 }
209
210 // Run through all observations to make sure they have the latest cluster assignments.
211 index.reset(ndim, ncenters, centers);
212 parallelize(my_options.num_threads, nobs, [&](const int, const Index_ start, const Index_ length) -> void {
213 auto work = data.new_extractor(start, length);
214 for (Index_ s = start, end = start + length; s < end; ++s) {
215 const auto ptr = work->get_observation();
216 clusters[s] = index.find(ptr);
217 }
218 });
219
220 auto cluster_sizes = sanisizer::create<std::vector<Index_> >(ncenters);
221 for (Index_ o = 0; o < nobs; ++o) {
222 ++cluster_sizes[clusters[o]];
223 }
224 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
225
226 int status = 0;
227 if (iter == my_options.max_iterations) {
228 status = 2;
229 } else {
230 ++iter; // make it 1-based.
231 }
232 return Details<Index_>(std::move(cluster_sizes), iter, status);
233 }
237};
238
239}
240
241#endif
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the mini-batch algorithm for k-means clustering.
Definition RefineMiniBatch.hpp:100
RefineMiniBatchOptions & get_options()
Definition RefineMiniBatch.hpp:117
RefineMiniBatch(RefineMiniBatchOptions options)
Definition RefineMiniBatch.hpp:105
Interface for k-means refinement algorithms.
Definition Refine.hpp:30
virtual Details< Index_ > run(const Matrix_ &data, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters) const =0
Perform k-means clustering.
Definition compute_wcss.hpp:16
std::mt19937_64 RefineMiniBatchRng
Definition RefineMiniBatch.hpp:29
void parallelize(const int num_workers, const Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineMiniBatch.
Definition RefineMiniBatch.hpp:34
RefineMiniBatchRng::result_type seed
Definition RefineMiniBatch.hpp:62
int max_iterations
Definition RefineMiniBatch.hpp:39
double max_change_proportion
Definition RefineMiniBatch.hpp:51
int convergence_history
Definition RefineMiniBatch.hpp:57
int num_threads
Definition RefineMiniBatch.hpp:68
int batch_size
Definition RefineMiniBatch.hpp:45