Octopus
isdf_serial.F90
Go to the documentation of this file.
1!! Copyright (C) 2024 - 2025. Alexander Buccheri
2!!
3!! This program is free software; you can redistribute it and/or modify
4!! it under the terms of the GNU General Public License as published by
5!! the Free Software Foundation; either version 2, or (at your option)
6!! any later version.
7!!
8!! This program is distributed in the hope that it will be useful,
9!! but WITHOUT ANY WARRANTY; without even the implied warranty of
10!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11!! GNU General Public License for more details.
12!!
13!! You should have received a copy of the GNU General Public License
14!! along with this program; if not, write to the Free Software
15!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16!! 02110-1301, U
17
18#include "global.h"
19
22 use, intrinsic :: iso_fortran_env, only: real64, int64
23 use batch_oct_m
25 use blas_oct_m
26 use debug_oct_m
30 use global_oct_m
31 use grid_oct_m
32 use ions_oct_m
40 use math_oct_m
41 use mesh_oct_m
44 use mpi_oct_m, only: mpi_world
52 use space_oct_m
55 use xc_cam_oct_m, only: xc_cam_t
56
57 implicit none
58 private
59
60 public :: &
64
65 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised systems
67 integer, parameter :: ik = 1
68
69contains
70
74 subroutine isdf_serial_interpolation_vectors(isdf, namespace, mesh, st, indices, phi_mu, P_r_mu, isdf_vectors)
75 type(isdf_options_t), intent(in ) :: isdf
76 type(namespace_t), intent(in ) :: namespace
77 class(mesh_t), intent(in ) :: mesh
78 type(states_elec_t), intent(in ) :: st
79 integer(int64), contiguous, intent(in ) :: indices(:)
80
81 real(real64), allocatable, intent(out) :: phi_mu(:, :)
82 ! defined at interpolation points: \f$ \varphi_i(\mathbf{r}_\mu) \f$
83 real(real64), allocatable, intent(out) :: P_r_mu(:, :)
84 ! \f$P_{\mathbf{r},\mu}\f$, with size (np, n_int)
85 real(real64), allocatable, intent(out) :: isdf_vectors(:, :)
86
87 real(real64), allocatable :: phi(:, :), cct(:, :)
88 integer :: n_states, n_int, rank
89 logical :: data_is_packed
90
91 push_sub_with_profile(isdf_serial_interpolation_vectors)
92 call messages_write(1)
93
94 ! Reference serial implementation not parallel in states or domain
95 if (st%parallel_in_states .or. mesh%parallel_in_domains) then
96 message(1) = "Serial ISDF called when running state or domain-parallel"
97 call messages_fatal(1)
98 endif
99
100 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised systems
101 if (st%d%nspin > 1) then
102 call messages_not_implemented("ISDF Serial for SPIN_POLARIZED and SPINOR calculations", namespace)
103 endif
104
105 ! TODO(Alex) Issue #1196 Template ISDF handle both real and complex states
106 if (.not. states_are_real(st)) then
107 call messages_not_implemented("ISDF Serial handling of complex states", namespace)
108 endif
109
110 ! TODO(Alex) Implement algorithms for unpacked data structure
111 data_is_packed = st%group%psib(st%group%block_start, 1)%status() == batch_packed
112
113 if (.not. data_is_packed) then
114 message(1) = "Serial ISDF only implemented for BATCH_PACKED"
115 call messages_fatal(1)
116 endif
117
118 n_states = isdf%n_ks_states
119 n_int = size(indices)
120
121 call collate_batches_get_state(mesh, st, n_states, phi)
122 if (debug%info) call output_matrix(namespace, "phi_r_serial.txt", phi)
123
124 safe_allocate(phi_mu(1:n_states, 1:n_int))
125 call sample_phi_at_centroids(phi, indices, phi_mu)
126 if (debug%info) call output_matrix(namespace, "phi_mu_serial.txt", phi_mu)
127
128 safe_allocate(p_r_mu(1:mesh%np, 1:n_int))
129 call construct_density_matrix_packed(phi, phi_mu, p_r_mu)
130 if (debug%info) call output_matrix(namespace, "p_r_mu_serial.txt", p_r_mu)
131
132 ! Mutate P_r_mu to [ZC^T] = P_r_mu o P_r_mu
133 call construct_zct(p_r_mu)
134 if (debug%info) call output_matrix(namespace, "zct_serial.txt", p_r_mu)
135
136 ! [CC^T] = ZC^T[indices, :]
137 safe_allocate(cct(1:n_int, 1:n_int))
138 call construct_cct(indices, p_r_mu, cct)
139 if (debug%info) call output_matrix(namespace, "cct_serial.txt", cct)
140 assert(is_symmetric(cct))
141
142 ! Note, in principle one could use lalg_pseudo_inverse and just add an optional
143 ! arg for returning the rank, but annoyingly the criterion in that is > rather than >=.
144 ! Quantify the optimal number of interpolation points to use.
145 ! If the rank of CC^T is <= ISDFNpoints, this does not give any indication how many
146 ! additional points are required. It is only indicative when oversampling.
147 if (isdf%check_n_interp) then
148 rank = lalg_matrix_rank_svd(cct, preserve_mat=.true.)
149 write(message(1),'(a, I4)') "ISDF Serial: Rank of CC^T is ", rank
150 if (rank < n_int) then
151 write(message(2),'(a)') " - This rank is the optimal ISDFNpoints to run the calculation with"
152 else
153 write(message(2),'(a)') " - This suggests that ISDFNpoints is either optimal, or could be larger"
154 endif
155 call messages_info(2, namespace=namespace)
156 endif
157
158 ! [CC^T]^{-1}, mutating cct in-place
159 ! NOTE, if the number of interpolation points exceeds the rank of CC^T, CC^T is by definition
160 ! ill-conditioned, and requires inverting with the pseudo-inverse (SVD).
161 ! Tests show that this is a much better solution than regularisation of either the diagonals
162 ! or off-diagonals of CC^T. If one limits the number of interpolation points, there is no problem
163 ! with inversion, but this limits the total accuracy achievable with the method.
164
165 ! As CC^T and its inverse should be symmetric, ultimately want to:
166 ! * Only compute the inverse of the upper (or lower) triangle
167 ! * Modify the GEMM operation below to only use the upper (or lower) triangle of [CC^T]^{-1}
168 write(message(1),'(a)') "ISDF Serial: Inverting [CC^T]"
169 call messages_info(1, namespace=namespace, debug_only=.true.)
170
171 ! Invert [CC^T] and symmetrise
172 call lalg_svd_inverse(n_int, n_int, cct)
173 call symmetrize_matrix(n_int, cct)
174
175 ! Compute interpolation vectors, [ZC^T] [CC^T]^{-1}
176 safe_allocate(isdf_vectors(1:mesh%np, 1:n_int))
177 ! CC^T is by definition symmetric, implying [CC^T]^{-1} also is
178 call lalg_gemm(mesh%np, n_int, n_int, 1.0_real64, p_r_mu, cct, 0.0_real64, isdf_vectors)
179 if (debug%info) call output_matrix(namespace, "isdf_serial.txt", isdf_vectors)
180 safe_deallocate_a(cct)
181
182 ! Rebuild P_r_mu, with occupation numbers absorbed into it. Used in construction of W_ace
183 call construct_density_matrix_with_occ_packed(st, phi, phi_mu, p_r_mu)
184 safe_deallocate_a(phi)
185 if (debug%info) call output_matrix(namespace, "OccP_r_mu_serial.txt", p_r_mu)
186
187 pop_sub_with_profile(isdf_serial_interpolation_vectors)
188
190
191
195 subroutine collate_batches_get_state(mesh, st, max_state, psi)
196 class(mesh_t), intent(in ) :: mesh
197 type(states_elec_t), intent(in ) :: st
198 integer, intent(in ) :: max_state
199 real(real64), allocatable, intent(out) :: psi(:, :)
200
201 integer :: istate, ib, ist, minst, maxst, block_end
202
203 push_sub_with_profile(collate_batches_get_state)
204
205 assert(max_state > 0 .and. max_state <= st%nst)
206
207 safe_allocate(psi(1:max_state, 1:mesh%np))
208 block_end = st%group%iblock(max_state)
209
210 istate = 0
211 do ib = 1, block_end
212 ! Normalisation did not affect condition number of CC^T matrix
213 !call dmesh_batch_normalize(mesh, st%group%psib(ib, ik))
214 minst = states_elec_block_min(st, ib)
215 maxst = min(states_elec_block_max(st, ib), max_state)
216 do ist = minst, maxst
217 istate = istate + 1
218 call states_elec_get_state(st, mesh, st%d%dim, ist, ik, psi(istate, :))
219 enddo
220 enddo
221
222 pop_sub_with_profile(collate_batches_get_state)
223
224 end subroutine collate_batches_get_state
225
226
228 subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
229 real(real64), contiguous, intent(in ) :: phi_r(:, :)
230 integer(int64), contiguous, intent(in ) :: indices(:)
231 real(real64), contiguous, intent(out) :: phi_mu(:, :)
232
233 integer :: ic, is, nst, n_int
234 integer(int64) :: ipg
235
236 push_sub_with_profile(sample_phi_at_centroids)
237
238 write(message(1),'(a)') "ISDF Serial: Sampling phi(r) at mu"
239 call messages_info(1, debug_only=.true.)
240
241 nst = size(phi_r, 1)
242 assert(size(phi_mu, 1) == nst)
243
244 n_int = size(indices)
245 assert(size(phi_mu, 2) == n_int)
246
247 do ic = 1, n_int
248 ipg = indices(ic)
249 do is = 1, nst
250 phi_mu(is, ic) = phi_r(is, ipg)
251 enddo
252 enddo
253
254 pop_sub_with_profile(sample_phi_at_centroids)
255
256 end subroutine sample_phi_at_centroids
257
258
269 subroutine construct_zct(zct)
270 real(real64), contiguous, intent(inout) :: zct(:, :)
271 ! Out: Contraction of Z and C^T matrices == element-wise square of quasi-density matrix
272
273 integer :: i, j
274
275 push_sub_with_profile(construct_zct)
276
277 write(message(1),'(a)') "ISDF Serial: Constructing ZC^T"
278 call messages_info(1, debug_only=.true.)
279
280 !$omp parallel do collapse(2)
281 do j = 1, size(zct, 2)
282 do i = 1, size(zct, 1)
283 zct(i, j) = zct(i, j)**2
284 end do
285 enddo
286 !$omp end parallel do
287
288 pop_sub_with_profile(construct_zct)
289
290 end subroutine construct_zct
291
292
295 subroutine construct_cct(indices, zct, cct)
296 integer(int64), contiguous, intent(in ) :: indices(:)
297 real(real64), contiguous, intent(in ) :: zct(:, :)
298
299 real(real64), contiguous, intent(out) :: cct(:, :)
300
301 integer(int64) :: ipg
302 integer :: i_mu, i_nu, n_int
303
304 push_sub_with_profile(construct_cct)
305
306 write(message(1),'(a)') "ISDF Serial: Constructing CC^T by sampling ZC^T"
307 call messages_info(1, debug_only=.true.)
308
309 n_int = size(indices)
310 assert(all(shape(cct) == [n_int, n_int]))
311 assert(size(zct, 1) > n_int)
312 assert(size(zct, 2) == n_int)
313
314 ! Mask ZC^T to obtain CC^T
315 do i_nu = 1, n_int
316 do i_mu = 1, n_int
317 ipg = indices(i_mu)
318 cct(i_mu, i_nu) = zct(ipg, i_nu)
319 enddo
320 enddo
321
322 pop_sub_with_profile(construct_cct)
324 end subroutine construct_cct
325
326
335 subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
336 real(real64), contiguous, intent(in ) :: phi(:, :)
337 ! of shape (m_states, np)
338 real(real64), contiguous, intent(in ) :: phi_mu(:, :)
339
340 real(real64), contiguous, intent(out) :: P_r_mu(:, :)
341
342 integer :: np
343 integer :: n_int
344 integer :: m_states
345
346 push_sub_with_profile(construct_density_matrix_packed)
347
348 write(message(1),'(a)') "ISDF Serial: Constructing P_r_mu"
349 call messages_info(1, debug_only=.true.)
350
351 m_states = size(phi, 1)
352 np = size(phi, 2)
353 n_int = size(phi_mu, 2)
354
355 assert(size(phi_mu, 1) == m_states)
356 assert(size(p_r_mu, 1) == np)
357 assert(size(p_r_mu, 2) == n_int)
358
359 ! Contract over the state index, P = phi^T @ phi_mu, with shape (np, m_states) (m_states, n_int)
360 call lalg_gemm(phi, phi_mu, p_r_mu, transa='T')
361
362 pop_sub_with_profile(construct_density_matrix_packed)
363
365
366
367 subroutine construct_density_matrix_with_occ_packed(st, phi, phi_mu, P_r_mu)
368 type(states_elec_t), intent(in ) :: st
369 real(real64), intent(in ) :: phi(:, :)
370 ! of shape (m_states, np)
371 real(real64), intent(in ) :: phi_mu(:, :)
372
373 real(real64), intent(out) :: P_r_mu(:, :)
374
375 integer :: np, n_int, m_states, imu, ist
376 real(real64), allocatable :: focc_phi_mu(:, :)
377
378 push_sub_with_profile(construct_p_with_occ_packed)
379
380 write(message(1),'(a)') "ISDF Serial: Constructing P_r_mu with occupations"
381 call messages_info(1, debug_only=.true.)
382
383 m_states = size(phi_mu, 1)
384 np = size(phi, 2)
385 n_int = size(phi_mu, 2)
386
387 assert(size(phi, 1) == m_states)
388 assert(size(p_r_mu, 1) == np)
389 assert(size(p_r_mu, 2) == n_int)
391 ! Element-wise multiply the occupations with the smaller of the two arrays
392 safe_allocate(focc_phi_mu(1:m_states, 1:n_int))
393 do imu = 1, n_int
394 do ist = 1, m_states
395 focc_phi_mu(ist, imu) = st%kweights(ik) * st%occ(ist, ik) * phi_mu(ist, imu)
396 enddo
397 enddo
398
399 ! Contract over the state index, P = phi^T @ focc_phi_mu, with shape (np, m_states) (m_states, n_int)
400 call lalg_gemm(phi, focc_phi_mu, p_r_mu, transa='T')
401 safe_deallocate_a(focc_phi_mu)
402
403 pop_sub_with_profile(construct_p_with_occ_packed)
404
406
407
409 subroutine quantify_error_and_visualise(isdf, namespace, st, space, mesh, ions, indices, isdf_vectors, output_cubes)
410 type(isdf_options_t), intent(in) :: isdf
411 type(namespace_t), intent(in) :: namespace
412 type(states_elec_t), intent(in) :: st
413 class(space_t), intent(in) :: space
414 class(mesh_t), intent(in) :: mesh
415 class(ions_t), pointer, intent(in) :: ions
416 integer(int64), contiguous, intent(in) :: indices(:)
417 real(real64), allocatable, intent(inout) :: isdf_vectors(:, :)
418 logical, intent(in) :: output_cubes
419
420 real(real64), allocatable :: product_basis(:, :), approx_product_basis(:, :)
421 real(real64), allocatable :: phi(:, :), phi_mu(:, :)
422 real(real64), allocatable :: product_error(:)
423 integer :: n_occ, n_products, n_int, i, j, ij, unit
424 real(real64) :: mean_error
425
426 push_sub_with_profile(quantify_error_and_visualise)
427
428 write(message(1),'(a)') "ISDF Serial: Computing exact pair products"
429 call messages_info(1, debug_only=.true.)
431 ! Rebuild phi matrix
432 n_occ = isdf%n_ks_states
433 call collate_batches_get_state(mesh, st, n_occ, phi)
434 assert(size(phi, 2) == mesh%np)
435
436 n_products = n_occ * n_occ
437 safe_allocate(product_basis(1:n_products, 1:mesh%np))
438 call column_wise_khatri_rao_product(phi, phi, product_basis)
439
440 if (output_cubes) then
441 call generate_product_state_cubes(namespace, space, mesh, ions, "exact_product_", &
442 product_basis)
443 endif
444
445 write(message(1),'(a)') "ISDF Serial Test: Computing approximate pair products"
446 call messages_info(1, namespace=namespace, debug_only=.true.)
447
448 ! Rebuild phi_mu, again only for occupied states
449 n_int = size(indices)
450 safe_allocate(phi_mu(1:n_occ, 1:n_int))
451 call sample_phi_at_centroids(phi, indices, phi_mu)
452 safe_deallocate_a(phi)
453
454 safe_allocate(approx_product_basis(1:n_products, 1:mesh%np))
455 call approximate_pair_products(phi_mu, isdf_vectors, approx_product_basis)
456 ! if (debug%info) call output_matrix(namespace, "approx_product_blas.txt", approx_product_basis)
457
458 safe_deallocate_a(phi_mu)
459 safe_deallocate_a(isdf_vectors)
460
461 if (output_cubes) then
462 call generate_product_state_cubes(namespace, space, mesh, ions, "approx_product_", &
463 approx_product_basis)
464 endif
465
466 ! Quantify the error
467 safe_allocate(product_error(1:n_products))
468 call error_in_product_basis(mesh, product_basis, approx_product_basis, product_error, mean_error)
469 safe_deallocate_a(product_basis)
470 safe_deallocate_a(approx_product_basis)
471
472 if (mpi_world%is_root()) then
473 open(newunit=unit, file="isdf_error_serial.txt")
474 write(unit, *) 'Mean error', mean_error
475 ij = 0
476 do i = 1, n_occ
477 do j = 1, n_occ
478 ij = ij + 1
479 write(unit, *) i, j, product_error(ij)
480 enddo
481 enddo
482 close(unit)
483 endif
484
485 safe_deallocate_a(product_error)
486
487 pop_sub_with_profile(quantify_error_and_visualise)
488
489 end subroutine quantify_error_and_visualise
490
491
502 subroutine approximate_pair_products(psi_mu, zeta, product_basis)
503 real(real64), contiguous, intent(in ) :: psi_mu(:, :)
504 real(real64), contiguous, intent(in ) :: zeta(:, :)
505 real(real64), contiguous, intent(out) :: product_basis(:, :)
506
507 real(real64), allocatable :: psi_ij_mu(:, :)
508 integer :: mn_states, n_int, np
509
510 push_sub_with_profile(approximate_pair_products)
511
512 mn_states = size(psi_mu, 1)**2
513 np = size(zeta, 1)
514 n_int = size(zeta, 2)
515
516 assert(size(product_basis, 1) == mn_states)
517 assert(size(product_basis, 2) == np)
518
519 safe_allocate(psi_ij_mu(1:mn_states, 1:n_int))
520 call column_wise_khatri_rao_product(psi_mu, psi_mu, psi_ij_mu)
521
522 ! Contract product_basis = [psi_ij_mu] [zeta]^T over interpolation vector index
523 call lalg_gemm(psi_ij_mu, zeta, product_basis, transb='T')
524
525 safe_deallocate_a(psi_ij_mu)
526
527 pop_sub_with_profile(approximate_pair_products)
528
529 end subroutine approximate_pair_products
530
531
539 subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
540 class(mesh_t), intent(in ) :: mesh
541 real(real64), contiguous, intent(in ) :: product_basis(:, :)
542 real(real64), contiguous, intent(in ) :: approx_product_basis(:, :)
543
544 real(real64), contiguous, intent(out) :: error(:)
545 real(real64), intent(out) :: mean_error
546
547 integer :: mn_states, np, ij, ip
548
549 push_sub_with_profile(error_in_product_basis)
550
551 mn_states = size(product_basis, 1)
552 np = size(product_basis, 2)
553
554 ! product_basis shape is not as expected
555 assert(mesh%np == np)
556
557 ! Two arrays should be the same dimensions
558 assert(all(shape(product_basis) == shape(approx_product_basis)))
559
560 ! error should be allocated, and with the correct size
561 assert(size(error) == mn_states)
562
563 ! Initialise error with first point from the grid
564 do ij = 1, mn_states
565 error(ij) = (product_basis(ij, 1) - approx_product_basis(ij, 1))**2
566 enddo
567
568 do ip = 2, np
569 do ij = 1, mn_states
570 error(ij) = error(ij) + (product_basis(ij, ip) - approx_product_basis(ij, ip))**2
571 enddo
572 enddo
573
574 mean_error = 0.0_real64
575 do ij = 1, mn_states
576 error(ij) = sqrt(mesh%volume_element * error(ij))
577 mean_error = mean_error + error(ij)
578 enddo
579
580 mean_error = mean_error / real(mn_states, real64)
581
582 pop_sub_with_profile(error_in_product_basis)
583
584 end subroutine error_in_product_basis
585
586
588 subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
589 type(namespace_t), intent(in) :: namespace
590 class(space_t), intent(in) :: space
591 class(mesh_t), intent(in) :: mesh
592 class(ions_t), pointer, intent(in) :: ions
593 character(len=*), intent(in) :: file_prefix
594 real(real64), contiguous, intent(in) :: data(:, :)
595 integer, optional, intent(in) :: limits(2)
596
597 integer :: m_states, limit_j, limit_i, i, j, ij, ierr
598 real(real64) :: size_data
599 character(len=4) :: i_char, j_char
600 character(len=120) :: file_name
601
602 ! product basis size is currently defined as (m_states * m_states)
603 size_data = real(size(data, 1), real64)
604 m_states = int(sqrt(size_data))
605
606 if (present(limits)) then
607 limit_j = limits(1)
608 limit_i = limits(2)
609 else
610 limit_j = m_states
611 limit_i = m_states
612 endif
613
614 do i = 1, limit_i
615 do j = 1, limit_j
616 ij = j + (i - 1) * m_states
617 write(i_char, '(I4)') i
618 write(j_char, '(I4)') j
619 file_name = trim(adjustl(file_prefix)) // trim(adjustl(i_char)) // '_' // trim(adjustl(j_char))
620 call dio_function_output(option__outputformat__cube, "./cubes", trim(adjustl(file_name)), namespace, space, mesh, &
621 data(ij,:) , unit_one, ierr, pos=ions%pos, atoms=ions%atom)
622 enddo
623 enddo
624
625 end subroutine generate_product_state_cubes
626
627
641 subroutine isdf_serial_ace_compute_potentials(exxop, namespace, space, mesh, st, Vx_on_st, kpoints)
642 type(exchange_operator_t), intent(in ) :: exxop
643 ! ISDF interpolation points, and cam parameters.
644 ! An ISDF instance is not passed directly so this API is consistent with the other "compute_potential" routines.
645 type(namespace_t), intent(in ) :: namespace
646 class(space_t), intent(in ) :: space
647 ! with the existing routines
648 class(mesh_t), intent(in ) :: mesh
649 type(states_elec_t), intent(in ) :: st
650 type(kpoints_t), intent(in ) :: kpoints
651 ! with the existing routines
652
653 type(states_elec_t), intent(out) :: Vx_on_st
654
655 real(real64), allocatable :: psi_mu(:, :), P_r_mu(:, :), W_ace(:, :), isdf_vectors(:, :)
656 integer(int64), allocatable :: indices(:)
657
659
660 ! TODO(Alex) Issue #1195 Extend ISDF to spin-polarised and periodic systems
661 assert(kpoints%gamma_only())
662 assert(.not. space%is_periodic())
663 assert(st%d%nspin == 1)
664
665 indices = exxop%isdf%centroids%global_mesh_indices()
666
667 call isdf_serial_interpolation_vectors(exxop%isdf, namespace, mesh, st, &
668 indices, psi_mu, p_r_mu, isdf_vectors)
669
670 safe_allocate(w_ace(1:mesh%np, exxop%isdf%n_ks_states))
671 call isdf_serial_ace_w_unpacked(namespace, p_r_mu, isdf_vectors, psi_mu, exxop%psolver, exxop%cam, st, w_ace)
672 safe_deallocate_a(psi_mu)
673 safe_deallocate_a(p_r_mu)
674 safe_deallocate_a(isdf_vectors)
675
676 call isdf_serial_ace_batch_w(exxop%isdf, st, w_ace, vx_on_st)
677 safe_deallocate_a(w_ace)
678
680
682
691 subroutine isdf_serial_ace_w_unpacked(namespace, P_r_mu, isdf_vectors, psi_mu, poisson_solver, cam, st, W_ace)
692 type(namespace_t), intent(in ) :: namespace
693 real(real64), intent(in ), contiguous :: P_r_mu(:, :)
694 real(real64), intent(in ), contiguous :: isdf_vectors(:, :)
695 real(real64), intent(in ), contiguous :: psi_mu(:, :)
696 type(poisson_t), intent(in ) :: poisson_solver
697 type(xc_cam_t), intent(in) :: cam
698 type(states_elec_t), intent(in ) :: st
699
700 real(real64), intent(out), contiguous :: W_ace(:, :)
701
702 integer :: ip, i_mu, ist, np, n_int, nst
703 real(real64) :: psi_ist_mu
704 real(real64), allocatable :: V_r_nu(:, :)
705 logical :: use_external_kernel
706 real(real64) :: exx_weight
707 real(real64) :: weight
708
709 push_sub_with_profile(isdf_serial_ace_w_unpacked)
710
711 ! Number of states defines the number used in ISDF, which is typically ~ N occupied states
712 nst = size(psi_mu, 1)
713 np = size(p_r_mu, 1)
714 n_int = size(p_r_mu, 2)
715
716 assert(all(shape(p_r_mu) == shape(isdf_vectors)))
717 assert(size(psi_mu, 2) == n_int)
718 assert(size(w_ace, 1) == np)
719 ! Implies a size issue with either W_ace or psi_mu
720 assert(size(w_ace, 2) == nst)
721
722 use_external_kernel = (st%nik > st%d%spin_channels .or. cam%omega > m_epsilon)
723 if (use_external_kernel) then
724 message(1) = "External kernel not supported in ISDF"
725 call messages_fatal(1)
726 endif
727 exx_weight = cam%alpha
728 weight = exx_weight / st%smear%el_per_state
729
730 safe_allocate(v_r_nu(1:np, 1:n_int))
731 call isdf_potential(namespace, poisson_solver, isdf_vectors, v_r_nu)
732
733 write(message(1),'(a)') "ISDF: Writing V from isdf_ace_w_unpacked"
734 call messages_info(1, namespace=namespace, debug_only=.true.)
735
736 ! Initialise elements of W_ace with data from the first interpolation point
737 i_mu = 1
738 do ist = 1, nst
739 psi_ist_mu = weight * psi_mu(ist, i_mu)
740 do ip = 1, np
741 w_ace(ip, ist) = - (p_r_mu(ip, i_mu) * v_r_nu(ip, i_mu) * psi_ist_mu)
742 enddo
743 enddo
744
745 ! Construct W_ace
746 do i_mu = 2, n_int
747 do ist = 1, nst
748 psi_ist_mu = weight * psi_mu(ist, i_mu)
749 do ip = 1, np
750 w_ace(ip, ist) = w_ace(ip, ist) - (p_r_mu(ip, i_mu) * v_r_nu(ip, i_mu) * psi_ist_mu)
751 enddo
752 enddo
753 enddo
754
755 safe_deallocate_a(v_r_nu)
756
757 pop_sub_with_profile(isdf_serial_ace_w_unpacked)
758
759 end subroutine isdf_serial_ace_w_unpacked
760
761
765 subroutine isdf_serial_ace_batch_w(isdf, st, W_ace, Vx_on_st)
766 type(isdf_options_t), intent(in ) :: isdf
767 type(states_elec_t), intent(in ) :: st
768 real(real64), intent(in ), contiguous :: w_ace(:, :)
769
770 type(states_elec_t), intent(out) :: vx_on_st
771
772 integer :: ist, ib, ist_b, np, max_state, minst, maxst, block_end, block_size
773
774 push_sub_with_profile(isdf_serial_ace_batch_w)
775
776 assert(size(w_ace, 2) == isdf%n_ks_states)
777 assert(st%d%dim == 1)
778
779 ! Copy memory layout, without data
780 call states_elec_copy(vx_on_st, st)
781 call states_elec_set_zero(vx_on_st)
782 np = size(w_ace, 1)
783
784 ! Ensure we do not go beyond the total number of occupied states
785 max_state = min(isdf%n_ks_states, st%st_end)
786 block_end = min(st%group%block_end, st%group%iblock(max_state))
787
788 do ib = st%group%block_start, block_end
789 minst = states_elec_block_min(st, ib)
790 maxst = min(states_elec_block_max(st, ib), max_state)
791 block_size = maxst - minst + 1
792 ! States in a block
793 do ist_b = 1, block_size
794 ! Global state index
795 ist = minst - 1 + ist_b
796 call batch_set_state(vx_on_st%group%psib(ib, 1), ist_b, np, w_ace(:, ist))
797 enddo
798 enddo
799
800 pop_sub_with_profile(isdf_serial_ace_batch_w)
801
802 end subroutine isdf_serial_ace_batch_w
803
804end module isdf_serial_oct_m
805
806!! Local Variables:
807!! mode: f90
808!! coding: utf-8
809!! End:
There are several ways how to call batch_set_state and batch_get_state:
Definition: batch_ops.F90:203
Matrix-matrix multiplication plus matrix.
Definition: lalg_basic.F90:229
double sqrt(double __x) __attribute__((__nothrow__
This module implements batches of mesh functions.
Definition: batch.F90:135
integer, parameter, public batch_packed
functions are stored in CPU memory, in transposed (packed) order
Definition: batch.F90:286
This module implements common operations on batches of mesh functions.
Definition: batch_ops.F90:118
This module contains interfaces for BLAS routines You should not use these routines directly....
Definition: blas.F90:120
type(debug_t), save, public debug
Definition: debug.F90:158
subroutine, public column_wise_khatri_rao_product(y, x, z)
Column-wise Kronecker product.
real(real64), parameter, public m_epsilon
Definition: global.F90:206
This module implements the underlying real-space grid.
Definition: grid.F90:119
subroutine, public dio_function_output(how, dir, fname, namespace, space, mesh, ff, unit, ierr, pos, atoms, grp, root)
Top-level IO routine for functions defined on the mesh.
Serial prototype for benchmarking and validating ISDF implementation.
subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
Helper function to output a set of pair product states.
subroutine construct_density_matrix_with_occ_packed(st, phi, phi_mu, P_r_mu)
subroutine collate_batches_get_state(mesh, st, max_state, psi)
Loop over states per block, which makes applying the maximum state limit much simpler Use this to com...
subroutine isdf_serial_ace_w_unpacked(namespace, P_r_mu, isdf_vectors, psi_mu, poisson_solver, cam, st, W_ace)
Compute the action of the exchange potential on KS states for adaptively-compressed exchange.
subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
Sample KS states at centroid points.
subroutine, public isdf_serial_interpolation_vectors(isdf, namespace, mesh, st, indices, phi_mu, P_r_mu, isdf_vectors)
Construct interpolative separable density fitting (ISDF) vectors and other intermediate quantities re...
subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
Quantify the error in the product basis expansion.
subroutine, public isdf_serial_ace_compute_potentials(exxop, namespace, space, mesh, st, Vx_on_st, kpoints)
ISDF wrapper computing interpolation points and vectors, which are used to build the potential used ...
subroutine construct_zct(zct)
Construct the product of Z and C matrices from the element-wise product of the quasi-density matrix.
subroutine construct_cct(indices, zct, cct)
Construct the product from by masking the first dimension of .
subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
@ brief Construct the density matrix with shape (np, n_int). Denoted packed, because it expects phi i...
subroutine approximate_pair_products(psi_mu, zeta, product_basis)
Construct a set of approximate pair products using the ISDF interpolation vectors.
subroutine, public quantify_error_and_visualise(isdf, namespace, st, space, mesh, ions, indices, isdf_vectors, output_cubes)
Wrapper for quantifying the error in the expansion of the product basis.
subroutine isdf_serial_ace_batch_w(isdf, st, W_ace, Vx_on_st)
Put the bare array representation of W into a batch.
subroutine, public output_matrix(namespace, fname, matrix)
Helper routine to output a 2D matrix.
Definition: isdf_utils.F90:148
subroutine, public isdf_potential(namespace, poisson_solver, isdf_vectors, V_r_nu)
Compute the effective potential in the ISDF vector basis.
Definition: isdf_utils.F90:182
This module is intended to contain "only mathematical" functions and procedures.
Definition: math.F90:117
logical function, public is_symmetric(a, tol)
Check if a 2D array is symmetric.
Definition: math.F90:1493
This module defines functions over batches of mesh functions.
Definition: mesh_batch.F90:118
This module defines the meshes, which are used in Octopus.
Definition: mesh.F90:120
subroutine, public messages_not_implemented(feature, namespace)
Definition: messages.F90:1097
character(len=256), dimension(max_lines), public message
to be output by fatal, warning
Definition: messages.F90:162
subroutine, public messages_fatal(no_lines, only_root_writes, namespace)
Definition: messages.F90:416
subroutine, public messages_info(no_lines, iunit, debug_only, stress, all_nodes, namespace)
Definition: messages.F90:600
type(mpi_grp_t), public mpi_world
Definition: mpi.F90:272
pure logical function, public states_are_real(st)
integer pure function, public states_elec_block_max(st, ib)
return index of last state in block ib
subroutine, public states_elec_copy(stout, stin, exclude_wfns, exclude_eigenval, special)
make a (selective) copy of a states_elec_t object
integer pure function, public states_elec_block_min(st, ib)
return index of first state in block ib
subroutine, public states_elec_set_zero(st)
Explicitly set all wave functions in the states to zero.
This module provides routines for communicating states when using states parallelization.
This module defines the unit system, used for input and output.
type(unit_t), public unit_angstrom
For XYZ files.
type(unit_t), public unit_one
some special units required for particular quantities
Describes mesh distribution to nodes.
Definition: mesh.F90:187
The states_elec_t class contains all electronic wave functions.
Coulomb-attenuating method parameters, used in the partitioning of the Coulomb potential into a short...
Definition: xc_cam.F90:141
int true(void)