125 Details<Index_> run(
const Matrix_& data, Cluster_ ncenters, Float_* centers, Cluster_* clusters)
const {
126 Index_ nobs = data.num_observations();
127 if (internal::is_edge_case(nobs, ncenters)) {
128 return internal::process_edge_case(data, ncenters, centers, clusters);
131 int iter = 0, status = 0;
132 std::vector<uint64_t> total_sampled(ncenters);
133 std::vector<Cluster_> previous(nobs);
134 std::vector<uint64_t> last_changed(ncenters), last_sampled(ncenters);
136 Index_ actual_batch_size = nobs;
137 typedef typename std::conditional<std::is_signed<Index_>::value, int,
unsigned int>::type SafeCompInt;
138 if (
static_cast<SafeCompInt
>(actual_batch_size) > my_options.
batch_size) {
141 std::vector<Index_> chosen(actual_batch_size);
142 std::mt19937_64 eng(my_options.
seed);
144 size_t ndim = data.num_dimensions();
145 internal::QuickSearch<Float_, Cluster_> index;
148 aarand::sample(nobs, actual_batch_size, chosen.data(), eng);
150 for (
auto o : chosen) {
151 previous[o] = clusters[o];
155 index.reset(ndim, ncenters, centers);
157 auto work = data.new_extractor(chosen.data() + start, length);
158 for (Index_ s = start, end = start + length; s < end; ++s) {
159 auto ptr = work->get_observation();
160 clusters[chosen[s]] = index.find(ptr);
165 auto work = data.new_extractor(chosen.data(), actual_batch_size);
166 for (
auto o : chosen) {
167 const auto c = clusters[o];
168 auto& n = total_sampled[c];
171 Float_ mult =
static_cast<Float_
>(1)/
static_cast<Float_
>(n);
172 auto ccopy = centers +
static_cast<size_t>(c) * ndim;
173 auto ocopy = work->get_observation();
175 for (
size_t d = 0; d < ndim; ++d) {
176 ccopy[d] += (
static_cast<Float_
>(ocopy[d]) - ccopy[d]) * mult;
182 for (
auto o : chosen) {
183 auto p = previous[o];
185 auto c = clusters[o];
194 bool too_many_changes =
false;
195 for (Cluster_ c = 0; c < ncenters; ++c) {
196 if (
static_cast<double>(last_changed[c]) >=
static_cast<double>(last_sampled[c]) * my_options.
max_change_proportion) {
197 too_many_changes =
true;
202 if (!too_many_changes) {
205 std::fill(last_sampled.begin(), last_sampled.end(), 0);
206 std::fill(last_changed.begin(), last_changed.end(), 0);
216 index.reset(ndim, ncenters, centers);
218 auto work = data.new_extractor(start, length);
219 for (Index_ s = start, end = start + length; s < end; ++s) {
220 auto ptr = work->get_observation();
221 clusters[s] = index.find(ptr);
225 std::vector<Index_> cluster_sizes(ncenters);
226 for (Index_ o = 0; o < nobs; ++o) {
227 ++cluster_sizes[clusters[o]];
230 internal::compute_centroids(data, ncenters, centers, clusters, cluster_sizes);
231 return Details<Index_>(std::move(cluster_sizes), iter, status);