22 use,
intrinsic :: iso_fortran_env, only: real64
57 integer,
parameter :: ik = 1
75 class(isdf_options_t),
intent(out) :: this
76 type(namespace_t),
intent(in ) :: namespace
77 type(states_elec_t),
intent(in ) :: st
79 integer :: default_n_interp
95 call parse_variable(namespace,
'NCentroidPoints', default_n_interp, this%n_interp)
106 type(namespace_t),
intent(in ) :: namespace
107 class(mesh_t),
intent(in ) :: mesh
108 type(states_elec_t),
intent(in ) :: st
109 integer(int64),
intent(in ) :: int_indices(:)
110 real(real64),
allocatable,
intent(out) :: phi(:, :)
114 real(real64),
allocatable,
intent(out) :: isdf_vectors(:, :)
120 message(1) =
"Serial ISDF"
121 call messages_write(1)
124 if (st%parallel_in_states .or. mesh%parallel_in_domains)
then
125 message(1) =
"Serial ISDF called when running state or domain-parallel"
126 call messages_fatal(1)
130 if (st%d%nspin > 1)
then
131 call messages_not_implemented(
"ISDF Serial for SPIN_POLARIZED and SPINOR calculations", namespace)
135 if (.not. states_are_real(st))
then
136 call messages_not_implemented(
"ISDF Serial handling of complex states", namespace)
145 if (debug%info)
call output_matrix(namespace,
"phi_r_serial.txt", phi)
146 safe_deallocate_a(phi)
151 if (debug%info)
call output_matrix(namespace,
"phi_r_serial3.txt", phi)
153 select case (st%group%psib(st%group%block_start, 1)%status())
154 case (batch_device_packed)
155 message(1) =
"Serial ISDF not implemented for BATCH_DEVICE_PACKED"
156 call messages_fatal(1)
161 case (batch_not_packed)
162 message(1) =
"Serial ISDF not implemented for BATCH_NOT_PACKED"
163 call messages_fatal(1)
164 assert(all(shape(phi) == [mesh%np, st%nst]))
177 type(namespace_t),
intent(in ) :: namespace
178 class(mesh_t),
intent(in ) :: mesh
179 real(real64),
intent(in ) :: phi(:, :)
180 integer(int64),
intent(in ) :: indices(:)
181 real(real64),
allocatable,
intent(out) :: isdf_vectors(:, :)
183 real(real64),
allocatable :: phi_mu(:, :)
184 real(real64),
allocatable :: P_mu_nu(:, :)
185 real(real64),
allocatable :: zct(:, :)
186 real(real64),
allocatable :: cct(:, :)
188 integer :: n_int, i, j, n_states
189 logical,
parameter :: construct_P_mu_nu = .false.
193 assert(
size(phi, 2) == mesh%np)
195 n_states =
size(phi, 1)
196 n_int =
size(indices)
198 safe_allocate(phi_mu(1:n_states, 1:n_int))
200 if (debug%info)
call output_matrix(namespace,
"phi_mu_serial.txt", phi_mu)
203 safe_allocate(zct(1:mesh%np, 1:n_int))
205 if (debug%info)
call output_matrix(namespace,
"p_r_mu_serial.txt", zct)
209 if (debug%info)
call output_matrix(namespace,
"zct_serial.txt", zct)
212 safe_allocate(cct(1:n_int, 1:n_int))
214 if (debug%info)
call output_matrix(namespace,
"cct_serial.txt", cct)
215 assert(is_symmetric(cct))
218 if (construct_p_mu_nu)
then
219 safe_allocate(p_mu_nu(1:n_int, 1:n_int))
221 assert(is_symmetric(p_mu_nu))
223 safe_deallocate_a(phi_mu)
224 if (debug%info)
call output_matrix(namespace,
"p_mu_nu_serial.txt", p_mu_nu)
227 write(message(1),
'(a)')
"ISDF Serial: Constructing [CC^T] = P_mu_nu o P_mu_nu"
228 call messages_info(1, namespace=namespace, debug_only=.
true.)
232 p_mu_nu(i, j) = p_mu_nu(i, j) * p_mu_nu(i, j)
236 if (debug%info)
call output_matrix(namespace,
"cct_alt_serial.txt", p_mu_nu)
237 safe_deallocate_a(p_mu_nu)
246 write(message(1),
'(a)')
"ISDF Serial: Inverting [CC^T]"
247 call messages_info(1, namespace=namespace, debug_only=.
true.)
250 call lalg_svd_inverse(n_int, n_int, cct)
251 call symmetrize_matrix(n_int, cct)
254 safe_allocate(isdf_vectors(1:mesh%np, 1:n_int))
257 call lalg_gemm(mesh%np, n_int, n_int, 1.0_real64, zct, cct, 0.0_real64, isdf_vectors)
259 if (debug%info)
call output_matrix(namespace,
"isdf_serial.txt", isdf_vectors)
260 safe_deallocate_a(zct)
261 safe_deallocate_a(cct)
272 class(mesh_t),
intent(in ) :: mesh
273 type(states_elec_t),
intent(in ) :: st
274 integer,
intent(in ) :: max_state
275 real(real64),
allocatable,
intent(out) :: psi(:, :)
277 integer :: istate, ib, ist, minst, maxst, block_end
281 assert(max_state <= st%nst)
283 safe_allocate(psi(1:max_state, 1:mesh%np))
284 block_end = st%group%iblock(max_state)
290 minst = states_elec_block_min(st, ib)
291 maxst = min(states_elec_block_max(st, ib), max_state)
292 do ist = minst, maxst
295 if (abs(st%occ(ist,
ik) * st%kweights(
ik)) < m_min_occ)
then
296 psi(istate, :) = 0.0_real64
298 call states_elec_get_state(st, mesh, st%d%dim, ist,
ik, psi(istate, :))
310 integer,
intent(in ) :: np
311 type(states_elec_t),
intent(in ) :: st
312 integer,
intent(in) :: max_state
313 real(real64),
allocatable,
intent(out) :: psi(:, :)
315 integer :: ip, ib, minst, maxst, block_end, ist, ist_local
319 select case (st%group%psib(1,1)%status())
320 case (batch_device_packed)
327 safe_allocate(psi(1:max_state, 1:np))
328 block_end = st%group%iblock(max_state)
332 minst = states_elec_block_min(st, ib)
334 maxst = min(states_elec_block_max(st, ib), max_state)
336 do ist = minst, maxst
337 ist_local = ist - minst + 1
339 if (abs(st%occ(ist,
ik) * st%kweights(
ik)) < m_min_occ)
then
340 psi(ist, ip) = 0.0_real64
342 psi(ist, ip) = st%group%psib(ib,
ik)%dff_pack(ist_local, ip)
348 case (batch_not_packed)
360 real(real64),
intent(in ) :: phi_r(:, :)
361 integer(int64),
intent(in ) :: indices(:)
362 real(real64),
intent(out) :: phi_mu(:, :)
364 integer :: ic, is, nst, n_int
365 integer(int64) :: ipg
369 write(message(1),
'(a)')
"ISDF Serial: Sampling phi(r) at mu"
370 call messages_info(1, debug_only=.
true.)
373 assert(
size(phi_mu, 1) == nst)
375 n_int =
size(indices)
376 assert(
size(phi_mu, 2) == n_int)
381 phi_mu(is, ic) = phi_r(is, ipg)
401 real(real64),
intent(inout) :: zct(:, :)
408 write(message(1),
'(a)')
"ISDF Serial: Constructing ZC^T"
409 call messages_info(1, debug_only=.
true.)
412 do j = 1,
size(zct, 2)
413 do i = 1,
size(zct, 1)
414 zct(i, j) = zct(i, j)**2
427 integer(int64),
intent(in ) :: indices(:)
428 real(real64),
intent(in ) :: zct(:, :)
430 real(real64),
intent(out) :: cct(:, :)
432 integer(int64) :: ipg
433 integer :: i_mu, i_nu, n_int
437 write(message(1),
'(a)')
"ISDF Serial: Constructing CC^T by sampling ZC^T"
438 call messages_info(1, debug_only=.
true.)
440 n_int =
size(indices)
441 assert(all(shape(cct) == [n_int, n_int]))
442 assert(
size(zct, 1) > n_int)
443 assert(
size(zct, 2) == n_int)
449 cct(i_mu, i_nu) = zct(ipg, i_nu)
467 real(real64),
intent(in ) :: phi(:, :)
469 real(real64),
intent(in ) :: phi_mu(:, :)
471 real(real64),
intent(out) :: P_r_mu(:, :)
479 write(message(1),
'(a)')
"ISDF Serial: Constructing P_r_mu"
480 call messages_info(1, debug_only=.
true.)
482 m_states =
size(phi, 1)
484 n_int =
size(phi_mu, 2)
486 assert(
size(phi_mu, 1) == m_states)
487 assert(
size(p_r_mu, 1) == np)
488 assert(
size(p_r_mu, 2) == n_int)
491 call lalg_gemm(phi, phi_mu, p_r_mu, transa=
'T')
507 real(real64),
intent(in ) :: phi_mu(:, :)
509 real(real64),
intent(out) :: P_mu_nu(:, :)
515 write(message(1),
'(a)')
"ISDF Serial: Constructing P_mu_nu"
516 call messages_info(1, debug_only=.
true.)
518 n_int =
size(phi_mu, 2)
519 assert(
size(p_mu_nu, 1) == n_int)
520 assert(
size(p_mu_nu, 2) == n_int)
523 call lalg_gemm(phi_mu, phi_mu, p_mu_nu, transa=
'T')
532 type(namespace_t),
intent(in) :: namespace
533 type(states_elec_t),
intent(in) :: st
534 class(space_t),
intent(in) :: space
535 class(mesh_t),
intent(in) :: mesh
536 class(ions_t),
pointer,
intent(in) :: ions
537 real(real64),
allocatable,
intent(inout) :: phi(:, :)
538 integer(int64),
intent(in) :: indices(:)
539 real(real64),
allocatable,
intent(inout) :: isdf_vectors(:, :)
540 logical,
intent(in) :: output_cubes
542 real(real64),
allocatable :: product_basis(:, :), approx_product_basis(:, :)
543 real(real64),
allocatable :: phi_mu(:, :), phi_occ(:, :)
544 real(real64),
allocatable :: product_error(:)
545 integer :: n_occ, n_products, n_int, i, j, ij, is, ip, unit
546 real(real64) :: mean_error
550 write(message(1),
'(a)')
"ISDF Serial: Computing exact pair products"
551 call messages_info(1, debug_only=.
true.)
553 assert(
size(phi, 2) == mesh%np)
557 safe_allocate(phi_occ(1:n_occ, 1:mesh%np))
560 phi_occ(is, ip) = phi(is, ip)
563 safe_deallocate_a(phi)
565 n_products = n_occ * n_occ
566 safe_allocate(product_basis(1:n_products, 1:mesh%np))
567 call column_wise_khatri_rao_product(phi_occ, phi_occ, product_basis)
570 if (output_cubes)
then
575 write(message(1),
'(a)')
"ISDF Serial Test: Computing approximate pair products"
576 call messages_info(1, namespace=namespace, debug_only=.
true.)
579 n_int =
size(indices)
580 safe_allocate(phi_mu(1:n_occ, 1:n_int))
582 safe_deallocate_a(phi_occ)
584 safe_allocate(approx_product_basis(1:n_products, 1:mesh%np))
588 safe_deallocate_a(phi_mu)
589 safe_deallocate_a(isdf_vectors)
591 if (output_cubes)
then
593 approx_product_basis)
597 safe_allocate(product_error(1:n_products))
599 safe_deallocate_a(product_basis)
600 safe_deallocate_a(approx_product_basis)
602 if (mpi_world%is_root())
then
603 open(newunit=unit, file=
"isdf_error_serial.txt")
604 write(unit, *)
'Mean error', mean_error
609 write(unit, *) i, j, product_error(ij)
615 safe_deallocate_a(product_error)
633 real(real64),
intent(in ) :: psi_mu(:, :)
634 real(real64),
intent(in ) :: zeta(:, :)
635 real(real64),
intent(out) :: product_basis(:, :)
637 real(real64),
allocatable :: psi_ij_mu(:, :)
638 integer :: mn_states, n_int, np
642 mn_states =
size(psi_mu, 1)**2
644 n_int =
size(zeta, 2)
646 assert(
size(product_basis, 1) == mn_states)
647 assert(
size(product_basis, 2) == np)
649 safe_allocate(psi_ij_mu(1:mn_states, 1:n_int))
650 call column_wise_khatri_rao_product(psi_mu, psi_mu, psi_ij_mu)
653 call lalg_gemm(psi_ij_mu, zeta, product_basis, transb=
'T')
655 safe_deallocate_a(psi_ij_mu)
670 class(mesh_t),
intent(in ) :: mesh
671 real(real64),
intent(in ) :: product_basis(:, :)
672 real(real64),
intent(in ) :: approx_product_basis(:, :)
674 real(real64),
intent(out) :: error(:)
675 real(real64),
intent(out) :: mean_error
677 integer :: mn_states, np, ij, ip
681 mn_states =
size(product_basis, 1)
682 np =
size(product_basis, 2)
685 assert(mesh%np == np)
688 assert(all(shape(product_basis) == shape(approx_product_basis)))
691 assert(
size(error) == mn_states)
695 error(ij) = (product_basis(ij, 1) - approx_product_basis(ij, 1))**2
700 error(ij) = error(ij) + (product_basis(ij, ip) - approx_product_basis(ij, ip))**2
704 mean_error = 0.0_real64
706 error(ij) =
sqrt(mesh%volume_element * error(ij))
707 mean_error = mean_error + error(ij)
710 mean_error = mean_error / real(mn_states, real64)
719 type(namespace_t),
intent(in) :: namespace
720 character(len=*),
intent(in) :: fname
721 real(real64),
intent(in) :: matrix(:, :)
723 integer :: i, j, unit
727 write(message(1),
'(a)')
"ISDF Serial. Outputting: " // trim(adjustl(fname))
728 call messages_info(1, namespace=namespace, debug_only=.
true.)
730 if (mpi_world%is_root())
then
732 open(newunit=unit, file=trim(adjustl(fname)))
733 do j = 1,
size(matrix, 2)
734 do i = 1,
size(matrix, 1)
735 write(unit, *) matrix(i, j)
749 type(namespace_t),
intent(in) :: namespace
750 class(space_t),
intent(in) :: space
751 class(mesh_t),
intent(in) :: mesh
752 class(ions_t),
pointer,
intent(in) :: ions
753 character(len=*),
intent(in) :: file_prefix
754 real(real64),
intent(in) :: data(:, :)
755 integer,
optional,
intent(in) :: limits(2)
757 integer :: m_states, limit_j, limit_i, i, j, ij, ierr
758 real(real64) :: size_data
759 character(len=4) :: i_char, j_char
760 character(len=120) :: file_name
763 size_data = real(
size(
data, 1), real64)
764 m_states = int(
sqrt(size_data))
766 if (
present(limits))
then
776 ij = j + (i - 1) * m_states
777 write(i_char,
'(I4)') i
778 write(j_char,
'(I4)') j
779 file_name = trim(adjustl(file_prefix)) // trim(adjustl(i_char)) //
'_' // trim(adjustl(j_char))
780 call dio_function_output(option__outputformat__cube,
"./cubes", trim(adjustl(file_name)), namespace, space, mesh, &
781 data(ij,:) , unit_one, ierr, pos=ions%pos, atoms=ions%atom)
798 type(states_elec_t),
intent(in) :: st
799 integer,
intent(in) :: ik_index
806 if (st%smear%method /= smear_semiconductor)
then
812 if (abs(st%occ(ist, ik_index)) < m_min_occ)
exit
815 assert(i_max_occ > 0)
double sqrt(double __x) __attribute__((__nothrow__
This module implements batches of mesh functions.
This module implements common operations on batches of mesh functions.
This module contains interfaces for BLAS routines You should not use these routines directly....
This module implements the underlying real-space grid.
Serial prototype for benchmarking and validating ISDF implementation.
subroutine, public serial_interpolative_separable_density_fitting_vectors(namespace, mesh, st, int_indices, phi, isdf_vectors)
Compute a set of ISDF interpolation vectors in serial, for code validation.
subroutine, public quantify_error_and_visualise(namespace, st, space, mesh, ions, phi, indices, isdf_vectors, output_cubes)
Wrapper for quantifying the error in the expansion of the product basis.
subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
Helper function to output a set of pair product states.
subroutine output_matrix(namespace, fname, matrix)
Helper routine to output a 2D matrix.
subroutine collate_batches_get_state(mesh, st, max_state, psi)
Loop over states per block, which makes applying the maximum state limit much simpler Use this to com...
subroutine isdf_options_init(this, namespace, st)
Initialise isdf_inp_options_t instance.
integer function highest_occupied_index(st, ik_index)
Return the index of highest occupied Kohn-Sham state for k-point ik.
subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
Sample KS states at centroid points.
subroutine collate_batches(np, st, max_state, psi)
Put batches into a single 2D array.
subroutine construct_density_matrix_all_centroids_packed(phi_mu, P_mu_nu)
@ brief Construct the density matrix with shape (n_int, n_int). Denoted packed, because it expects ph...
subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
Quantify the error in the product basis expansion.
subroutine construct_zct(zct)
Construct the product of Z and C matrices from the element-wise product of the quasi-density matrix.
subroutine isdf_construct_interpolation_vectors_packed(namespace, mesh, phi, indices, isdf_vectors)
Compute a set of ISDF interpolation vectors, where intermediate quantities such as phi are constructe...
subroutine construct_cct(indices, zct, cct)
Construct the product from by masking the first dimension of .
integer, parameter ik
Hard-coded for Gamma-point, spin-unpolarised calculations.
subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
@ brief Construct the density matrix with shape (np, n_int). Denoted packed, because it expects phi i...
subroutine approximate_pair_products(psi_mu, zeta, product_basis)
Construct a set of approximate pair products using the ISDF interpolation vectors.
This module is intended to contain "only mathematical" functions and procedures.
This module defines functions over batches of mesh functions.
This module defines the meshes, which are used in Octopus.
type(mpi_grp_t), public mpi_world
integer, parameter, public smear_semiconductor
pure logical function, public states_are_real(st)
This module provides routines for communicating states when using states parallelization.
This module defines the unit system, used for input and output.
type(unit_t), public unit_angstrom
For XYZ files.
type(unit_t), public unit_one
some special units required for particular quantities
Class describing the electron system.