120 auto nobs = data.num_observations();
121 if (internal::is_edge_case(nobs, ncenters)) {
122 return internal::process_edge_case(data, ncenters, centers, clusters);
125 int iter = 0, status = 0;
126 std::vector<uint64_t> total_sampled(ncenters);
127 std::vector<Cluster_> previous(nobs);
128 typedef decltype(nobs) Index_;
129 std::vector<uint64_t> last_changed(ncenters), last_sampled(ncenters);
131 Index_ actual_batch_size = nobs;
132 typedef typename std::conditional<std::is_signed<Index_>::value, int,
unsigned int>::type SafeCompInt;
133 if (
static_cast<SafeCompInt
>(actual_batch_size) > my_options.
batch_size) {
136 std::vector<Index_> chosen(actual_batch_size);
137 std::mt19937_64 eng(my_options.
seed);
139 auto ndim = data.num_dimensions();
140 size_t long_ndim = ndim;
141 internal::QuickSearch<Float_, Cluster_,
decltype(ndim)> index;
144 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
146 for (
auto o : chosen) {
147 previous[o] = clusters[o];
151 index.reset(ndim, ncenters, centers);
153 auto work = data.create_workspace(chosen.data() + start, length);
154 for (Index_ s = start, end = start + length; s < end; ++s) {
155 auto ptr = data.get_observation(work);
156 clusters[chosen[s]] = index.find(ptr);
161 auto work = data.create_workspace(chosen.data(), actual_batch_size);
162 for (
auto o : chosen) {
163 const auto c = clusters[o];
164 auto& n = total_sampled[c];
167 Float_ mult =
static_cast<Float_
>(1)/
static_cast<Float_
>(n);
168 auto ccopy = centers +
static_cast<size_t>(c) * long_ndim;
169 auto ocopy = data.get_observation(work);
171 for (
decltype(ndim) d = 0; d < ndim; ++d, ++ocopy, ++ccopy) {
172 (*ccopy) += (
static_cast<Float_
>(*ocopy) - *ccopy) * mult;
178 for (
auto o : chosen) {
179 auto p = previous[o];
181 auto c = clusters[o];
190 bool too_many_changes =
false;
191 for (Cluster_ c = 0; c < ncenters; ++c) {
192 if (
static_cast<double>(last_changed[c]) >=
static_cast<double>(last_sampled[c]) * my_options.
max_change_proportion) {
193 too_many_changes =
true;
198 if (!too_many_changes) {
201 std::fill(last_sampled.begin(), last_sampled.end(), 0);
202 std::fill(last_changed.begin(), last_changed.end(), 0);
212 index.reset(ndim, ncenters, centers);
214 auto work = data.create_workspace(start, length);
215 for (Index_ s = start, end = start + length; s < end; ++s) {
216 auto ptr = data.get_observation(work);
217 clusters[s] = index.find(ptr);
221 std::vector<Index_> cluster_sizes(ncenters);
222 for (Index_ o = 0; o < nobs; ++o) {
223 ++cluster_sizes[clusters[o]];
226 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
227 return Details<Index_>(std::move(cluster_sizes), iter, status);