# #
# This file is distributed as part of the WannierBerri code #
# under the terms of the GNU General Public License. See the #
# file `LICENSE' in the root directory of the WannierBerri #
# distribution, or http://www.gnu.org/copyleft/gpl.txt #
# #
# The WannierBerri code is hosted on GitHub: #
# https://github.com/stepan-tsirkin/wannier-berri #
# written by #
# Stepan Tsirkin, University of Zurich #
# #
# ------------------------------------------------------------
import numpy as np
import abc
from functools import cached_property
from ..parallel import pool
from ..system.system import System
from ..grid import TetraWeights, TetraWeightsParal, get_bands_in_range, get_bands_below_range
from .. import formula
from ..grid import KpointBZparallel, KpointBZtetra
from ..symmetry.point_symmetry import transform_ident, transform_odd
from .sdct_K import SDCT_K
def _rotate_matrix(X):
return X[1].T.conj().dot(X[0]).dot(X[1])
def get_transform_Inv(name, der=0):
"""returns the transformation of the quantity under inversion
raises for unknown quantities"""
###########
# Oscar ###
###########################################################################
if name in ['Ham', 'CC', 'FF', 'OO', 'GG', 'SS']: # even before derivative
p = 0
###########################################################################
elif name in ['D', 'AA', 'BB', 'CCab']:
return None
else:
raise ValueError(f"parity under inversion unknown for {name}")
if (p + der) % 2 == 1:
return transform_odd
else:
return transform_ident
def get_transform_TR(name, der=0):
"""returns transformation of quantity is under TR, (after a real trace is taken, if appropriate)
False otherwise
raises ValueError for unknown quantities"""
if name in ['Ham']: # even before derivative
p = 0
#########
# Oscar #
###########################################################################
elif name in ['CC', 'FF', 'OO', 'GG', 'SS']: # odd before derivative
p = 1
###########################################################################
elif name in ['D', 'AA', 'BB', 'CCab']:
return None
else:
raise ValueError(f"parity under TR unknown for {name}")
if (p + der) % 2 == 1:
return transform_odd
else:
return transform_ident
[docs]
class Data_K(System, abc.ABC):
"""
class to store many data calculated on a specific FFT grid.
The stored data can be used to evaluate many quantities.
Is destroyed after everything is evaluated for the FFT grid
Parameters
-----------
random_gauge : bool
applies random unitary rotations to degenerate states. Needed only for testing, to make sure that gauge
covariance is preserved.
degen_thresh_random_gauge : float
threshold to consider bands as degenerate for random_gauge
fftlib : str
library used to perform fftlib : 'fftw' (defgault) or 'numpy' or 'slow'
"""
# Those are not used at the moment , but will be restored (TODO):
# frozen_max : float
# position of the upper edge of the frozen window. Used in the evaluation of orbital moment. But not necessary.
# If not specified, attempts to read this value from system. Othewise set to
# delta_fz:float
# size of smearing for B matrix with frozen window, from frozen_max-delta_fz to frozen_max.
def __init__(self, system, dK, grid, Kpoint=None,
# Those are not used at the moment, but will be restored (TODO):
# frozen_max = -np.inf,
# delta_fz = 0.1,
Emin=-np.inf,
Emax=np.inf,
fftlib='fftw',
npar_k=1,
random_gauge=False,
degen_thresh_random_gauge=1e-4
):
self.system = system
self.Emin = Emin
self.Emax = Emax
self.fftlib = fftlib
self.npar_k = npar_k
self.random_gauge = random_gauge
self.degen_threshold_random_gauge = degen_thresh_random_gauge
self.force_internal_terms_only = system.force_internal_terms_only
self.grid = grid
self.NKFFT = grid.FFT
self.select_K = np.ones(self.nk, dtype=bool)
# self.findif = grid.findif
self.real_lattice = system.real_lattice
self.num_wann = self.system.num_wann
self.Kpoint = Kpoint
self.nkptot = self.NKFFT[0] * self.NKFFT[1] * self.NKFFT[2]
self.poolmap = pool(self.npar_k)[0]
self.dK = dK
self._bar_quantities = {}
self._covariant_quantities = {}
###########################################
# Now the **_R objects are evaluated only on demand
# - as cached_property (if used more than once)
# as property - iif used only once
# let's write them explicitly, for better code readability
###########################
@property
def is_phonon(self):
return self.system.is_phonon
###############################################################
###########
# TOOLS #
###########
def _rotate(self, mat):
assert mat.ndim > 2
if mat.ndim == 3:
return np.array(self.poolmap(_rotate_matrix, zip(mat, self.UU_K)))
else:
for i in range(mat.shape[-1]):
mat[..., i] = self._rotate(mat[..., i])
return mat
#####################
# Basic variables #
#####################
@cached_property
def nbands(self):
return self.num_wann
@cached_property
def kpoints_all(self):
return (self.grid.points_FFT + self.dK[None]) % 1
@cached_property
def nk(self):
return np.prod(self.NKFFT)
@cached_property
def tetraWeights(self):
if isinstance(self.Kpoint, KpointBZparallel):
return TetraWeightsParal(eCenter=self.E_K, eCorners=self.E_K_corners_parallel())
elif isinstance(self.Kpoint, KpointBZtetra):
return TetraWeights(eCenter=self.E_K, eCorners=self.E_K_corners_tetra())
else:
raise RuntimeError()
[docs]
def get_bands_in_range_groups_ik(self, ik, emin, emax, degen_thresh=-1, degen_Kramers=False, sea=False,
Emin=-np.inf, Emax=np.inf):
bands_in_range = get_bands_in_range(
emin, emax, self.E_K[ik], degen_thresh=degen_thresh, degen_Kramers=degen_Kramers)
weights = {(ib1, ib2): self.E_K[ik, ib1:ib2].mean() for ib1, ib2 in bands_in_range}
if sea:
bandmax = get_bands_below_range(emin, self.E_K[ik])
if len(bands_in_range) > 0:
bandmax = min(bandmax, bands_in_range[0][0])
if bandmax > 0:
weights[(0, bandmax)] = -np.inf
return weights
[docs]
def get_bands_in_range_groups(self, emin, emax, degen_thresh=-1, degen_Kramers=False, sea=False, Emin=-np.inf,
Emax=np.inf):
res = []
for ik in range(self.nk):
res.append(self.get_bands_in_range_groups_ik(ik, emin, emax, degen_thresh, degen_Kramers, sea, Emin=Emin,
Emax=Emax))
return res
###################################################
# Basic variables and their standard derivatives #
###################################################
[docs]
def select_bands(self, energies):
if hasattr(self, 'bands_selected'):
return
energies = energies.reshape((energies.shape[0], -1, energies.shape[-1]))
select = np.any(energies > self.Emin, axis=1) * np.any(energies < self.Emax, axis=1)
self.select_K = np.any(select, axis=1)
self.select_B = np.any(select, axis=0)
self.nk_selected = self.select_K.sum()
self.nb_selected = self.select_B.sum()
self.bands_selected = True
@cached_property
def E_K(self):
EUU = self.poolmap(np.linalg.eigh, self.HH_K)
E_K = self.phonon_freq_from_square(np.array([euu[0] for euu in EUU]))
# print ("E_K = ",E_K.min(), E_K.max(), E_K.mean())
self.select_bands(E_K)
self._UU = np.array([euu[1] for euu in EUU])[self.select_K, :][:, self.select_B]
return E_K[self.select_K, :][:, self.select_B]
# evaluate the energies in the corners of the parallelepiped, in order to use tetrahedron method
[docs]
def phonon_freq_from_square(self, E):
"""takes sqrt(|E|)*sign(E) for phonons, returns input for electrons"""
if self.is_phonon:
e = np.sqrt(np.abs(E))
e[E < 0] = -e[E < 0]
return e
else:
return E
@property
@abc.abstractmethod
def HH_K(self):
"""returns Wannier Hamiltonian for all points of the FFT grid"""
@cached_property
def delE_K(self):
delE_K = np.einsum("klla->kla", self.Xbar('Ham', 1))
check = np.abs(delE_K).imag.max()
if check > 1e-10:
raise RuntimeError(f"The band derivatives have considerable imaginary part: {check}")
return delE_K.real
[docs]
def covariant(self, name, commader=0, gender=0, save=True):
assert commader * gender == 0, "cannot mix comm and generalized derivatives"
key = (name, commader, gender)
if key not in self._covariant_quantities:
if gender == 0:
res = formula.Matrix_ln(
self.Xbar(name, commader),
transformTR=get_transform_TR(name, commader),
transformInv=get_transform_Inv(name, commader),
)
elif gender == 1:
if name == 'Ham':
res = self.V_covariant
else:
res = formula.Matrix_GenDer_ln(
self.covariant(name),
self.covariant(name, commader=1),
self.Dcov,
transformTR=get_transform_TR(name, gender),
transformInv=get_transform_Inv(name, gender)
)
else:
raise NotImplementedError()
if not save:
return res
else:
self._covariant_quantities[key] = res
return self._covariant_quantities[key]
@property
def V_covariant(self):
class V(formula.Matrix_ln):
def __init__(self, matrix):
super().__init__(matrix, transformTR=transform_odd, transformInv=transform_odd)
def ln(self, ik, inn, out):
return np.zeros((len(out), len(inn), 3), dtype=complex)
return V(self.Xbar('Ham', der=1))
@cached_property
def Dcov(self):
return formula.covariant.Dcov(self)
@cached_property
def dEig_inv(self):
dEig_threshold = 1e-7
dEig = self.E_K[:, :, None] - self.E_K[:, None, :]
select = abs(dEig) < dEig_threshold
dEig[select] = dEig_threshold
dEig = 1. / dEig
dEig[select] = 0.
return dEig
# defining sets of degenerate states - needed only for testing with random_gauge
@cached_property
def degen(self):
A = [np.where(E[1:] - E[:-1] > self.degen_thresh_random_gauge)[0] + 1 for E in self.E_K]
A = [[
0,
] + list(a) + [len(E)] for a, E in zip(A, self.E_K)]
return [[(ib1, ib2) for ib1, ib2 in zip(a, a[1:]) if ib2 - ib1 > 1] for a in A]
@cached_property
def UU_K(self):
self.E_K
# the following is needed only for testing :
if self.random_gauge:
from scipy.stats import unitary_group
cnt = 0
s = 0
for ik, deg in enumerate(self.true):
for ib1, ib2 in deg:
self._UU[ik, :, ib1:ib2] = self._UU[ik, :, ib1:ib2].dot(unitary_group.rvs(ib2 - ib1))
cnt += 1
s += ib2 - ib1
return self._UU
@cached_property
def D_H(self):
return -self.Xbar('Ham', 1) * self.dEig_inv[:, :, :, None]
@cached_property
def A_H(self):
'''Generalized Berry connection matrix, A^(H) as defined in eqn. (25) of 10.1103/PhysRevB.74.195118.'''
return self.Xbar('AA') + 1j * self.D_H
@property
def A_H_internal(self):
'''Generalized Berry connection matrix, A^(H) as defined in eqn. (25) of 10.1103/PhysRevB.74.195118. only internal term'''
return 1j * self.D_H
@cached_property
def SDCT(self):
"""returns the SDC term"""
return SDCT_K(self)
#########################################################################################################################################