Octopus
kmeans_clustering.F90
Go to the documentation of this file.
1!! Copyright (C) 2024. A Buccheri.
2!!
3!! This program is free software; you can redistribute it and/or modify
4!! it under the terms of the GNU General Public License as published by
5!! the Free Software Foundation; either version 2, or (at your option)
6!! any later version.
7!!
8!! This program is distributed in the hope that it will be useful,
9!! but WITHOUT ANY WARRANTY; without even the implied warranty of
10!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11!! GNU General Public License for more details.
12!!
13!! You should have received a copy of the GNU General Public License
14!! along with this program; if not, write to the Free Software
15!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16!! 02110-1301, USA.
17
18#include "global.h"
19
21 use, intrinsic :: iso_fortran_env
22 use debug_oct_m
24 use global_oct_m
25 use mesh_oct_m
27 use mpi_oct_m
32 use space_oct_m
33 use sort_oct_m
35 implicit none
36 private
37
42
43contains
44
45 ! TODO(Alex) Issue #1004. Implement `assign_points_to_centroids` for periodic boundary conditions
46
72 subroutine assign_points_to_centroids_finite_bc(mesh, centroids, ip_to_ic)
73 class(mesh_t), intent(in) :: mesh
74 real(real64), intent(in) :: centroids(:, :)
75 integer, intent(out) :: ip_to_ic(:)
76
77 integer :: ip, ic, icen
78 integer :: n_centroids
79 real(real64) :: min_dist, dist
80 real(real64), allocatable :: point(:)
81
82 ! Some small finite tolerance is required to distinguish degenerate points, else
83 ! `if (dist < min_dist)` can vary on different hardware due to numerical noise.
84 ! One could equally choose tol to be some percentage of the grid spacing.
85 real(real64), parameter :: tol = 1.0e-13_real64
86
88
89 ! Grid to centroid index map should have size of the grid
90 assert(size(ip_to_ic) == mesh%np)
91
92 n_centroids = size(centroids, 2)
93 safe_allocate(point(1:size(centroids, 1)))
94
95 !$omp parallel do default(shared) private(point, icen, min_dist, dist)
96 do ip = 1, mesh%np
97 ip_to_ic(ip) = 0
98 ! Compute which centroid, grid point `ip` is closest to
99 point = mesh%x(ip, :)
100 icen = 1
101 min_dist = sum((centroids(:, 1) - point(:))**2)
102 do ic = 2, n_centroids
103 dist = sum((centroids(:, ic) - point(:))**2)
104 if (dist < min_dist - tol) then
105 min_dist = dist
106 icen = ic
107 endif
108 enddo
109 ip_to_ic(ip) = icen
110 enddo
111 !$omp end parallel do
112
113 safe_deallocate_a(point)
114
116
118
119
130 subroutine update_centroids(mesh, weight, ip_to_ic, centroids)
131 class(mesh_t), intent(in) :: mesh
132 real(real64), intent(in) :: weight(:)
133 integer, intent(in) :: ip_to_ic(:)
134 real(real64), contiguous, intent(inout) :: centroids(:, :)
135 ! Out: Updated centroid positions
136
137 integer :: n_centroids, ic, ip
138 real(real64) :: one_over_denom
139 real(real64), allocatable :: denominator(:)
140
141 push_sub(update_centroids)
142
143 ! The indexing of weight and grid must be consistent => must belong to the same spatial distribution
144 ! This does not explicit assert this, but if the grid sizes differ, this is a clear indicator of a problem
145 assert(mesh%np == size(weight))
146
147 n_centroids = size(centroids, 2)
148 safe_allocate(denominator(1:n_centroids))
149
150 do ic = 1, n_centroids
151 centroids(:, ic) = 0._real64
152 denominator(ic) = 0._real64
153 enddo
154
155 !$omp parallel do private(ic) reduction(+ : centroids, denominator)
156 do ip = 1, mesh%np
157 ic = ip_to_ic(ip)
158 ! Initially accumulate the numerator in `centroids`
159 centroids(:, ic) = centroids(:, ic) + (mesh%x(ip, :) * weight(ip))
160 denominator(ic) = denominator(ic) + weight(ip)
161 enddo
162 !$omp end parallel do
163
164 ! Gather contributions to numerator and denominator of all centroids, from all domains of the mesh/grid
165 call mesh%allreduce(centroids)
166 call mesh%allreduce(denominator)
167
168 ! If division by zero occurs here it implies that the sum(weight) = 0 for all grid points
169 ! in cluster ic. This can occur if the initial centroid is poorly chosen at a point with no
170 !! associated weight (such as the vacuum of a crystal cell)
171 !$omp parallel do private(one_over_denom) reduction(* : centroids)
172 do ic = 1, n_centroids
173 one_over_denom = 1._real64 / denominator(ic)
174 centroids(:, ic) = centroids(:, ic) * one_over_denom
175 enddo
176 !$omp end parallel do
177
178 safe_deallocate_a(denominator)
179 pop_sub(update_centroids)
180
181 end subroutine update_centroids
182
183
188 subroutine compute_grid_difference(points, updated_points, tol, points_differ)
189 real(real64), intent(in) :: points(:, :)
190 real(real64), intent(in) :: updated_points(:, :)
191 real(real64), intent(in) :: tol
192 logical, intent(out) :: points_differ(:)
193
194 integer :: ip, n_dim
195 real(real64), allocatable :: diff(:)
196
198
199 n_dim = size(points, 1)
200 allocate(diff(n_dim))
201
202 !$omp parallel do default(shared) private(diff)
203 do ip = 1, size(points, 2)
204 diff(:) = abs(updated_points(:, ip) - points(:, ip))
205 points_differ(ip) = any(diff > tol)
206 enddo
207 !$omp end parallel do
208
209 if(debug%info) then
210 call report_differences_in_grids(points, updated_points, tol, points_differ)
211 endif
212
214
215 end subroutine compute_grid_difference
216
217
219 subroutine report_differences_in_grids(points, updated_points, tol, points_differ)
220 real(real64), intent(in) :: points(:, :)
221 real(real64), intent(in) :: updated_points(:, :)
222 real(real64), intent(in) :: tol
223 logical, intent(in) :: points_differ(:)
224
225 integer, allocatable :: indices(:)
226 integer :: i, j, n_unconverged, ndim
227 character(len=50) :: f_string
228 real(real64), allocatable :: diff(:)
229
231
232 indices = pack([(i, i=1,size(points_differ))], points_differ)
233 n_unconverged = size(indices)
234 ndim = size(points, 1)
235 allocate(diff(ndim))
236
237 write(f_string, '(A, I1, A, I1, A, I1, A)') '(', &
238 & ndim, '(F16.10, X), ', &
239 & ndim, '(F16.10, X), ', &
240 & ndim, '(F16.10, X), F16.10)'
241
242 write(message(1), '(a)') "# Current Point , Prior Point , |ri - r_{i-1}| , tol"
243 call messages_info(1)
244 do j = 1, n_unconverged
245 i = indices(j)
246 diff(:) = abs(updated_points(:, i) - points(:, i))
247 write(message(1), f_string) updated_points(:, i), points(:, i), diff, tol
248 call messages_info(1)
249 enddo
250 write(message(1), *) "Summary:", n_unconverged, "of out", size(points, 2), "are not converged"
251 call messages_info(1)
252
254
255 end subroutine report_differences_in_grids
256
257
283 subroutine weighted_kmeans(space, mesh, weight, centroids, n_iter, centroid_tol, discretize, inertia)
284 class(space_t), intent(in) :: space
285 class(mesh_t), intent(in) :: mesh
286 real(real64), intent(in) :: weight(:)
287 real(real64), contiguous, intent(inout) :: centroids(:, :)
288 ! Out: Final centroids
289 integer, optional, intent(in ) :: n_iter
290 real(real64), optional, intent(in ) :: centroid_tol
291 logical, optional, intent(in ) :: discretize
292 real(real64), optional, intent(out) :: inertia
293
294 logical :: discretize_centroids
295 integer :: n_iterations, n_centroid, i
296 real(real64) :: tol
297 integer, allocatable :: ip_to_ic(:)
298 real(real64), allocatable :: prior_centroids(:, :)
299 logical, allocatable :: points_differ(:)
300
301 push_sub(weighted_kmeans)
302
303 n_iterations = optional_default(n_iter, 200)
304 tol = optional_default(centroid_tol, 1.e-4_real64)
305 discretize_centroids = optional_default(discretize, .true.)
306
307 ! Should use a positive number of iterations
308 assert(n_iterations >= 1)
309 ! Number of weights inconsistent with number of grid points
310 assert(size(weight) == mesh%np)
311 ! Spatial dimensions of centroids array is inconsistent
312 assert(size(centroids, 1) == space%dim)
313 ! Assignment of points to centroids only implemented for finite BCs
314 assert(.not. space%is_periodic())
315
316 ! Work arrays
317 n_centroid = size(centroids, 2)
318 safe_allocate_source(prior_centroids(space%dim, size(centroids, 2)), centroids)
319 safe_allocate(ip_to_ic(1:mesh%np))
320 safe_allocate(points_differ(1:n_centroid))
321
322 write(message(1), '(a)') 'Debug: Performing weighted Kmeans clustering '
323 call messages_info(1, debug_only=.true.)
324
325 do i = 1, n_iterations
326 write(message(1), '(a, I3)') 'Debug: Iteration ', i
327 call messages_info(1, debug_only=.true.)
328 ! TODO(Alex) Issue #1004. Implement `assign_points_to_centroids` for periodic boundary conditions
329 call assign_points_to_centroids_finite_bc(mesh, centroids, ip_to_ic)
330
331 call update_centroids(mesh, weight, ip_to_ic, centroids)
332 call compute_grid_difference(prior_centroids, centroids, tol, points_differ)
333
334 if (any(points_differ)) then
335 prior_centroids = centroids
336 else
337 write(message(1), '(a)') 'Debug: All centroid points converged'
338 call messages_info(1, debug_only=.true.)
339 ! Break loop
340 exit
341 endif
342
343 enddo
344
345 if (discretize_centroids) then
346 call mesh_discretize_values_to_mesh(mesh, centroids)
347 endif
348
349 if (present(inertia)) then
350 call compute_centroid_inertia(mesh, centroids, weight, ip_to_ic, inertia)
351 endif
352
353 safe_deallocate_a(prior_centroids)
354 safe_deallocate_a(ip_to_ic)
355 safe_deallocate_a(points_differ)
356
357 pop_sub(weighted_kmeans)
358
359 end subroutine weighted_kmeans
360
361
366 subroutine sample_initial_centroids(mesh, centroids, seed_value)
367 class(mesh_t), intent(in ) :: mesh
368 real(real64), contiguous, intent(out) :: centroids(:, :)
369 integer(int64), intent(inout), optional :: seed_value
370 ! This will get mutated by the Fisher Yates shuffle
371
372 integer(int32) :: n_centroids
373 integer(int64), allocatable :: centroid_idx(:)
374 integer(int64) :: ipg
375 integer(int32) :: ic
378
379 n_centroids = size(centroids, 2)
380 safe_allocate(centroid_idx(1:n_centroids))
381
382 ! Choose n_centroids indices from [1, np_global]
383 call fisher_yates_shuffle(n_centroids, mesh%np_global, seed_value, centroid_idx)
384
385 ! Convert ip_global to (x,y,z)
386 do ic = 1, n_centroids
387 ipg = centroid_idx(ic)
388 centroids(:, ic) = mesh_x_global(mesh, ipg)
389 enddo
390
391 safe_deallocate_a(centroid_idx)
392
394
395 end subroutine sample_initial_centroids
396
397
407 subroutine compute_centroid_inertia(mesh, centroids, weight, ip_to_ic, inertia)
408 class(mesh_t), intent(in) :: mesh
409 real(real64), intent(in) :: centroids(:, :)
410 real(real64), intent(in) :: weight(:)
411 integer, intent(in) :: ip_to_ic(:)
412 real(real64), intent(out) :: inertia
413 integer :: ip, ic
414
415 inertia = 0.0_real64
416
417 !$omp parallel do private(ip) reduction(+ : inertia)
418 do ip = 1, mesh%np
419 ic = ip_to_ic(ip)
420 inertia = inertia + weight(ip) * sum((centroids(:, ic) - mesh%x(ip, :))**2)
421 enddo
422 !$omp end parallel do
423
424 call mesh%allreduce(inertia)
425
426 end subroutine compute_centroid_inertia
427
429
430!! Local Variables:
431!! mode: f90
432!! coding: utf-8
433!! End:
type(debug_t), save, public debug
Definition: debug.F90:156
subroutine report_differences_in_grids(points, updated_points, tol, points_differ)
Report differences returned from compute_grid_difference.
subroutine, public weighted_kmeans(space, mesh, weight, centroids, n_iter, centroid_tol, discretize, inertia)
Weighted K-means clustering.
subroutine, public sample_initial_centroids(mesh, centroids, seed_value)
Sample initial centroids from the full mesh.
subroutine, public assign_points_to_centroids_finite_bc(mesh, centroids, ip_to_ic)
Assign each grid point to the closest centroid. A centroid and its set of nearest grid points defines...
subroutine, public update_centroids(mesh, weight, ip_to_ic, centroids)
Compute a new set of centroids.
subroutine compute_centroid_inertia(mesh, centroids, weight, ip_to_ic, inertia)
Compute the inertia of all centroids.
subroutine compute_grid_difference(points, updated_points, tol, points_differ)
Compute the difference in two grids as .
This module defines the meshes, which are used in Octopus.
Definition: mesh.F90:118
subroutine, public mesh_discretize_values_to_mesh(mesh, values)
Assign a set of values to their nearest discrete points on the mesh.
Definition: mesh.F90:410
real(real64) function, dimension(1:mesh%box%dim), public mesh_x_global(mesh, ipg)
Definition: mesh.F90:804
character(len=256), dimension(max_lines), public message
to be output by fatal, warning
Definition: messages.F90:160
subroutine, public messages_info(no_lines, iunit, debug_only, stress, all_nodes, namespace)
Definition: messages.F90:616
This module contains some common usage patterns of MPI routines.
Definition: mpi_lib.F90:115
subroutine, public fisher_yates_shuffle(m, n, seed, values)
Return m random numbers from a range of with no replacement.
Definition: quickrnd.F90:309
This module is intended to contain "only mathematical" functions and procedures.
Definition: sort.F90:117
Describes mesh distribution to nodes.
Definition: mesh.F90:186
int true(void)