Source code for wannierberri.w90files.chk

from functools import cached_property
from time import time
import numpy as np
from .utility import readstr
from ..io import FortranFileR
from ..utility import alpha_A, beta_A


[docs] class CheckPoint: """ A class to store the data about wannierisation, written by Wannier90 Parameters ---------- seedname : str the prefix of the file (including relative/absolute path, but not including the extension `.chk`) kmesh_tol : float tolerance to distinguish different/same k-points bk_complete_tol : float tolerance for the completeness relation for finite-difference scheme """ def __init__(self, real_lattice=None, num_wann=None, num_bands=None, num_kpts=None, wannier_centers_cart=None, wannier_spreads=None, v_matrix=None, kpt_latt=None, mp_grid=None, kmesh_tol=1e-7, bk_complete_tol=1e-5, ): if real_lattice is not None: real_lattice = np.array(real_lattice, dtype=float) assert real_lattice.shape == (3, 3), f"real_lattice should be of shape (3, 3), but got {real_lattice.shape}" self.recip_lattice = 2 * np.pi * np.linalg.inv(real_lattice).T self.real_lattice = real_lattice else: self.recip_lattice = None self.real_lattice = None if wannier_centers_cart is not None: wannier_centers_cart = np.array(wannier_centers_cart, dtype=float) if num_wann is None: num_wann = wannier_centers_cart.shape[0] else: assert wannier_centers_cart.shape == (num_wann, 3), f"wannier_centers should be of shape ({num_wann}, 3), but got {wannier_centers_cart.shape}" self.wannier_centers_cart = wannier_centers_cart if wannier_spreads is not None: if num_wann is None: num_wann = wannier_spreads.shape[0] else: assert len(wannier_spreads) == num_wann, f"wannier_spreads should be of shape ({num_wann},), but got {len(wannier_spreads)}" self.wannier_spreads = wannier_spreads if kpt_latt is not None: kpt_latt = np.array(kpt_latt, dtype=float) if num_kpts is None: num_kpts = kpt_latt.shape[0] assert kpt_latt.shape == (num_kpts, 3), f"kpt_latt should be of shape ({num_kpts}, 3), but got {kpt_latt.shape}" self.kpt_latt = kpt_latt if mp_grid is not None: mp_grid = np.array(mp_grid, dtype=int) assert mp_grid.shape == (3,), f"mp_grid should be of shape (3,), but got {mp_grid.shape}" self.mp_grid = mp_grid self.kmesh_tol = kmesh_tol self.bk_complete_tol = bk_complete_tol if v_matrix is not None: self.v_matrix = np.array(v_matrix, dtype=complex) if num_kpts is None: num_kpts = self.v_matrix.shape[0] else: assert self.v_matrix.shape[0] == num_kpts, f"v_matrix should be of shape ({num_kpts}, num_bands, num_wann), but got {v_matrix.shape}" if num_bands is None: num_bands = self.v_matrix.shape[1] else: assert self.v_matrix.shape[1] == num_bands, f"v_matrix should be of shape (num_kpts, {num_bands}, num_wann), but got {v_matrix.shape}" if num_wann is None: num_wann = self.v_matrix.shape[2] else: assert self.v_matrix.shape[2] == num_wann, f"v_matrix should be of shape (num_kpts, num_bands, {num_wann}), but got {v_matrix.shape}" self.num_wann = num_wann self.num_bands = num_bands self.num_kpts = num_kpts def from_w90_file(self, seedname, kmesh_tol=1e-7, bk_complete_tol=1e-5): kmesh_tol = kmesh_tol # will be used in set_bk bk_complete_tol = bk_complete_tol # will be used in set_bk t0 = time() seedname = seedname.strip() FIN = FortranFileR(seedname + '.chk') readint = lambda: FIN.read_record('i4') readfloat = lambda: FIN.read_record('f8') def readcomplex(): a = readfloat() return a[::2] + 1j * a[1::2] print('Reading restart information from file ' + seedname + '.chk :') readstr(FIN) # comment line num_bands = readint()[0] num_exclude_bands = readint()[0] exclude_bands = readint() assert len(exclude_bands) == num_exclude_bands, f"read exclude_bands are {exclude_bands}, length={len(exclude_bands)} while num_exclude_bands={num_exclude_bands}" real_lattice = readfloat().reshape((3, 3), order='F') recip_lattice = readfloat().reshape((3, 3), order='F') assert np.linalg.norm(real_lattice.dot(recip_lattice.T) / (2 * np.pi) - np.eye(3)) < 1e-14, f"the read real and reciprocal lattices are not consistent {self.real_lattice.dot(self.recip_lattice.T) / (2 * np.pi)}!=identiy" num_kpts = readint()[0] mp_grid = readint() assert len(mp_grid) == 3 assert num_kpts == np.prod(mp_grid), f"the number of k-points is not consistent with the mesh {num_kpts}!={np.prod(mp_grid)}" kpt_latt = readfloat().reshape((num_kpts, 3)) nntot = readint()[0] num_wann = readint()[0] readstr(FIN) # checkpoint string have_disentangled = bool(readint()[0]) # print(f"have_disentangled={have_disentangled}") if have_disentangled: self.omega_invariant = readfloat()[0] lwindow = np.array(readint().reshape((num_kpts, num_bands)), dtype=bool) ndimwin = readint() # print(f"ndimwin={ndimwin}") u_matrix_opt = readcomplex().reshape((num_kpts, num_wann, num_bands)).swapaxes(1, 2) win_min = np.array([np.where(lwin)[0].min() for lwin in lwindow]) win_max = np.array([np.where(lwin)[0].max() for lwin in lwindow]) for ik in range(num_kpts): assert win_max[ik] - win_min[ik] + 1 == ndimwin[ik], f"win_max={win_max}, win_min={win_min}, ndimwin={ndimwin} - inconsistent" assert np.sum(lwindow[ik]) == ndimwin[ik], f"lwindow={lwindow}, ndimwin={ndimwin} - inconsistent" assert np.all(lwindow[ik, win_min[ik]:win_max[ik] + 1]) if win_min[ik] > 0: assert np.all(np.logical_not(lwindow[ik, :win_min[ik]])) if win_max[ik] < num_bands - 1: assert np.all(np.logical_not(lwindow[ik, win_max[ik] + 1:])) u_matrix = readcomplex().reshape((num_kpts, num_wann, num_wann)).swapaxes(1, 2) readcomplex().reshape((num_kpts, nntot, num_wann, num_wann)).swapaxes(2, 3) # m_matrix if have_disentangled: v_matrix = np.zeros((num_kpts, num_bands, num_wann), dtype=complex) for ik in range(num_kpts): u = u_matrix[ik] u_opt = u_matrix_opt[ik] nd = ndimwin[ik] assert np.linalg.norm(u_opt[nd:]) < 1e-12, f"u_opt[nd:]={u_opt[nd:]} - not zero" v_matrix[ik, lwindow[ik], :] = u_opt[:nd, :].dot(u) else: v_matrix = u_matrix wannier_centers_cart = readfloat().reshape((num_wann, 3)) wannier_spreads = readfloat().reshape((num_wann)) print(f"Time to read .chk : {time() - t0}") self.__init__(real_lattice=real_lattice, v_matrix=v_matrix, wannier_centers_cart=wannier_centers_cart, wannier_spreads=wannier_spreads, kmesh_tol=kmesh_tol, bk_complete_tol=bk_complete_tol, kpt_latt=kpt_latt, mp_grid=mp_grid, ) return self @property def wannierised(self): if not hasattr(self, "v_matrix"): return False elif self.v_matrix is None: return False else: return True
[docs] def spin_order_block_to_interlace(self): """ If the chk was obtain from a block ordering (like in old VASP versions), the ordering should be changed to interlace """ v_matrix = np.zeros((self.num_kpts, self.num_bands, self.num_wann), dtype=complex) v_matrix[:, :, 0::2] = self.v_matrix[:, :, :self.num_wann // 2] v_matrix[:, :, 1::2] = self.v_matrix[:, :, self.num_wann // 2:] self.v_matrix = v_matrix
[docs] def spin_order_interlace_to_block(self): """ If the chk was obtain from an interlace ordering, one may want to change the ordering to block """ v_matrix = np.zeros((self.num_kpts, self.num_bands, self.num_wann), dtype=complex) v_matrix[:, :, :self.num_wann // 2] = self.v_matrix[:, :, 0::2] v_matrix[:, :, self.num_wann // 2:] = self.v_matrix[:, :, 1::2] self.v_matrix = v_matrix
@cached_property def kpt_latt_int(self): """ Returns the k-points in the lattice basis """ return np.array(np.round(self.kpt_latt * self.mp_grid[None, :]), dtype=int)
[docs] def wannier_gauge(self, mat, ik1, ik2): """ Returns the matrix elements in the Wannier gauge Parameters ---------- mat : np.ndarray the matrix elements in the Hamiltonian gauge ik1, ik2 : int the indices of the k-points Returns ------- np.ndarray the matrix elements in the Wannier gauge """ # data should be of form NBxNBx ... - any form later if len(mat.shape) == 1: mat = np.diag(mat) assert mat.shape[:2] == (self.num_bands,) * 2, f"mat.shape={mat.shape}, num_bands={self.num_bands}" v1 = self.v_matrix[ik1].conj().T v2 = self.v_matrix[ik2] return np.tensordot(np.tensordot(v1, mat, axes=(1, 0)), v2, axes=(1, 0)).transpose( (0, -1) + tuple(range(1, mat.ndim - 1)))
[docs] def get_HH_q(self, eig): """ Returns the Hamiltonian matrix in the Wannier gauge Parameters ---------- eig : `~wannierberri.w90files.EIG` the eigenvalues of the Hamiltonian Returns ------- np.ndarray the Hamiltonian matrix in the Wannier gauge """ assert (eig.NK, eig.NB) == (self.num_kpts, self.num_bands), f"eig file has NK={eig.NK}, NB={eig.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" HH_q = np.array([self.wannier_gauge(E, ik, ik) for ik, E in enumerate(eig.data)]) return 0.5 * (HH_q + HH_q.transpose(0, 2, 1).conj())
[docs] def get_SS_q(self, spn): """ Returns the spin matrix in the Wannier gauge Parameters ---------- spn : `~wannierberri.w90files.SPN` the spin matrix Returns ------- np.ndarray the spin matrix in the Wannier gauge """ assert (spn.NK, spn.NB) == (self.num_kpts, self.num_bands), f"spn file has NK={spn.NK}, NB={spn.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" SS_q = np.array([self.wannier_gauge(S, ik, ik) for ik, S in enumerate(spn.data)]) return 0.5 * (SS_q + SS_q.transpose(0, 2, 1, 3).conj())
######### # Oscar # ########################################################################### # Depart from the original matrix elements in the ab initio mesh # (Hamiltonian gauge) to obtain the corresponding matrix elements in the # Wannier gauge. The last constitute the basis to construct the real-space # matrix elements for Wannier interpolation, independently of the # finite-difference scheme used.
[docs] def get_AABB_qb(self, mmn, transl_inv=False, eig=None, phase=None, sum_b=False): """ Returns the matrix elements AA or BB(if eig is not Flase) in the Wannier gauge Parameters ---------- mmn : `~wannierberri.w90files.MMN` the overlap matrix elements between the Wavefunctions at neighbouring k-points transl_inv : bool if True, the band-diagonal matrix elements are calculated using the Marzari & Vanderbilt translational invariant formula eig : `~wannierberri.w90files.EIG` the eigenvalues of the Hamiltonian, needed to calculate BB (if None, the matrix elements are AA) phase : np.ndarray(shape=(num_wann, num_wann, nnb), dtype=complex) the phase factors to be applied to the matrix elements (if None, no phase factors are applied) sum_b : bool if True, the matrix elements are summed over the neighbouring k-points. Otherwise, the matrix elements are stored in a 5D array of shape (num_kpts, num_wann, num_wann, nnb, 3) Returns ------- np.ndarray(shape=(num_kpts, num_wann, num_wann, nnb, 3), dtype=complex) (if sum_b=False) or np.ndarray(shape=(num_kpts, num_wann, num_wann, nnb, 3), dtype=complex) (if sum_b=True) the q-resolved matrix elements AA or BB in the Wannier gauge """ assert (not transl_inv) or eig is None, "transl_inv cannot be used for BB matrix elements" if sum_b: AA_qb = np.zeros((self.num_kpts, self.num_wann, self.num_wann, 3), dtype=complex) else: AA_qb = np.zeros((self.num_kpts, self.num_wann, self.num_wann, mmn.NNB, 3), dtype=complex) for ik in range(self.num_kpts): for ib in range(mmn.NNB): iknb = mmn.neighbours[ik, ib] ib_unique = mmn.ib_unique_map[ik, ib] # Matrix < u_k | u_k+b > (mmn) data = mmn.data[ik, ib] # Hamiltonian gauge if eig is not None: data = data * eig.data[ik, :, None] # Hamiltonian gauge (add energies) AAW = self.wannier_gauge(data, ik, iknb) # Wannier gauge # Matrix for finite-difference schemes AA_q_ik_ib = 1.j * AAW[:, :, None] * mmn.wk[ik, ib] * mmn.bk_cart[ik, ib, None, None, :] # Marzari & Vanderbilt formula for band-diagonal matrix elements if transl_inv: AA_q_ik_ib[range(self.num_wann), range(self.num_wann)] = -np.log( AAW.diagonal()).imag[:, None] * mmn.wk[ik, ib] * mmn.bk_cart[ik, ib, None, :] if phase is not None: AA_q_ik_ib *= phase[:, :, ib_unique, None] if sum_b: AA_qb[ik] += AA_q_ik_ib else: AA_qb[ik, :, :, ib_unique, :] = AA_q_ik_ib return AA_qb
# --- A_a(q,b) matrix --- #
[docs] def get_AA_qb(self, mmn, transl_inv=False, phase=None, sum_b=False): """ A wrapper for get_AABB_qb with eig=None see :meth:`~wannierberri.w90files.CheckPoint.get_AABB_qb` for more details """ return self.get_AABB_qb(mmn, transl_inv=transl_inv, phase=phase, sum_b=sum_b)
[docs] def get_AA_q(self, mmn, transl_inv=False): """ A wrapper for get_AA_qb with sum_b=True see :meth:`~wannierberri.w90files.CheckPoint.get_AA_qb` for more details """ return self.get_AA_qb(mmn=mmn, transl_inv=transl_inv, sum_b=True).sum(axis=3)
[docs] def get_wannier_centers(self, mmn, spreads=False): """ calculate wannier centers with the Marzarri-Vanderbilt translational invariant formula and optionally the spreads Parameters ---------- mmn : :class:`~wannierberri.w90files.MMN` the overlap matrix elements between the Wavefunctions at neighbouring k-points spreads : bool if True, the spreads are calculated Returns ------- np.ndarray(shape=(num_wann, 3), dtype=float) the wannier centers np.ndarray(shape=(num_wann,), dtype=float) the wannier spreads (in Angstrom^2) (if spreads=True) """ wcc = np.zeros((self.num_wann, 3), dtype=float) if spreads: r2 = np.zeros(self.num_wann, dtype=float) for ik in range(mmn.NK): for ib in range(mmn.NNB): iknb = mmn.neighbours[ik, ib] mmn_loc = self.wannier_gauge(mmn.data[ik, ib], ik, iknb) mmn_loc = mmn_loc.diagonal() log_loc = np.angle(mmn_loc) wcc += -log_loc[:, None] * mmn.wk[ik, ib] * mmn.bk_cart[ik, ib] if spreads: r2 += (1 - np.abs(mmn_loc) ** 2 + log_loc ** 2) * mmn.wk[ik, ib] wcc /= mmn.NK if spreads: return wcc, r2 / mmn.NK - np.sum(wcc**2, axis=1) else: return wcc
# --- B_a(q,b) matrix --- #
[docs] def get_BB_qb(self, mmn, eig, phase=None, sum_b=False): """ a wrapper for get_AABB_qb to evaluate BB matrix elements. (transl_inv is disabled) see :meth:`~wannierberri.w90files.CheckPoint.get_AABB_qb` for more details """ return self.get_AABB_qb(mmn, eig=eig, phase=phase, sum_b=sum_b)
[docs] def get_CCOOGG_qb(self, mmn, uhu, antisym=True, phase=None, sum_b=False): """ Returns the matrix elements CC, OO or GG in the Wannier gauge Parameters ---------- mmn : :class:`~wannierberri.w90files.MMN` the overlap matrix elements between the Wavefunctions at neighbouring k-points uhu : :class:`~wannierberri.w90files.UHU` or :class:`~wannierberri.w90files.UIU` the matrix elements uhu or uiu produced by pw2wannier90 antisym : bool if True, the antisymmetric piece of the matrix elements is calculated. Otherwise, the full matrix is calculated phase : np.ndarray(shape=(num_wann, num_wann, nnb), dtype=complex) the phase factors to be applied to the matrix elements (if None, no phase factors are applied) sum_b : bool if True, the matrix elements are summed over the neighbouring k-points. Otherwise, the matrix elements are stored in a 6D array of shape (num_kpts, num_wann, num_wann, nnb, nnb, 3) Returns ------- np.ndarray(shape=(num_kpts, num_wann, num_wann, nnb, nnb, 3), dtype=complex) (if sum_b=False) or np.ndarray(shape=(num_kpts, num_wann, num_wann, nnb, nnb, 3), dtype=complex) (if sum_b=True) the q-resolved matrix elements CC, OO or GG in the Wannier gauge """ nd_cart = 1 if antisym else 2 shape_NNB = () if sum_b else (mmn.NNB, mmn.NNB) shape = (self.num_kpts, self.num_wann, self.num_wann) + shape_NNB + (3,) * nd_cart CC_qb = np.zeros(shape, dtype=complex) if phase is not None: phase = np.reshape(phase, np.shape(phase)[:4] + (1,) * nd_cart) for ik in range(self.num_kpts): for ib1 in range(mmn.NNB): iknb1 = mmn.neighbours[ik, ib1] ib1_unique = mmn.ib_unique_map[ik, ib1] for ib2 in range(mmn.NNB): iknb2 = mmn.neighbours[ik, ib2] ib2_unique = mmn.ib_unique_map[ik, ib2] # Matrix < u_k+b1 | H_k | u_k+b2 > (uHu) data = uhu.data[ik, ib1, ib2] # Hamiltonian gauge CCW = self.wannier_gauge(data, iknb1, iknb2) # Wannier gauge if antisym: # Matrix for finite-difference schemes (takes antisymmetric piece only) CC_q_ik_ib = 1.j * CCW[:, :, None] * ( mmn.wk[ik, ib1] * mmn.wk[ik, ib2] * ( mmn.bk_cart[ik, ib1, alpha_A] * mmn.bk_cart[ik, ib2, beta_A] - mmn.bk_cart[ik, ib1, beta_A] * mmn.bk_cart[ik, ib2, alpha_A]))[None, None, :] else: # Matrix for finite-difference schemes (takes symmetric piece only) CC_q_ik_ib = CCW[:, :, None, None] * ( mmn.wk[ik, ib1] * mmn.wk[ik, ib2] * ( mmn.bk_cart[ik, ib1, :, None] * mmn.bk_cart[ik, ib2, None, :]))[None, None, :, :] if phase is not None: CC_q_ik_ib *= phase[:, :, ib1_unique, ib2_unique] if sum_b: CC_qb[ik] += CC_q_ik_ib else: CC_qb[ik, :, :, ib1_unique, ib2_unique] = CC_q_ik_ib return CC_qb
# --- C_a(q,b1,b2) matrix --- #
[docs] def get_CC_qb(self, mmn, uhu, phase=None, sum_b=False): """ A wrapper for get_CCOOGG_qb with antisym=True see :meth:`~wannierberri.w90files.CheckPoint.get_CCOOGG_qb` for more details """ return self.get_CCOOGG_qb(mmn, uhu, phase=phase, sum_b=sum_b)
# --- O_a(q,b1,b2) matrix --- #
[docs] def get_OO_qb(self, mmn, uiu, phase=None, sum_b=False): """ A wrapper for get_CCOOGG_qb with antisym=False see :meth:`~wannierberri.w90files.CheckPoint.get_CCOOGG_qb` for more details (actually, the same as :meth:`~wannierberri.w90files.CheckPoint.get_CC_qb`) """ return self.get_CCOOGG_qb(mmn, uiu, phase=phase, sum_b=sum_b)
# Symmetric G_bc(q,b1,b2) matrix
[docs] def get_GG_qb(self, mmn, uiu, phase=None, sum_b=False): """ A wrapper for get_CCOOGG_qb with antisym=False see :meth:`~wannierberri.w90files.CheckPoint.get_CCOOGG_qb` for more details """ return self.get_CCOOGG_qb(mmn, uiu, antisym=False, phase=phase, sum_b=sum_b)
########################################################################### def get_SH_q(self, spn, eig): SH_q = np.zeros((self.num_kpts, self.num_wann, self.num_wann, 3), dtype=complex) assert (spn.NK, spn.NB) == (self.num_kpts, self.num_bands), f"spn file has NK={spn.NK}, NB={spn.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" assert (eig.NK, eig.NB) == (self.num_kpts, self.num_bands), f"eig file has NK={eig.NK}, NB={eig.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" for ik in range(self.num_kpts): SH_q[ik, :, :, :] = self.wannier_gauge(spn.data[ik, :, :, :] * eig.data[ik, None, :, None], ik, ik) return SH_q
[docs] def get_SHA_q(self, shu, mmn, phase=None, sum_b=False): """ SHA or SA (if siu is used instead of shu) """ mmn.set_bk_chk(self) if sum_b: SHA_qb = np.zeros((self.num_kpts, self.num_wann, self.num_wann, 3, 3), dtype=complex) else: SHA_qb = np.zeros((self.num_kpts, self.num_wann, self.num_wann, mmn.NNB, 3, 3), dtype=complex) assert shu.NNB == mmn.NNB, f"shu.NNB={shu.NNB}, mmn.NNB={mmn.NNB} - mismatch" for ik in range(self.num_kpts): for ib in range(mmn.NNB): iknb = mmn.neighbours[ik, ib] ib_unique = mmn.ib_unique_map[ik, ib] SHAW = self.wannier_gauge(shu.data[ik, ib], ik, iknb) SHA_q_ik_ib = 1.j * SHAW[:, :, None, :] * mmn.wk[ik, ib] * mmn.bk_cart[ik, ib, None, None, :, None] if phase is not None: SHA_q_ik_ib *= phase[:, :, ib_unique, None, None] if sum_b: SHA_qb[ik] += SHA_q_ik_ib else: SHA_qb[ik, :, :, ib_unique, :, :] = SHA_q_ik_ib return SHA_qb
[docs] def get_SHR_q(self, spn, mmn, eig=None, phase=None): """ SHR or SR(if eig is None) """ mmn.set_bk_chk(self) SHR_q = np.zeros((self.num_kpts, self.num_wann, self.num_wann, 3, 3), dtype=complex) assert (spn.NK, spn.NB) == (self.num_kpts, self.num_bands), f"spn file has NK={spn.NK}, NB={spn.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" assert (mmn.NK, mmn.NB) == (self.num_kpts, self.num_bands), f"mmn file has NK={mmn.NK}, NB={mmn.NB}, while the checkpoint has NK={self.num_kpts}, NB={self.num_bands}" for ik in range(self.num_kpts): SH = spn.data[ik, :, :, :] if eig is not None: SH = SH * eig.data[ik, None, :, None] SHW = self.wannier_gauge(SH, ik, ik) for ib in range(mmn.NNB): iknb = mmn.neighbours[ik, ib] ib_unique = mmn.ib_unique_map[ik, ib] SHM = np.tensordot(SH, mmn.data[ik, ib], axes=((1,), (0,))).swapaxes(-1, -2) SHRW = self.wannier_gauge(SHM, ik, iknb) if phase is not None: SHRW = SHRW * phase[:, :, ib_unique, None] SHRW = SHRW - SHW SHR_q[ik, :, :, :, :] += 1.j * SHRW[:, :, None] * mmn.wk[ik, ib] * mmn.bk_cart[ik, ib, None, None, :, None] return SHR_q
def from_win(self, win): print("creating empty CheckPoint from Win file") mp_grid = np.array(win.data["mp_grid"]) kpt_latt = win.get_kpoints() real_lattice = win.get_unit_cell_cart_ang() try: num_wann = win["num_wann"] except KeyError: num_wann = None try: num_bands = win["num_bands"] except KeyError: num_bands = None self.__init__(real_lattice=real_lattice, num_wann=num_wann, num_bands=num_bands, kpt_latt=kpt_latt, mp_grid=mp_grid,) return self def select_bands(self, selected_bands): if selected_bands is not None: assert not self.wannierised, "v_matrix already set, cannot select bands" selected_bands_bool = np.zeros(self.num_bands, dtype=bool) selected_bands_bool[selected_bands] = True assert np.any(selected_bands_bool), "No bands selected" self.num_bands = sum(selected_bands_bool)