Wave functions
Work in progress!
Wavefunctions in Octopus
The wave functions in Octopus are referred to as the states.
They are handled by the module states_abst_oct_m
, which is defined in states/states_abst.F90.
The exact way how states are stored in memory is flexible and depends on optimization (and accelerator) settings in the input file. In particular, the order of indices depends on the PACKED
setting.
States are stored in a hierarchy of ‘containers’. Concepts of this hierarchy include groups, batches and blocks.
The abstract class
The top level data structure, describing states is:
type, abstract :: states_abst_t
private
type(type_t), public :: wfs_type !< real (TYPE_FLOAT) or complex (TYPE_CMPLX) wavefunctions
integer, public :: nst !< Number of states in each irreducible subspace
logical, public :: packed
contains
procedure(nullify), deferred :: nullify
procedure(pack), deferred :: pack
procedure(unpack), deferred :: unpack
procedure(write_info), deferred :: write_info
procedure(set_zero), deferred :: set_zero
procedure, non_overridable :: are_packed
procedure, non_overridable :: get_type
end type states_abst_t
This structure contains mainly metadata about states, describing how states are represented, and defines the interface to the class. As it is an abstract class, it cannot contain any information about the actual system.
The states class for electrons
The class states_elec_t
, specialized the abstract class and contains more data specific to the electron system,
as well as pointers to other quantities, which are common to all states, such as the density, the current, etc.
The dimensions object contains a number of variables. The most relevant for this discussion is the dim
variable, which denotes the dimension of one state, being 1
for spin-less states and 2
for spinors.
type states_elec_dim_t
! Components are public by default
integer :: dim !< Dimension of the state (one, or two for spinors)
integer :: nik !< Number of irreducible subspaces
integer :: ispin !< spin mode (unpolarized, spin-polarized, spinors)
integer :: nspin !< dimension of rho (1, 2 or 4)
integer :: spin_channels !< 1 or 2, whether spin is or not considered.
FLOAT, allocatable :: kweights(:) !< weights for the k-point integrations
type(distributed_t) :: kpt
integer :: block_size
integer :: orth_method = 0
logical :: pack_states
FLOAT :: cl_states_mem
contains
procedure :: get_spin_index => states_elec_dim_get_spin_index
procedure :: get_kpoint_index => states_elec_dim_get_kpoint_index
end type states_elec_dim_t
The wave functions themselves are stored in
type(states_elec_group_t) :: group
which, in turn, is defined in the module states_elec_group_oct_m
in src/states/states_elec_group.F90
:
The group
contains all wave functions, grouped together in blocks or batches.
They are organised in an array of batch_t
structures.
type states_elec_group_t
! Components are public by default
type(wfs_elec_t), allocatable :: psib(:, :) !< A set of wave-functions blocks
integer :: nblocks !< The number of blocks
integer :: block_start !< The lowest index of local blocks
integer :: block_end !< The highest index of local blocks
integer, allocatable :: iblock(:, :) !< A map, that for each state index, returns the index of block containing it
integer, allocatable :: block_range(:, :) !< Each block contains states from block_range(:, 1) to block_range(:, 2)
integer, allocatable :: block_size(:) !< The number of states in each block.
logical, allocatable :: block_is_local(:, :) !< It is true if the block is in this node.
integer, allocatable :: block_node(:) !< The node that contains each block
integer, allocatable :: rma_win(:, :) !< The MPI window for one side communication
logical :: block_initialized = .false. !< For keeping track of the blocks to avoid memory leaks
end type states_elec_group_t
type(wfs_elec_t), pointer :: psib(:, :) !< A set of wave-functions
The indexing is as follows: psib(ib,iqb)
where ib
is the block index, and iqn
the k-point. See below for the routine states_init_block(st, mesh, verbose)
which creates the group
object. On a given node, only wave functions of local blocks are available.
The group
object does contain all information on how the batches are distributed over nodes.
type, extends(batch_t) :: wfs_elec_t
private
integer, public :: ik
logical, public :: has_phase
contains
procedure :: clone_to => wfs_elec_clone_to
procedure :: clone_to_array => wfs_elec_clone_to_array
procedure :: copy_to => wfs_elec_copy_to
procedure :: check_compatibility_with => wfs_elec_check_compatibility_with
procedure :: end => wfs_elec_end
end type wfs_elec_t
Creating the wave functions
A number of steps in initializing the states_t
object are called from the system_init()
routine:
states_init()
:
parses states-related input variables, and allocates memory for some book keeping variables. It does not allocate any memory for the states themselves.
states_distribute_nodes()
:
…
states_density_init()
:
allocates memory for the density (rho
) and the core density (rho_core
).
states_exec_init()
:
- Fills in the block size (
st\%d\%block_size
); - Finds out whether or not to pack the states (
st\%d\%pack_states
); - Finds out the orthogonalization method (
st\%d\%orth_method
).
Memory for the actual wave functions is allocated in states_elec_allocate_wfns()
which is called from the corresponding *_run()
routines, such as scf_run()
or td_run()
, etc.
subroutine states_elec_allocate_wfns(st, mesh, wfs_type, skip, packed)
type(states_elec_t), intent(inout) :: st
class(mesh_t), intent(in) :: mesh
type(type_t), optional, intent(in) :: wfs_type
logical, optional, intent(in) :: skip(:)
logical, optional, intent(in) :: packed
PUSH_SUB(states_elec_allocate_wfns)
if (present(wfs_type)) then
ASSERT(wfs_type == TYPE_FLOAT .or. wfs_type == TYPE_CMPLX)
st%wfs_type = wfs_type
end if
call states_elec_init_block(st, mesh, skip = skip, packed=packed)
call states_elec_set_zero(st)
POP_SUB(states_elec_allocate_wfns)
end subroutine states_elec_allocate_wfns
The routine states_init_block
initializes the data components in st
that describe how the states are distributed in blocks:
st%nblocks
: this is the number of blocks in which the states are divided.
Note that this number is the total number of blocks,
regardless of how many are actually stored in each node.
block_start
: in each node, the index of the first block.
block_end
: in each node, the index of the last block.
If the states are not parallelized, then block_start
is 1 and block_end
is st%nblocks
.
st%iblock(1:st%nst, 1:st%d%nik)
: it points, for each state, to the block that contains it.
st%block_is_local()
: st%block_is_local(ib)
is .true.
if block ib
is stored in the running node.
st%block_range(1:st%nblocks, 1:2)
: Block ib contains states fromn st%block_range(ib, 1) to st%block_range(ib, 2)
st%block_size(1:st%nblocks)
: Block ib contains a number st%block_size(ib) of states.
st%block_initialized
: it should be .false. on entry, and .true. after exiting this routine.
The set of batches st%psib(1:st%nblocks)
contains the block
s themselves.
subroutine states_elec_init_block(st, mesh, verbose, skip, packed)
type(states_elec_t), intent(inout) :: st
type(mesh_t), intent(in) :: mesh
logical, optional, intent(in) :: verbose
logical, optional, intent(in) :: skip(:)
logical, optional, intent(in) :: packed
integer :: ib, iqn, ist, istmin, istmax
logical :: same_node, verbose_, packed_
integer, allocatable :: bstart(:), bend(:)
PUSH_SUB(states_elec_init_block)
SAFE_ALLOCATE(bstart(1:st%nst))
SAFE_ALLOCATE(bend(1:st%nst))
SAFE_ALLOCATE(st%group%iblock(1:st%nst, 1:st%d%nik))
st%group%iblock = 0
verbose_ = optional_default(verbose, .true.)
packed_ = optional_default(packed, .false.)
!In case we have a list of state to skip, we do not allocate them
istmin = 1
if (present(skip)) then
do ist = 1, st%nst
if (.not. skip(ist)) then
istmin = ist
exit
end if
end do
end if
istmax = st%nst
if (present(skip)) then
do ist = st%nst, istmin, -1
if (.not. skip(ist)) then
istmax = ist
exit
end if
end do
end if
if (present(skip) .and. verbose_) then
call messages_write('Info: Allocating states from ')
call messages_write(istmin, fmt = 'i8')
call messages_write(' to ')
call messages_write(istmax, fmt = 'i8')
call messages_info()
end if
! count and assign blocks
ib = 0
st%group%nblocks = 0
bstart(1) = istmin
do ist = istmin, istmax
ib = ib + 1
st%group%iblock(ist, st%d%kpt%start:st%d%kpt%end) = st%group%nblocks + 1
same_node = .true.
if (st%parallel_in_states .and. ist /= istmax) then
! We have to avoid that states that are in different nodes end
! up in the same block
same_node = (st%node(ist + 1) == st%node(ist))
end if
if (ib == st%d%block_size .or. ist == istmax .or. .not. same_node) then
ib = 0
st%group%nblocks = st%group%nblocks + 1
bend(st%group%nblocks) = ist
if (ist /= istmax) bstart(st%group%nblocks + 1) = ist + 1
end if
end do
SAFE_ALLOCATE(st%group%psib(1:st%group%nblocks, st%d%kpt%start:st%d%kpt%end))
SAFE_ALLOCATE(st%group%block_is_local(1:st%group%nblocks, st%d%kpt%start:st%d%kpt%end))
st%group%block_is_local = .false.
st%group%block_start = -1
st%group%block_end = -2 ! this will make that loops block_start:block_end do not run if not initialized
do ib = 1, st%group%nblocks
if (bstart(ib) >= st%st_start .and. bend(ib) <= st%st_end) then
if (st%group%block_start == -1) st%group%block_start = ib
st%group%block_end = ib
do iqn = st%d%kpt%start, st%d%kpt%end
st%group%block_is_local(ib, iqn) = .true.
if (states_are_real(st)) then
call dwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
special=.true., packed=packed_)
else
call zwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
special=.true., packed=packed_)
end if
end do
end if
end do
SAFE_ALLOCATE(st%group%block_range(1:st%group%nblocks, 1:2))
SAFE_ALLOCATE(st%group%block_size(1:st%group%nblocks))
st%group%block_range(1:st%group%nblocks, 1) = bstart(1:st%group%nblocks)
st%group%block_range(1:st%group%nblocks, 2) = bend(1:st%group%nblocks)
st%group%block_size(1:st%group%nblocks) = bend(1:st%group%nblocks) - bstart(1:st%group%nblocks) + 1
st%group%block_initialized = .true.
SAFE_ALLOCATE(st%group%block_node(1:st%group%nblocks))
ASSERT(allocated(st%node))
ASSERT(all(st%node >= 0) .and. all(st%node < st%mpi_grp%size))
do ib = 1, st%group%nblocks
st%group%block_node(ib) = st%node(st%group%block_range(ib, 1))
ASSERT(st%group%block_node(ib) == st%node(st%group%block_range(ib, 2)))
end do
if (verbose_) then
call messages_write('Info: Blocks of states')
call messages_info()
do ib = 1, st%group%nblocks
call messages_write(' Block ')
call messages_write(ib, fmt = 'i8')
call messages_write(' contains ')
call messages_write(st%group%block_size(ib), fmt = 'i8')
call messages_write(' states')
if (st%group%block_size(ib) > 0) then
call messages_write(':')
call messages_write(st%group%block_range(ib, 1), fmt = 'i8')
call messages_write(' - ')
call messages_write(st%group%block_range(ib, 2), fmt = 'i8')
end if
call messages_info()
end do
end if
!!$!!!!DEBUG
!!$ ! some debug output that I will keep here for the moment
!!$ if (mpi_grp_is_root(mpi_world)) then
!!$ print*, "NST ", st%nst
!!$ print*, "BLOCKSIZE ", st%d%block_size
!!$ print*, "NBLOCKS ", st%group%nblocks
!!$
!!$ print*, "==============="
!!$ do ist = 1, st%nst
!!$ print*, st%node(ist), ist, st%group%iblock(ist, 1)
!!$ end do
!!$ print*, "==============="
!!$
!!$ do ib = 1, st%group%nblocks
!!$ print*, ib, bstart(ib), bend(ib)
!!$ end do
!!$
!!$ end if
!!$!!!!ENDOFDEBUG
SAFE_DEALLOCATE_A(bstart)
SAFE_DEALLOCATE_A(bend)
POP_SUB(states_elec_init_block)
end subroutine states_elec_init_block
The allocation of memory for the actual wave functions is performed in batch_init_empty()
and X(batch_allocate)()
. This routine, and the related X(batch_add_state)()
show most clearly how the different memory blocks are related.
batch_init_empty()
allocates the memory for batch_state_t
states
and batch_states_l_t
states_linear
and nullifies the pointers within this types. Note that no memory for the actual wave functions has been allocated yet.
Questions:
How are the different objects pointing to states related?
- The usual storage for states is in
states_linear
which can be shadowed bypack
in case packed states are used.
What is the difference between batch_add_state
and batch_add_state_linear
?