22 use,
intrinsic :: iso_fortran_env, only: real64
58 integer,
parameter :: ik = 1
76 class(isdf_options_t),
intent(out) :: this
77 type(namespace_t),
intent(in ) :: namespace
78 type(states_elec_t),
intent(in ) :: st
80 integer :: default_n_interp
84 default_n_interp = 10 * highest_occupied_index(st,
ik)
96 call parse_variable(namespace,
'NCentroidPoints', default_n_interp, this%n_interp)
107 type(namespace_t),
intent(in ) :: namespace
108 class(mesh_t),
intent(in ) :: mesh
109 type(states_elec_t),
intent(in ) :: st
110 integer(int64),
contiguous,
intent(in ) :: int_indices(:)
111 real(real64),
allocatable,
intent(out) :: phi(:, :)
115 real(real64),
allocatable,
intent(out) :: isdf_vectors(:, :)
121 message(1) =
"Serial ISDF"
122 call messages_write(1)
125 if (st%parallel_in_states .or. mesh%parallel_in_domains)
then
126 message(1) =
"Serial ISDF called when running state or domain-parallel"
127 call messages_fatal(1)
131 if (st%d%nspin > 1)
then
132 call messages_not_implemented(
"ISDF Serial for SPIN_POLARIZED and SPINOR calculations", namespace)
136 if (.not. states_are_real(st))
then
137 call messages_not_implemented(
"ISDF Serial handling of complex states", namespace)
140 nocc = highest_occupied_index(st,
ik)
146 if (debug%info)
call output_matrix(namespace,
"phi_r_serial.txt", phi)
147 safe_deallocate_a(phi)
152 if (debug%info)
call output_matrix(namespace,
"phi_r_serial3.txt", phi)
154 select case (st%group%psib(st%group%block_start, 1)%status())
155 case (batch_device_packed)
156 message(1) =
"Serial ISDF not implemented for BATCH_DEVICE_PACKED"
157 call messages_fatal(1)
162 case (batch_not_packed)
163 message(1) =
"Serial ISDF not implemented for BATCH_NOT_PACKED"
164 call messages_fatal(1)
165 assert(all(shape(phi) == [mesh%np, st%nst]))
178 type(namespace_t),
intent(in ) :: namespace
179 class(mesh_t),
intent(in ) :: mesh
180 real(real64),
contiguous,
intent(in ) :: phi(:, :)
181 integer(int64),
contiguous,
intent(in ) :: indices(:)
182 real(real64),
allocatable,
intent(out) :: isdf_vectors(:, :)
184 real(real64),
allocatable :: phi_mu(:, :)
185 real(real64),
allocatable :: P_mu_nu(:, :)
186 real(real64),
allocatable :: zct(:, :)
187 real(real64),
allocatable :: cct(:, :)
189 integer :: n_int, i, j, n_states
190 logical,
parameter :: construct_P_mu_nu = .false.
194 assert(
size(phi, 2) == mesh%np)
196 n_states =
size(phi, 1)
197 n_int =
size(indices)
199 safe_allocate(phi_mu(1:n_states, 1:n_int))
201 if (debug%info)
call output_matrix(namespace,
"phi_mu_serial.txt", phi_mu)
204 safe_allocate(zct(1:mesh%np, 1:n_int))
206 if (debug%info)
call output_matrix(namespace,
"p_r_mu_serial.txt", zct)
210 if (debug%info)
call output_matrix(namespace,
"zct_serial.txt", zct)
213 safe_allocate(cct(1:n_int, 1:n_int))
215 if (debug%info)
call output_matrix(namespace,
"cct_serial.txt", cct)
216 assert(is_symmetric(cct))
219 if (construct_p_mu_nu)
then
220 safe_allocate(p_mu_nu(1:n_int, 1:n_int))
222 assert(is_symmetric(p_mu_nu))
224 safe_deallocate_a(phi_mu)
225 if (debug%info)
call output_matrix(namespace,
"p_mu_nu_serial.txt", p_mu_nu)
228 write(message(1),
'(a)')
"ISDF Serial: Constructing [CC^T] = P_mu_nu o P_mu_nu"
229 call messages_info(1, namespace=namespace, debug_only=.
true.)
233 p_mu_nu(i, j) = p_mu_nu(i, j) * p_mu_nu(i, j)
237 if (debug%info)
call output_matrix(namespace,
"cct_alt_serial.txt", p_mu_nu)
238 safe_deallocate_a(p_mu_nu)
247 write(message(1),
'(a)')
"ISDF Serial: Inverting [CC^T]"
248 call messages_info(1, namespace=namespace, debug_only=.
true.)
251 call lalg_svd_inverse(n_int, n_int, cct)
252 call symmetrize_matrix(n_int, cct)
255 safe_allocate(isdf_vectors(1:mesh%np, 1:n_int))
258 call lalg_gemm(mesh%np, n_int, n_int, 1.0_real64, zct, cct, 0.0_real64, isdf_vectors)
260 if (debug%info)
call output_matrix(namespace,
"isdf_serial.txt", isdf_vectors)
261 safe_deallocate_a(zct)
262 safe_deallocate_a(cct)
273 class(mesh_t),
intent(in ) :: mesh
274 type(states_elec_t),
intent(in ) :: st
275 integer,
intent(in ) :: max_state
276 real(real64),
allocatable,
intent(out) :: psi(:, :)
278 integer :: istate, ib, ist, minst, maxst, block_end
282 assert(max_state <= st%nst)
284 safe_allocate(psi(1:max_state, 1:mesh%np))
285 block_end = st%group%iblock(max_state)
291 minst = states_elec_block_min(st, ib)
292 maxst = min(states_elec_block_max(st, ib), max_state)
293 do ist = minst, maxst
296 if (abs(st%occ(ist,
ik) * st%kweights(
ik)) < m_min_occ)
then
297 psi(istate, :) = 0.0_real64
299 call states_elec_get_state(st, mesh, st%d%dim, ist,
ik, psi(istate, :))
311 integer,
intent(in ) :: np
312 type(states_elec_t),
intent(in ) :: st
313 integer,
intent(in) :: max_state
314 real(real64),
allocatable,
intent(out) :: psi(:, :)
316 integer :: ip, ib, minst, maxst, block_end, ist, ist_local
320 select case (st%group%psib(1,1)%status())
321 case (batch_device_packed)
328 safe_allocate(psi(1:max_state, 1:np))
329 block_end = st%group%iblock(max_state)
333 minst = states_elec_block_min(st, ib)
335 maxst = min(states_elec_block_max(st, ib), max_state)
337 do ist = minst, maxst
338 ist_local = ist - minst + 1
340 if (abs(st%occ(ist,
ik) * st%kweights(
ik)) < m_min_occ)
then
341 psi(ist, ip) = 0.0_real64
343 psi(ist, ip) = st%group%psib(ib,
ik)%dff_pack(ist_local, ip)
349 case (batch_not_packed)
361 real(real64),
contiguous,
intent(in ) :: phi_r(:, :)
362 integer(int64),
contiguous,
intent(in ) :: indices(:)
363 real(real64),
contiguous,
intent(out) :: phi_mu(:, :)
365 integer :: ic, is, nst, n_int
366 integer(int64) :: ipg
370 write(message(1),
'(a)')
"ISDF Serial: Sampling phi(r) at mu"
371 call messages_info(1, debug_only=.
true.)
374 assert(
size(phi_mu, 1) == nst)
376 n_int =
size(indices)
377 assert(
size(phi_mu, 2) == n_int)
382 phi_mu(is, ic) = phi_r(is, ipg)
402 real(real64),
contiguous,
intent(inout) :: zct(:, :)
409 write(message(1),
'(a)')
"ISDF Serial: Constructing ZC^T"
410 call messages_info(1, debug_only=.
true.)
413 do j = 1,
size(zct, 2)
414 do i = 1,
size(zct, 1)
415 zct(i, j) = zct(i, j)**2
428 integer(int64),
contiguous,
intent(in ) :: indices(:)
429 real(real64),
contiguous,
intent(in ) :: zct(:, :)
431 real(real64),
contiguous,
intent(out) :: cct(:, :)
433 integer(int64) :: ipg
434 integer :: i_mu, i_nu, n_int
438 write(message(1),
'(a)')
"ISDF Serial: Constructing CC^T by sampling ZC^T"
439 call messages_info(1, debug_only=.
true.)
441 n_int =
size(indices)
442 assert(all(shape(cct) == [n_int, n_int]))
443 assert(
size(zct, 1) > n_int)
444 assert(
size(zct, 2) == n_int)
450 cct(i_mu, i_nu) = zct(ipg, i_nu)
468 real(real64),
contiguous,
intent(in ) :: phi(:, :)
470 real(real64),
contiguous,
intent(in ) :: phi_mu(:, :)
472 real(real64),
contiguous,
intent(out) :: P_r_mu(:, :)
480 write(message(1),
'(a)')
"ISDF Serial: Constructing P_r_mu"
481 call messages_info(1, debug_only=.
true.)
483 m_states =
size(phi, 1)
485 n_int =
size(phi_mu, 2)
487 assert(
size(phi_mu, 1) == m_states)
488 assert(
size(p_r_mu, 1) == np)
489 assert(
size(p_r_mu, 2) == n_int)
492 call lalg_gemm(phi, phi_mu, p_r_mu, transa=
'T')
508 real(real64),
contiguous,
intent(in ) :: phi_mu(:, :)
510 real(real64),
contiguous,
intent(out) :: P_mu_nu(:, :)
516 write(message(1),
'(a)')
"ISDF Serial: Constructing P_mu_nu"
517 call messages_info(1, debug_only=.
true.)
519 n_int =
size(phi_mu, 2)
520 assert(
size(p_mu_nu, 1) == n_int)
521 assert(
size(p_mu_nu, 2) == n_int)
524 call lalg_gemm(phi_mu, phi_mu, p_mu_nu, transa=
'T')
533 type(namespace_t),
intent(in) :: namespace
534 type(states_elec_t),
intent(in) :: st
535 class(space_t),
intent(in) :: space
536 class(mesh_t),
intent(in) :: mesh
537 class(ions_t),
pointer,
intent(in) :: ions
538 real(real64),
allocatable,
intent(inout) :: phi(:, :)
539 integer(int64),
contiguous,
intent(in) :: indices(:)
540 real(real64),
allocatable,
intent(inout) :: isdf_vectors(:, :)
541 logical,
intent(in) :: output_cubes
543 real(real64),
allocatable :: product_basis(:, :), approx_product_basis(:, :)
544 real(real64),
allocatable :: phi_mu(:, :), phi_occ(:, :)
545 real(real64),
allocatable :: product_error(:)
546 integer :: n_occ, n_products, n_int, i, j, ij, is, ip, unit
547 real(real64) :: mean_error
551 write(message(1),
'(a)')
"ISDF Serial: Computing exact pair products"
552 call messages_info(1, debug_only=.
true.)
554 assert(
size(phi, 2) == mesh%np)
557 n_occ = highest_occupied_index(st,
ik)
558 safe_allocate(phi_occ(1:n_occ, 1:mesh%np))
561 phi_occ(is, ip) = phi(is, ip)
564 safe_deallocate_a(phi)
566 n_products = n_occ * n_occ
567 safe_allocate(product_basis(1:n_products, 1:mesh%np))
568 call column_wise_khatri_rao_product(phi_occ, phi_occ, product_basis)
571 if (output_cubes)
then
576 write(message(1),
'(a)')
"ISDF Serial Test: Computing approximate pair products"
577 call messages_info(1, namespace=namespace, debug_only=.
true.)
580 n_int =
size(indices)
581 safe_allocate(phi_mu(1:n_occ, 1:n_int))
583 safe_deallocate_a(phi_occ)
585 safe_allocate(approx_product_basis(1:n_products, 1:mesh%np))
589 safe_deallocate_a(phi_mu)
590 safe_deallocate_a(isdf_vectors)
592 if (output_cubes)
then
594 approx_product_basis)
598 safe_allocate(product_error(1:n_products))
600 safe_deallocate_a(product_basis)
601 safe_deallocate_a(approx_product_basis)
603 if (mpi_world%is_root())
then
604 open(newunit=unit, file=
"isdf_error_serial.txt")
605 write(unit, *)
'Mean error', mean_error
610 write(unit, *) i, j, product_error(ij)
616 safe_deallocate_a(product_error)
633 real(real64),
contiguous,
intent(in ) :: psi_mu(:, :)
634 real(real64),
contiguous,
intent(in ) :: zeta(:, :)
635 real(real64),
contiguous,
intent(out) :: product_basis(:, :)
637 real(real64),
allocatable :: psi_ij_mu(:, :)
638 integer :: mn_states, n_int, np
642 mn_states =
size(psi_mu, 1)**2
644 n_int =
size(zeta, 2)
646 assert(
size(product_basis, 1) == mn_states)
647 assert(
size(product_basis, 2) == np)
649 safe_allocate(psi_ij_mu(1:mn_states, 1:n_int))
650 call column_wise_khatri_rao_product(psi_mu, psi_mu, psi_ij_mu)
653 call lalg_gemm(psi_ij_mu, zeta, product_basis, transb=
'T')
655 safe_deallocate_a(psi_ij_mu)
670 class(mesh_t),
intent(in ) :: mesh
671 real(real64),
contiguous,
intent(in ) :: product_basis(:, :)
672 real(real64),
contiguous,
intent(in ) :: approx_product_basis(:, :)
674 real(real64),
contiguous,
intent(out) :: error(:)
675 real(real64),
intent(out) :: mean_error
677 integer :: mn_states, np, ij, ip
681 mn_states =
size(product_basis, 1)
682 np =
size(product_basis, 2)
685 assert(mesh%np == np)
688 assert(all(shape(product_basis) == shape(approx_product_basis)))
691 assert(
size(error) == mn_states)
695 error(ij) = (product_basis(ij, 1) - approx_product_basis(ij, 1))**2
700 error(ij) = error(ij) + (product_basis(ij, ip) - approx_product_basis(ij, ip))**2
704 mean_error = 0.0_real64
706 error(ij) =
sqrt(mesh%volume_element * error(ij))
707 mean_error = mean_error + error(ij)
710 mean_error = mean_error / real(mn_states, real64)
719 type(namespace_t),
intent(in) :: namespace
720 class(space_t),
intent(in) :: space
721 class(mesh_t),
intent(in) :: mesh
722 class(ions_t),
pointer,
intent(in) :: ions
723 character(len=*),
intent(in) :: file_prefix
724 real(real64),
contiguous,
intent(in) :: data(:, :)
725 integer,
optional,
intent(in) :: limits(2)
727 integer :: m_states, limit_j, limit_i, i, j, ij, ierr
728 real(real64) :: size_data
729 character(len=4) :: i_char, j_char
730 character(len=120) :: file_name
733 size_data = real(
size(
data, 1), real64)
734 m_states = int(
sqrt(size_data))
736 if (
present(limits))
then
746 ij = j + (i - 1) * m_states
747 write(i_char,
'(I4)') i
748 write(j_char,
'(I4)') j
749 file_name = trim(adjustl(file_prefix)) // trim(adjustl(i_char)) //
'_' // trim(adjustl(j_char))
750 call dio_function_output(option__outputformat__cube,
"./cubes", trim(adjustl(file_name)), namespace, space, mesh, &
751 data(ij,:) , unit_one, ierr, pos=ions%pos, atoms=ions%atom)
double sqrt(double __x) __attribute__((__nothrow__
This module implements batches of mesh functions.
This module implements common operations on batches of mesh functions.
This module contains interfaces for BLAS routines You should not use these routines directly....
This module implements the underlying real-space grid.
Serial prototype for benchmarking and validating ISDF implementation.
subroutine, public serial_interpolative_separable_density_fitting_vectors(namespace, mesh, st, int_indices, phi, isdf_vectors)
Compute a set of ISDF interpolation vectors in serial, for code validation.
subroutine, public quantify_error_and_visualise(namespace, st, space, mesh, ions, phi, indices, isdf_vectors, output_cubes)
Wrapper for quantifying the error in the expansion of the product basis.
subroutine generate_product_state_cubes(namespace, space, mesh, ions, file_prefix, data, limits)
Helper function to output a set of pair product states.
subroutine collate_batches_get_state(mesh, st, max_state, psi)
Loop over states per block, which makes applying the maximum state limit much simpler Use this to com...
subroutine isdf_options_init(this, namespace, st)
Initialise isdf_inp_options_t instance.
subroutine sample_phi_at_centroids(phi_r, indices, phi_mu)
Sample KS states at centroid points.
subroutine collate_batches(np, st, max_state, psi)
Put batches into a single 2D array.
subroutine construct_density_matrix_all_centroids_packed(phi_mu, P_mu_nu)
@ brief Construct the density matrix with shape (n_int, n_int). Denoted packed, because it expects ph...
subroutine error_in_product_basis(mesh, product_basis, approx_product_basis, error, mean_error)
Quantify the error in the product basis expansion.
subroutine construct_zct(zct)
Construct the product of Z and C matrices from the element-wise product of the quasi-density matrix.
subroutine isdf_construct_interpolation_vectors_packed(namespace, mesh, phi, indices, isdf_vectors)
Compute a set of ISDF interpolation vectors, where intermediate quantities such as phi are constructe...
subroutine construct_cct(indices, zct, cct)
Construct the product from by masking the first dimension of .
integer, parameter ik
Hard-coded for Gamma-point, spin-unpolarised calculations.
subroutine construct_density_matrix_packed(phi, phi_mu, P_r_mu)
@ brief Construct the density matrix with shape (np, n_int). Denoted packed, because it expects phi i...
subroutine approximate_pair_products(psi_mu, zeta, product_basis)
Construct a set of approximate pair products using the ISDF interpolation vectors.
This module is intended to contain "only mathematical" functions and procedures.
This module defines functions over batches of mesh functions.
This module defines the meshes, which are used in Octopus.
type(mpi_grp_t), public mpi_world
integer, parameter, public smear_semiconductor
pure logical function, public states_are_real(st)
This module provides routines for communicating states when using states parallelization.
This module defines the unit system, used for input and output.
type(unit_t), public unit_angstrom
For XYZ files.
type(unit_t), public unit_one
some special units required for particular quantities
Class describing the electron system.