kmeans
A C++ library for k-means
Loading...
Searching...
No Matches
RefineMiniBatch.hpp
Go to the documentation of this file.
1#ifndef KMEANS_REFINE_MINIBATCH_HPP
2#define KMEANS_REFINE_MINIBATCH_HPP
3
4#include <vector>
5#include <algorithm>
6#include <numeric>
7#include <cstdint>
8#include <stdexcept>
9#include <limits>
10#include <random>
11#include <type_traits>
12
13#include "aarand/aarand.hpp"
14
15#include "Refine.hpp"
16#include "Details.hpp"
17#include "QuickSearch.hpp"
18#include "is_edge_case.hpp"
19#include "parallelize.hpp"
20
27namespace kmeans {
28
37 int max_iterations = 100;
38
43 int batch_size = 500;
44
49 double max_change_proportion = 0.01;
50
56
60 uint64_t seed = 1234567890u;
61
66 int num_threads = 1;
67};
68
96template<typename Index_, typename Data_, typename Cluster_, typename Float_, typename Matrix_ = Matrix<Index_, Data_> >
97class RefineMiniBatch : public Refine<Index_, Data_, Cluster_, Float_, Matrix_> {
98public:
102 RefineMiniBatch(RefineMiniBatchOptions options) : my_options(std::move(options)) {}
103
107 RefineMiniBatch() = default;
108
109public:
115 return my_options;
116 }
117
118private:
119 RefineMiniBatchOptions my_options;
120
121public:
125 Details<Index_> run(const Matrix_& data, Cluster_ ncenters, Float_* centers, Cluster_* clusters) const {
126 Index_ nobs = data.num_observations();
127 if (internal::is_edge_case(nobs, ncenters)) {
128 return internal::process_edge_case(data, ncenters, centers, clusters);
129 }
130
131 int iter = 0, status = 0;
132 std::vector<uint64_t> total_sampled(ncenters); // holds the number of sampled observations across iterations, so we need a large integer.
133 std::vector<Cluster_> previous(nobs);
134 std::vector<uint64_t> last_changed(ncenters), last_sampled(ncenters); // holds the number of sampled/changed observation for the last few iterations.
135
136 Index_ actual_batch_size = nobs;
137 typedef typename std::conditional<std::is_signed<Index_>::value, int, unsigned int>::type SafeCompInt; // waiting for C++20's comparison functions...
138 if (static_cast<SafeCompInt>(actual_batch_size) > my_options.batch_size) {
139 actual_batch_size = my_options.batch_size;
140 }
141 std::vector<Index_> chosen(actual_batch_size);
142 std::mt19937_64 eng(my_options.seed);
143
144 size_t ndim = data.num_dimensions();
145 internal::QuickSearch<Float_, Cluster_> index;
146
147 for (iter = 1; iter <= my_options.max_iterations; ++iter) {
148 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
149 if (iter > 1) {
150 for (auto o : chosen) {
151 previous[o] = clusters[o];
152 }
153 }
154
155 index.reset(ndim, ncenters, centers);
156 parallelize(my_options.num_threads, actual_batch_size, [&](int, Index_ start, Index_ length) -> void {
157 auto work = data.new_extractor(chosen.data() + start, length);
158 for (Index_ s = start, end = start + length; s < end; ++s) {
159 auto ptr = work->get_observation();
160 clusters[chosen[s]] = index.find(ptr);
161 }
162 });
163
164 // Updating the means for each cluster.
165 auto work = data.new_extractor(chosen.data(), actual_batch_size);
166 for (auto o : chosen) {
167 const auto c = clusters[o];
168 auto& n = total_sampled[c];
169 ++n;
170
171 Float_ mult = static_cast<Float_>(1)/static_cast<Float_>(n);
172 auto ccopy = centers + static_cast<size_t>(c) * ndim; // cast to size_t to avoid overflow.
173 auto ocopy = work->get_observation();
174
175 for (size_t d = 0; d < ndim; ++d) {
176 ccopy[d] += (static_cast<Float_>(ocopy[d]) - ccopy[d]) * mult; // cast to ensure consistent precision regardless of Matrix_::data_type.
177 }
178 }
179
180 // Checking for updates.
181 if (iter != 1) {
182 for (auto o : chosen) {
183 auto p = previous[o];
184 ++(last_sampled[p]);
185 auto c = clusters[o];
186 if (p != c) {
187 ++(last_sampled[c]);
188 ++(last_changed[p]);
189 ++(last_changed[c]);
190 }
191 }
192
193 if (iter % my_options.convergence_history == 1) {
194 bool too_many_changes = false;
195 for (Cluster_ c = 0; c < ncenters; ++c) {
196 if (static_cast<double>(last_changed[c]) >= static_cast<double>(last_sampled[c]) * my_options.max_change_proportion) {
197 too_many_changes = true;
198 break;
199 }
200 }
201
202 if (!too_many_changes) {
203 break;
204 }
205 std::fill(last_sampled.begin(), last_sampled.end(), 0);
206 std::fill(last_changed.begin(), last_changed.end(), 0);
207 }
208 }
209 }
210
211 if (iter == my_options.max_iterations + 1) {
212 status = 2;
213 }
214
215 // Run through all observations to make sure they have the latest cluster assignments.
216 index.reset(ndim, ncenters, centers);
217 parallelize(my_options.num_threads, nobs, [&](int, Index_ start, Index_ length) -> void {
218 auto work = data.new_extractor(start, length);
219 for (Index_ s = start, end = start + length; s < end; ++s) {
220 auto ptr = work->get_observation();
221 clusters[s] = index.find(ptr);
222 }
223 });
224
225 std::vector<Index_> cluster_sizes(ncenters);
226 for (Index_ o = 0; o < nobs; ++o) {
227 ++cluster_sizes[clusters[o]];
228 }
229
230 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
231 return Details<Index_>(std::move(cluster_sizes), iter, status);
232 }
236};
237
238}
239
240#endif
Report detailed clustering statistics.
Interface for k-means refinement.
Implements the mini-batch algorithm for k-means clustering.
Definition RefineMiniBatch.hpp:97
RefineMiniBatchOptions & get_options()
Definition RefineMiniBatch.hpp:114
RefineMiniBatch(RefineMiniBatchOptions options)
Definition RefineMiniBatch.hpp:102
Interface for k-means refinement algorithms.
Definition Refine.hpp:26
virtual Details< Index_ > run(const Matrix_ &data, Cluster_ num_centers, Float_ *centers, Cluster_ *clusters) const =0
Namespace for k-means clustering.
Definition compute_wcss.hpp:12
void parallelize(int num_workers, Task_ num_tasks, Run_ run_task_range)
Definition parallelize.hpp:28
Utilities for parallelization.
Additional statistics from the k-means algorithm.
Definition Details.hpp:20
Options for RefineMiniBatch construction.
Definition RefineMiniBatch.hpp:32
int max_iterations
Definition RefineMiniBatch.hpp:37
double max_change_proportion
Definition RefineMiniBatch.hpp:49
int convergence_history
Definition RefineMiniBatch.hpp:55
uint64_t seed
Definition RefineMiniBatch.hpp:60
int num_threads
Definition RefineMiniBatch.hpp:66
int batch_size
Definition RefineMiniBatch.hpp:43