# #
# This file is distributed as part of the WannierBerri code #
# under the terms of the GNU General Public License. See the #
# file `LICENSE' in the root directory of the WannierBerri #
# distribution, or http://www.gnu.org/copyleft/gpl.txt #
# #
# The WannierBerri code is hosted on GitHub: #
# https://github.com/stepan-tsirkin/wannier-berri #
# written by #
# Stepan Tsirkin, University of Zurich #
# some parts of this file are originate #
# from the translation of Wannier90 code #
#------------------------------------------------------------#
import numpy as np
import functools
import multiprocessing
from ..__utility import iterate3dpm, real_recip_lattice, fourier_q_to_R
from .system import System
from ..__w90_files import EIG, MMN, CheckPoint, SPN, UHU, SIU, SHU
from time import time
[docs]class System_w90(System):
"""
System initialized from the Wannier functions generated by `Wannier90 <http://wannier.org>`__ code.
Reads the ``.chk``, ``.eig`` and optionally ``.mmn``, ``.spn``, ``.uHu``, ``.sIu``, and ``.sHu`` files
Parameters
----------
seedname : str
the seedname used in Wannier90
transl_inv : bool
Use Eq.(31) of `Marzari&Vanderbilt PRB 56, 12847 (1997) <https://journals.aps.org/prb/abstract/10.1103/PhysRevB.56.12847>`_ for band-diagonal position matrix elements
guiding_centers : bool
If True, enable overwriting the diagonal elements of the AA_R matrix at R=0 with the
Wannier centers calculated from Wannier90.
npar : int
number of processes used in the constructor
fft : str
library used to perform the fast Fourier transform from **q** to **R**. ``fftw`` or ``numpy``. (practically does not affect performance,
anyway mostly time of the constructor is consumed by reading the input files)
kmesh_tol : float
tolerance to consider the b_k vectors (connecting to neighbouring k-points on the grid) belonging to the same shell
bk_complete_tol : float
tolerance to consider the set of b_k shells as complete.
Notes
-----
see also parameters of the :class:`~wannierberri.System`
"""
def __init__(
self,
seedname="wannier90",
transl_inv=True,
guiding_centers=False,
fft='fftw',
npar=multiprocessing.cpu_count(),
kmesh_tol=1e-7,
bk_complete_tol=1e-5,
**parameters):
self.set_parameters(**parameters)
self.npar = npar
self.seedname = seedname
chk = CheckPoint(self.seedname, kmesh_tol=kmesh_tol, bk_complete_tol=bk_complete_tol)
self.real_lattice, self.recip_lattice = real_recip_lattice(chk.real_lattice, chk.recip_lattice)
if self.mp_grid is None:
self.mp_grid = chk.mp_grid
self.iRvec, self.Ndegen = self.wigner_seitz(chk.mp_grid)
self.nRvec0 = len(self.iRvec)
self.num_wann = chk.num_wann
self.wannier_centers_cart_auto = chk.wannier_centers
eig = EIG(seedname)
if self.need_R_any(['AA','BB']):
mmn = MMN(seedname, npar=npar)
kpt_mp_grid = [
tuple(k) for k in np.array(np.round(chk.kpt_latt * np.array(chk.mp_grid)[None, :]), dtype=int) % chk.mp_grid
]
if (0, 0, 0) not in kpt_mp_grid:
raise ValueError(
"the grid of k-points read from .chk file is not Gamma-centered. Please, use Gamma-centered grids in the ab initio calculation"
)
fourier_q_to_R_loc = functools.partial(
fourier_q_to_R,
mp_grid=chk.mp_grid,
kpt_mp_grid=kpt_mp_grid,
iRvec=self.iRvec,
ndegen=self.Ndegen,
numthreads=npar,
fft=fft)
timeFFT = 0
HHq = chk.get_HH_q(eig)
t0 = time()
self.set_R_mat('Ham', fourier_q_to_R_loc(HHq))
timeFFT += time() - t0
if self.need_R_any('AA'):
AAq = chk.get_AA_q(mmn, transl_inv=transl_inv)
t0 = time()
self.set_R_mat('AA',fourier_q_to_R_loc(AAq))
timeFFT += time() - t0
if transl_inv:
wannier_centers_cart_new = np.diagonal(self.get_R_mat('AA')[:, :, self.iR0, :], axis1=0, axis2=1).transpose()
if not np.all(abs(wannier_centers_cart_new - self.wannier_centers_cart_auto) < 1e-6):
if guiding_centers:
print(
f"The read Wannier centers\n{self.wannier_centers_cart_auto}\n"
f"are different from the evaluated Wannier centers\n{wannier_centers_cart_new}\n"
"This can happen if guiding_centres was set to true in Wannier90.\n"
"Overwrite the evaluated centers using the read centers.")
for iw in range(self.num_wann):
self.get_R_mat('AA')[iw, iw, self.iR0, :] = self.wannier_centers_cart_auto[iw, :]
else:
raise ValueError(
f"the difference between read\n{self.wannier_centers_cart_auto}\n"
f"and evluated \n{wannier_centers_cart_new}\n wannier centers is\n"
f"{self.wannier_centers_cart_auto-wannier_centers_cart_new}\n"
"If guiding_centres was set to true in Wannier90, pass guiding_centers = True to System_w90."
)
if 'BB' in self.needed_R_matrices:
t0 = time()
self.set_R_mat('BB', fourier_q_to_R_loc(chk.get_AA_q(mmn, eig)))
timeFFT += time() - t0
if 'CC' in self.needed_R_matrices:
uhu = UHU(seedname)
t0 = time()
self.set_R_mat('CC', fourier_q_to_R_loc(chk.get_CC_q(uhu, mmn)))
timeFFT += time() - t0
del uhu
if self.need_R_any(['SS', 'SR', 'SH', 'SHR']):
spn = SPN(seedname)
t0 = time()
if self.need_R_any('SS'):
self.set_R_mat('SS' ,fourier_q_to_R_loc(chk.get_SS_q(spn)))
if self.need_R_any('SR'):
self.set_R_mat('SR' , fourier_q_to_R_loc(chk.get_SR_q(spn, mmn)))
if self.need_R_any('SH'):
self.set_R_mat('SH' , fourier_q_to_R_loc(chk.get_SH_q(spn, eig)))
if self.need_R_any('SHR'):
self.set_R_mat('SHR' , fourier_q_to_R_loc(chk.get_SHR_q(spn, mmn, eig)))
timeFFT += time() - t0
try:
del spn
except NameError:
pass
if 'SA' in self.needed_R_matrices:
siu = SIU(seedname)
t0 = time()
self.set_R_mat('SA', fourier_q_to_R_loc(chk.get_SA_q(siu, mmn)) )
timeFFT += time() - t0
del siu
if 'SHA' in self.needed_R_matrices:
shu = SHU(seedname)
t0 = time()
self.set_R_mat('SHA', fourier_q_to_R_loc(chk.get_SHA_q(shu, mmn)) )
timeFFT += time() - t0
del shu
print("time for FFT_q_to_R : {} s".format(timeFFT))
self.do_at_end_of_init()
print("Real-space lattice:\n", self.real_lattice)
def wigner_seitz(self, mp_grid):
ws_search_size = np.array([1] * 3)
dist_dim = np.prod((ws_search_size + 1) * 2 + 1)
origin = divmod((dist_dim + 1), 2)[0] - 1
real_metric = self.real_lattice.dot(self.real_lattice.T)
mp_grid = np.array(mp_grid)
irvec = []
ndegen = []
for n in iterate3dpm(mp_grid * ws_search_size):
dist = []
for i in iterate3dpm((1, 1, 1) + ws_search_size):
ndiff = n - i * mp_grid
dist.append(ndiff.dot(real_metric.dot(ndiff)))
dist_min = np.min(dist)
if abs(dist[origin] - dist_min) < 1.e-7:
irvec.append(n)
ndegen.append(np.sum(abs(dist - dist_min) < 1.e-7))
return np.array(irvec), np.array(ndegen)