diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 97823e7c1b..8768d3bce4 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -30,6 +30,7 @@ #include "source_io/to_wannier90_pw.h" #include "source_io/write_elecstat_pot.h" #include "source_io/module_parameter/parameter.h" +#include "source_hamilt/module_xc/xc_functional.h" #include #include diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b98fe6490e..11088a6daf 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -17,6 +17,7 @@ #include "source_pw/module_pwdft/forces.h" #include "source_pw/module_pwdft/stress_pw.h" +#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional #ifdef __DSP #include "source_base/kernels/dsp/dsp_connector.h" diff --git a/source/source_hamilt/module_xc/exx_info.h b/source/source_hamilt/module_xc/exx_info.h index 55ae60668f..c9584f9b60 100644 --- a/source/source_hamilt/module_xc/exx_info.h +++ b/source/source_hamilt/module_xc/exx_info.h @@ -2,7 +2,6 @@ #define EXX_INFO_H #include "source_lcao/module_ri/conv_coulomb_pot_k.h" -#include "xc_functional.h" #include #include diff --git a/source/source_hamilt/module_xc/xc_functional.h b/source/source_hamilt/module_xc/xc_functional.h index 067ceef7cf..86dcb4049d 100644 --- a/source/source_hamilt/module_xc/xc_functional.h +++ b/source/source_hamilt/module_xc/xc_functional.h @@ -15,7 +15,6 @@ #include "source_base/global_variable.h" #include "source_base/vector3.h" #include "source_base/matrix.h" -#include "exx_info.h" #include "source_basis/module_pw/pw_basis_k.h" #include "source_estate/module_charge/charge.h" #include "source_cell/unitcell.h" diff --git a/source/source_io/input_conv.cpp b/source/source_io/input_conv.cpp index d174cefeae..111040d812 100644 --- a/source/source_io/input_conv.cpp +++ b/source/source_io/input_conv.cpp @@ -14,6 +14,7 @@ #include #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info +#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional #ifdef __EXX #include "source_lcao/module_ri/exx_abfs-jle.h" #endif diff --git a/source/source_lcao/LCAO_hamilt.hpp b/source/source_lcao/LCAO_hamilt.hpp deleted file mode 100644 index 69179f64c1..0000000000 --- a/source/source_lcao/LCAO_hamilt.hpp +++ /dev/null @@ -1,142 +0,0 @@ -#include "source_io/module_parameter/parameter.h" - -#ifndef LCAO_HAMILT_HPP -#define LCAO_HAMILT_HPP - -#include "source_base/abfs-vector3_order.h" -#include "source_base/global_variable.h" -#include "source_base/timer.h" -#include "source_lcao/module_ri/RI_2D_Comm.h" -#include "source_lcao/spar_exx.h" - -#include -#include -#include -#include -#include -#include -#include - -#ifdef __EXX -// Peize Lin add 2022.09.13 - -template -void sparse_format::cal_HR_exx( - const UnitCell& ucell, - const Parallel_Orbitals& pv, - LCAO_HS_Arrays& HS_Arrays, - const int& current_spin, - const double& sparse_threshold, - const int (&nmp)[3], - const std::vector>, RI::Tensor>>>& Hexxs) -{ - ModuleBase::TITLE("sparse_format", "cal_HR_exx"); - ModuleBase::timer::tick("sparse_format", "cal_HR_exx"); - - const Tdata frac = GlobalC::exx_info.info_global.hybrid_alpha; - - std::map> atoms_pos; - for (int iat = 0; iat < ucell.nat; ++iat) - { - atoms_pos[iat] = RI_Util::Vector3_to_array3(ucell.atoms[ucell.iat2it[iat]].tau[ucell.iat2ia[iat]]); - } - const std::array, 3> latvec - = {RI_Util::Vector3_to_array3(ucell.a1), // too bad to use GlobalC here, - RI_Util::Vector3_to_array3(ucell.a2), - RI_Util::Vector3_to_array3(ucell.a3)}; - - const std::array Rs_period = {nmp[0], nmp[1], nmp[2]}; - - RI::Cell_Nearest cell_nearest; - cell_nearest.init(atoms_pos, latvec, Rs_period); - - const std::vector is_list - = (PARAM.inp.nspin != 4) ? std::vector{current_spin} : std::vector{0, 1, 2, 3}; - - for (const int is: is_list) - { - int is0_b = 0; - int is1_b = 0; - std::tie(is0_b, is1_b) = RI_2D_Comm::split_is_block(is); - - if (Hexxs.empty()) - { - break; - } - - for (const auto& HexxA: Hexxs[is]) - { - const int iat0 = HexxA.first; - for (const auto& HexxB: HexxA.second) - { - const int iat1 = HexxB.first.first; - - const Abfs::Vector3_Order R = RI_Util::array3_to_Vector3( - cell_nearest.get_cell_nearest_discrete(iat0, iat1, HexxB.first.second)); - - HS_Arrays.all_R_coor.insert(R); - - const RI::Tensor& Hexx = HexxB.second; - - for (size_t iw0 = 0; iw0 < Hexx.shape[0]; ++iw0) - { - const int iwt0 = RI_2D_Comm::get_iwt(ucell, iat0, iw0, is0_b); - const int iwt0_local = pv.global2local_row(iwt0); - - if (iwt0_local < 0) - { - continue; - } - - for (size_t iw1 = 0; iw1 < Hexx.shape[1]; ++iw1) - { - const int iwt1 = RI_2D_Comm::get_iwt(ucell, iat1, iw1, is1_b); - const int iwt1_local = pv.global2local_col(iwt1); - - if (iwt1_local < 0) - { - continue; - } - - if (std::abs(Hexx(iw0, iw1)) > sparse_threshold) - { - if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) - { - auto& HR_sparse_ptr = HS_Arrays.HR_sparse[current_spin][R][iwt0]; - double& HR_sparse = HR_sparse_ptr[iwt1]; - HR_sparse += RI::Global_Func::convert(frac * Hexx(iw0, iw1)); - if (std::abs(HR_sparse) <= sparse_threshold) - { - HR_sparse_ptr.erase(iwt1); - } - } - else if (PARAM.inp.nspin == 4) - { - auto& HR_sparse_ptr = HS_Arrays.HR_soc_sparse[R][iwt0]; - - std::complex& HR_sparse = HR_sparse_ptr[iwt1]; - - HR_sparse += RI::Global_Func::convert>(frac * Hexx(iw0, iw1)); - - if (std::abs(HR_sparse) <= sparse_threshold) - { - HR_sparse_ptr.erase(iwt1); - } - } - else - { - throw std::invalid_argument(std::string(__FILE__) + " line " - + std::to_string(__LINE__)); - } - } - } - } - } - } - } - - ModuleBase::timer::tick("sparse_format", "cal_HR_exx"); -} -#endif - -#endif diff --git a/source/source_lcao/spar_dh.h b/source/source_lcao/spar_dh.h index c176096e51..3df85fefc9 100644 --- a/source/source_lcao/spar_dh.h +++ b/source/source_lcao/spar_dh.h @@ -1,11 +1,14 @@ #ifndef SPAR_DH_H #define SPAR_DH_H +#include "source_base/matrix.h" +#include "source_basis/module_ao/parallel_orbitals.h" +#include "source_basis/module_nao/two_center_bundle.h" +#include "source_basis/module_ao/ORB_read.h" #include "source_cell/module_neighbor/sltk_atom_arrange.h" #include "source_cell/module_neighbor/sltk_grid_driver.h" #include "source_lcao/LCAO_HS_arrays.hpp" #include "source_lcao/force_stress_arrays.h" -#include "source_lcao/hamilt_lcao.h" #include namespace sparse_format diff --git a/source/source_lcao/spar_exx.cpp b/source/source_lcao/spar_exx.cpp index 6538da7072..b19eee7c39 100644 --- a/source/source_lcao/spar_exx.cpp +++ b/source/source_lcao/spar_exx.cpp @@ -1,4 +1,174 @@ #ifdef __EXX -#include "LCAO_hamilt.hpp" -#endif + +#include "spar_exx.h" + +// -------------------------------------------------------- +// Header files needed for implementation only +// -------------------------------------------------------- + +#include +#include +#include + +#include "source_base/abfs-vector3_order.h" +#include "source_base/global_variable.h" +#include "source_base/timer.h" +#include "source_hamilt/module_xc/exx_info.h" +#include "source_io/module_parameter/parameter.h" +#include "source_lcao/module_ri/RI_2D_Comm.h" +#include "source_lcao/module_ri/RI_Util.hpp" + +// -------------------------------------------------------- +// Implementation of the cal_HR_exx function +// -------------------------------------------------------- + +namespace sparse_format +{ + +/** + * @brief Implementation of the cal_HR_exx function + */ +template +void cal_HR_exx( + const UnitCell& ucell, + const Parallel_Orbitals& pv, + LCAO_HS_Arrays& HS_Arrays, + const int& current_spin, + const double& sparse_threshold, + const int (&nmp)[3], + const std::vector>, RI::Tensor>>>& Hexxs) +{ + ModuleBase::TITLE("sparse_format", "cal_HR_exx"); + ModuleBase::timer::tick("sparse_format", "cal_HR_exx"); + + const Tdata frac = GlobalC::exx_info.info_global.hybrid_alpha; + + std::map> atoms_pos; + for (int iat = 0; iat < ucell.nat; ++iat) + { + atoms_pos[iat] = RI_Util::Vector3_to_array3(ucell.atoms[ucell.iat2it[iat]].tau[ucell.iat2ia[iat]]); + } + const std::array, 3> latvec + = {RI_Util::Vector3_to_array3(ucell.a1), + RI_Util::Vector3_to_array3(ucell.a2), + RI_Util::Vector3_to_array3(ucell.a3)}; + + const std::array Rs_period = {nmp[0], nmp[1], nmp[2]}; + + RI::Cell_Nearest cell_nearest; + cell_nearest.init(atoms_pos, latvec, Rs_period); + + const std::vector is_list + = (PARAM.inp.nspin != 4) ? std::vector{current_spin} : std::vector{0, 1, 2, 3}; + + for (const int is: is_list) + { + int is0_b = 0; + int is1_b = 0; + std::tie(is0_b, is1_b) = RI_2D_Comm::split_is_block(is); + + if (Hexxs.empty()) + { + break; + } + + for (const auto& HexxA: Hexxs[is]) + { + const int iat0 = HexxA.first; + for (const auto& HexxB: HexxA.second) + { + const int iat1 = HexxB.first.first; + + const Abfs::Vector3_Order R = RI_Util::array3_to_Vector3( + cell_nearest.get_cell_nearest_discrete(iat0, iat1, HexxB.first.second)); + + HS_Arrays.all_R_coor.insert(R); + + const RI::Tensor& Hexx = HexxB.second; + + for (size_t iw0 = 0; iw0 < Hexx.shape[0]; ++iw0) + { + const int iwt0 = RI_2D_Comm::get_iwt(ucell, iat0, iw0, is0_b); + const int iwt0_local = pv.global2local_row(iwt0); + + if (iwt0_local < 0) + { + continue; + } + + for (size_t iw1 = 0; iw1 < Hexx.shape[1]; ++iw1) + { + const int iwt1 = RI_2D_Comm::get_iwt(ucell, iat1, iw1, is1_b); + const int iwt1_local = pv.global2local_col(iwt1); + + if (iwt1_local < 0) + { + continue; + } + + if (std::abs(Hexx(iw0, iw1)) > sparse_threshold) + { + if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2) + { + auto& HR_sparse_ptr = HS_Arrays.HR_sparse[current_spin][R][iwt0]; + double& HR_sparse = HR_sparse_ptr[iwt1]; + HR_sparse += RI::Global_Func::convert(frac * Hexx(iw0, iw1)); + if (std::abs(HR_sparse) <= sparse_threshold) + { + HR_sparse_ptr.erase(iwt1); + } + } + else if (PARAM.inp.nspin == 4) + { + auto& HR_sparse_ptr = HS_Arrays.HR_soc_sparse[R][iwt0]; + + std::complex& HR_sparse = HR_sparse_ptr[iwt1]; + + HR_sparse += RI::Global_Func::convert>(frac * Hexx(iw0, iw1)); + + if (std::abs(HR_sparse) <= sparse_threshold) + { + HR_sparse_ptr.erase(iwt1); + } + } + else + { + throw std::invalid_argument(std::string(__FILE__) + " line " + + std::to_string(__LINE__)); + } + } + } + } + } + } + } + + ModuleBase::timer::tick("sparse_format", "cal_HR_exx"); +} + +// -------------------------------------------------------- +// Explicit instantiations for double and complex types +// -------------------------------------------------------- + +template void cal_HR_exx( + const UnitCell& ucell, + const Parallel_Orbitals& pv, + LCAO_HS_Arrays& HS_Arrays, + const int& current_spin, + const double& sparse_thr, + const int (&nmp)[3], + const std::vector>, RI::Tensor>>>& Hexxs); + +template void cal_HR_exx>( + const UnitCell& ucell, + const Parallel_Orbitals& pv, + LCAO_HS_Arrays& HS_Arrays, + const int& current_spin, + const double& sparse_thr, + const int (&nmp)[3], + const std::vector>, RI::Tensor>>>>& Hexxs); + +} // namespace sparse_format + +#endif // __EXX diff --git a/source/source_lcao/spar_exx.h b/source/source_lcao/spar_exx.h index c12c4a14ce..4aca75aa5a 100644 --- a/source/source_lcao/spar_exx.h +++ b/source/source_lcao/spar_exx.h @@ -3,20 +3,44 @@ #ifdef __EXX -#include -#include +// -------------------------------------------------------- +// Header files - minimal set for declaration only +// -------------------------------------------------------- + +#include #include #include #include -#include "source_lcao/LCAO_HS_arrays.hpp" #include "source_basis/module_ao/parallel_orbitals.h" #include "source_cell/unitcell.h" +#include "source_lcao/LCAO_HS_arrays.hpp" + +// -------------------------------------------------------- +// Namespace - merged into one block +// -------------------------------------------------------- + namespace sparse_format { +/** + * @brief Calculate the Hamiltonian matrix elements in real space using EXX data in sparse format + * + * This function computes the Hamiltonian matrix elements in real space (HR) from EXX data, + * which is stored in a sparse tensor format. The results are added to the HS_Arrays structure. + * + * @tparam Tdata Data type for the matrix elements (double or std::complex) + * @param ucell Unit cell information + * @param pv Parallel orbitals information for distributed computation + * @param HS_Arrays Structure to store Hamiltonian and overlap matrix arrays + * @param current_spin Current spin channel (0 or 1 for spin-polarized calculations) + * @param sparse_thr Threshold for sparse matrix construction + * @param nmp Periodic boundary conditions in reciprocal space + * @param Hexxs EXX data stored as a nested map structure of tensors + */ template -void cal_HR_exx(const UnitCell& ucell, +void cal_HR_exx( + const UnitCell& ucell, const Parallel_Orbitals& pv, LCAO_HS_Arrays& HS_Arrays, const int& current_spin, @@ -24,7 +48,26 @@ void cal_HR_exx(const UnitCell& ucell, const int (&nmp)[3], const std::vector>, RI::Tensor>>>& Hexxs); +// Explicit instantiations for double and complex types +extern template void cal_HR_exx( + const UnitCell& ucell, + const Parallel_Orbitals& pv, + LCAO_HS_Arrays& HS_Arrays, + const int& current_spin, + const double& sparse_thr, + const int (&nmp)[3], + const std::vector>, RI::Tensor>>>& Hexxs); + +extern template void cal_HR_exx>( + const UnitCell& ucell, + const Parallel_Orbitals& pv, + LCAO_HS_Arrays& HS_Arrays, + const int& current_spin, + const double& sparse_thr, + const int (&nmp)[3], + const std::vector>, RI::Tensor>>>>& Hexxs); + } -#include "source_lcao/LCAO_hamilt.hpp" -#endif -#endif + +#endif // __EXX +#endif // SPARSE_FORMAT_EXX_H diff --git a/source/source_lcao/spar_hsr.cpp b/source/source_lcao/spar_hsr.cpp index b0c1ce2139..fb74a15f23 100644 --- a/source/source_lcao/spar_hsr.cpp +++ b/source/source_lcao/spar_hsr.cpp @@ -1,4 +1,5 @@ #include "spar_hsr.h" +#include "source_lcao/hamilt_lcao.h" #include "source_io/module_parameter/parameter.h" #include "source_lcao/module_hcontainer/hcontainer.h" diff --git a/source/source_lcao/spar_hsr.h b/source/source_lcao/spar_hsr.h index b9c1334914..b196d3be3b 100644 --- a/source/source_lcao/spar_hsr.h +++ b/source/source_lcao/spar_hsr.h @@ -2,9 +2,14 @@ #define SPARSE_FORMAT_HSR_H #include "source_lcao/LCAO_HS_arrays.hpp" -#include "source_lcao/hamilt_lcao.h" +#include "source_lcao/module_hcontainer/hcontainer.h" +#include "source_hamilt/hamilt.h" #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 +#ifdef __EXX +#include +#endif + namespace sparse_format { #ifdef __MPI diff --git a/source/source_lcao/spar_u.h b/source/source_lcao/spar_u.h index 5ac5c8a365..4940a3c5ff 100644 --- a/source/source_lcao/spar_u.h +++ b/source/source_lcao/spar_u.h @@ -1,11 +1,10 @@ #ifndef SPARSE_FORMAT_U_H #define SPARSE_FORMAT_U_H +#include "source_base/abfs-vector3_order.h" #include "source_cell/module_neighbor/sltk_atom_arrange.h" #include "source_cell/module_neighbor/sltk_grid_driver.h" -#include "source_lcao/hamilt_lcao.h" #include "source_lcao/module_dftu/dftu.h" // mohan add 20251107 -#include "source_base/abfs-vector3_order.h" namespace sparse_format { diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp index c5be374e1d..a949a806c7 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp @@ -3,6 +3,7 @@ #include "source_base/constants.h" #include "source_base/global_variable.h" #include "source_base/parallel_common.h" +#include "source_base/parallel_comm.h" // use KP_WORLD #include "source_base/parallel_reduce.h" #include "source_base/module_external/lapack_connector.h" #include "source_base/timer.h" diff --git a/source/source_pw/module_pwdft/stress_pw.cpp b/source/source_pw/module_pwdft/stress_pw.cpp index 919d1288bf..8cd8959ff9 100644 --- a/source/source_pw/module_pwdft/stress_pw.cpp +++ b/source/source_pw/module_pwdft/stress_pw.cpp @@ -1,9 +1,11 @@ #include "stress_pw.h" #include "source_base/timer.h" +#include "source_base/global_variable.h" // use GlobalC #include "source_hamilt/module_vdw/vdw.h" #include "source_io/output_log.h" #include "source_hamilt/module_xc/xc_functional.h" +#include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info template void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot,