Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/source_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "source_io/to_wannier90_pw.h"
#include "source_io/write_elecstat_pot.h"
#include "source_io/module_parameter/parameter.h"
#include "source_hamilt/module_xc/xc_functional.h"

#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>
Expand Down
1 change: 1 addition & 0 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "source_pw/module_pwdft/forces.h"
#include "source_pw/module_pwdft/stress_pw.h"
#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional

#ifdef __DSP
#include "source_base/kernels/dsp/dsp_connector.h"
Expand Down
1 change: 0 additions & 1 deletion source/source_hamilt/module_xc/exx_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#define EXX_INFO_H

#include "source_lcao/module_ri/conv_coulomb_pot_k.h"
#include "xc_functional.h"

#include <vector>
#include <map>
Expand Down
1 change: 0 additions & 1 deletion source/source_hamilt/module_xc/xc_functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "source_base/global_variable.h"
#include "source_base/vector3.h"
#include "source_base/matrix.h"
#include "exx_info.h"
#include "source_basis/module_pw/pw_basis_k.h"
#include "source_estate/module_charge/charge.h"
#include "source_cell/unitcell.h"
Expand Down
1 change: 1 addition & 0 deletions source/source_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <algorithm>

#include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info
#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional
#ifdef __EXX
#include "source_lcao/module_ri/exx_abfs-jle.h"
#endif
Expand Down
142 changes: 0 additions & 142 deletions source/source_lcao/LCAO_hamilt.hpp

This file was deleted.

5 changes: 4 additions & 1 deletion source/source_lcao/spar_dh.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#ifndef SPAR_DH_H
#define SPAR_DH_H

#include "source_base/matrix.h"
#include "source_basis/module_ao/parallel_orbitals.h"
#include "source_basis/module_nao/two_center_bundle.h"
#include "source_basis/module_ao/ORB_read.h"
#include "source_cell/module_neighbor/sltk_atom_arrange.h"
#include "source_cell/module_neighbor/sltk_grid_driver.h"
#include "source_lcao/LCAO_HS_arrays.hpp"
#include "source_lcao/force_stress_arrays.h"
#include "source_lcao/hamilt_lcao.h"
#include <vector>

namespace sparse_format
Expand Down
174 changes: 172 additions & 2 deletions source/source_lcao/spar_exx.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,174 @@
#ifdef __EXX
#include "LCAO_hamilt.hpp"
#endif

#include "spar_exx.h"

// --------------------------------------------------------
// Header files needed for implementation only
// --------------------------------------------------------

#include <RI/global/Global_Func-2.h>
#include <RI/ri/Cell_Nearest.h>
#include <stdexcept>

#include "source_base/abfs-vector3_order.h"
#include "source_base/global_variable.h"
#include "source_base/timer.h"
#include "source_hamilt/module_xc/exx_info.h"
#include "source_io/module_parameter/parameter.h"
#include "source_lcao/module_ri/RI_2D_Comm.h"
#include "source_lcao/module_ri/RI_Util.hpp"

// --------------------------------------------------------
// Implementation of the cal_HR_exx function
// --------------------------------------------------------

namespace sparse_format
{

/**
* @brief Implementation of the cal_HR_exx function
*/
template <typename Tdata>
void cal_HR_exx(
const UnitCell& ucell,
const Parallel_Orbitals& pv,
LCAO_HS_Arrays& HS_Arrays,
const int& current_spin,
const double& sparse_threshold,
const int (&nmp)[3],
const std::vector<std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<Tdata>>>>& Hexxs)
{
ModuleBase::TITLE("sparse_format", "cal_HR_exx");
ModuleBase::timer::tick("sparse_format", "cal_HR_exx");

const Tdata frac = GlobalC::exx_info.info_global.hybrid_alpha;

std::map<int, std::array<double, 3>> atoms_pos;
for (int iat = 0; iat < ucell.nat; ++iat)
{
atoms_pos[iat] = RI_Util::Vector3_to_array3(ucell.atoms[ucell.iat2it[iat]].tau[ucell.iat2ia[iat]]);
}
const std::array<std::array<double, 3>, 3> latvec
= {RI_Util::Vector3_to_array3(ucell.a1),
RI_Util::Vector3_to_array3(ucell.a2),
RI_Util::Vector3_to_array3(ucell.a3)};

const std::array<int, 3> Rs_period = {nmp[0], nmp[1], nmp[2]};

RI::Cell_Nearest<int, int, 3, double, 3> cell_nearest;
cell_nearest.init(atoms_pos, latvec, Rs_period);

const std::vector<int> is_list
= (PARAM.inp.nspin != 4) ? std::vector<int>{current_spin} : std::vector<int>{0, 1, 2, 3};

for (const int is: is_list)
{
int is0_b = 0;
int is1_b = 0;
std::tie(is0_b, is1_b) = RI_2D_Comm::split_is_block(is);

if (Hexxs.empty())
{
break;
}

for (const auto& HexxA: Hexxs[is])
{
const int iat0 = HexxA.first;
for (const auto& HexxB: HexxA.second)
{
const int iat1 = HexxB.first.first;

const Abfs::Vector3_Order<int> R = RI_Util::array3_to_Vector3(
cell_nearest.get_cell_nearest_discrete(iat0, iat1, HexxB.first.second));

HS_Arrays.all_R_coor.insert(R);

const RI::Tensor<Tdata>& Hexx = HexxB.second;

for (size_t iw0 = 0; iw0 < Hexx.shape[0]; ++iw0)
{
const int iwt0 = RI_2D_Comm::get_iwt(ucell, iat0, iw0, is0_b);
const int iwt0_local = pv.global2local_row(iwt0);

if (iwt0_local < 0)
{
continue;
}

for (size_t iw1 = 0; iw1 < Hexx.shape[1]; ++iw1)
{
const int iwt1 = RI_2D_Comm::get_iwt(ucell, iat1, iw1, is1_b);
const int iwt1_local = pv.global2local_col(iwt1);

if (iwt1_local < 0)
{
continue;
}

if (std::abs(Hexx(iw0, iw1)) > sparse_threshold)
{
if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2)
{
auto& HR_sparse_ptr = HS_Arrays.HR_sparse[current_spin][R][iwt0];
double& HR_sparse = HR_sparse_ptr[iwt1];
HR_sparse += RI::Global_Func::convert<double>(frac * Hexx(iw0, iw1));
if (std::abs(HR_sparse) <= sparse_threshold)
{
HR_sparse_ptr.erase(iwt1);
}
}
else if (PARAM.inp.nspin == 4)
{
auto& HR_sparse_ptr = HS_Arrays.HR_soc_sparse[R][iwt0];

std::complex<double>& HR_sparse = HR_sparse_ptr[iwt1];

HR_sparse += RI::Global_Func::convert<std::complex<double>>(frac * Hexx(iw0, iw1));

if (std::abs(HR_sparse) <= sparse_threshold)
{
HR_sparse_ptr.erase(iwt1);
}
}
else
{
throw std::invalid_argument(std::string(__FILE__) + " line "
+ std::to_string(__LINE__));
}
}
}
}
}
}
}

ModuleBase::timer::tick("sparse_format", "cal_HR_exx");
}

// --------------------------------------------------------
// Explicit instantiations for double and complex<double> types
// --------------------------------------------------------

template void cal_HR_exx<double>(
const UnitCell& ucell,
const Parallel_Orbitals& pv,
LCAO_HS_Arrays& HS_Arrays,
const int& current_spin,
const double& sparse_thr,
const int (&nmp)[3],
const std::vector<std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<double>>>>& Hexxs);

template void cal_HR_exx<std::complex<double>>(
const UnitCell& ucell,
const Parallel_Orbitals& pv,
LCAO_HS_Arrays& HS_Arrays,
const int& current_spin,
const double& sparse_thr,
const int (&nmp)[3],
const std::vector<std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<std::complex<double>>>>>& Hexxs);

} // namespace sparse_format

#endif // __EXX

Loading
Loading