Source code for wannierberri.result.kbandresult

import numpy as np
from .result import Result
import itertools
import abc
from ..symmetry.point_symmetry import transform_from_dict


class K__Result(Result, abc.ABC):

    def __init__(self, data=None, transformTR=None, transformInv=None, file_npz=None, rank=None,
                 other_properties=None):
        if other_properties is None:
            other_properties = {}
        assert (data is not None) or (file_npz is not None)
        if file_npz is not None:
            res = np.load(open(file_npz, "rb"), allow_pickle=True)
            self.__init__(
                data=res['data'],
                transformTR=transform_from_dict(res, 'transformTR'),
                transformInv=transform_from_dict(res, 'transformInv'),
            )
        else:
            if data is not None:
                if isinstance(data, list):
                    self.data_list = data
                else:
                    self.data_list = [data]
            self.transformTR = transformTR
            self.transformInv = transformInv
        if rank is None:
            self.rank = self.get_rank()
        else:
            self.rank = rank
        self.other_properties = other_properties

    def get_rank(self):
        raise NotImplementedError()

    def fit(self, other):
        for var in ['transformTR', 'transformInv', 'rank']:
            if getattr(self, var) != getattr(other, var):
                print(f"parameters {var} are not fit : `{getattr(self, var)}` and `{getattr(other, var)}` ")
                return False
        return True

    @property
    def data(self):
        if len(self.data_list) > 1:
            self.data_list = [np.vstack(self.data_list)]
        return self.data_list[0]

    @property
    def nk(self):
        return sum(data.shape[0] for data in self.data_list)

    def __add__(self, other):
        assert self.fit(other)
        return self.__class__(data=self.data_list + other.data_list,
                              transformTR=self.transformTR,
                              transformInv=self.transformInv,
                              rank=self.rank,
                              other_properties=self.other_properties
                              )

    def add(self, other):
        self.data_list = [d1 + d2 for d1, d2 in zip(self.data_list, other.data_list)]

    def __mul__(self, number):
        return self.__class__(data=[d * number for d in self.data_list],
                              transformTR=self.transformTR,
                              transformInv=self.transformInv,
                              rank=self.rank,
                              other_properties=self.other_properties
                              )

    def mul_array(self, other, axes=None):
        if isinstance(axes, int):
            axes = (axes,)
        if axes is None:
            axes = tuple(range(other.ndim))
        axes = tuple((a + 1) for a in axes)  # because 0th dimension is k here
        for i, d in enumerate(other.shape):
            assert d == self.data_list[0].shape[axes[i]], \
                f"shapes  {other.shape} should match the axes {axes} of {self.data_list[0].shape}"
        reshape = tuple((self.data.shape[i] if i in axes else 1) for i in range(self.data_list[0].ndim))
        other_reshape = other.reshape(reshape)
        return self.__class__(
            data=[d * other_reshape for d in self.data_list],
            transformTR=self.transformTR,
            transformInv=self.transformInv,
            rank=self.rank,
            other_properties=self.other_properties
        )

    def __sub__(self, other):
        if (self.transformTR is not None) and (other.transformTR is not None):
            assert self.transformTR == other.transformTR
        if (self.transformInv is not None) and (other.transformInv is not None):
            assert self.transformInv == other.transformInv
        return KBandResult(
            data=self.data - other.data,
            transformTR=self.transformTR,
            transformInv=self.transformInv,
        )

    def __truediv__(self, number):
        return self * 1  # actually a copy

    def as_dict(self):
        """
        returns a dictionary-like object with the following keys:
        - 'E_titles' : list of str - titles of the energies on which the result depends
        - 'Energies_0', ['Energies_1', ... ] - corresponding arrays of energies
        - data : array of shape (len(Energies_0), [ len(Energies_1), ...] , [3  ,[ 3, ... ]] )
        """
        return dict(
            data=self.data,
            transformTR=self.transformTR.as_dict(),
            transformInv=self.transformInv.as_dict()
        )

    def to_grid(self, k_map):
        dataall = self.data
        data = np.array([sum(dataall[ik] for ik in km) / len(km) for km in k_map])
        return self.__class__(data=data,
                              transformTR=self.transformTR,
                              transformInv=self.transformInv,
                              rank=self.rank,
                              other_properties=self.other_properties
                              )

    def average_deg(self, deg):
        for i, D in enumerate(deg):
            for ib1, ib2 in D:
                for j in range(len(self.data_list)):
                    self.data_list[j][i, ib1:ib2] = self.data_list[j][i, ib1:ib2].mean(axis=0)
        return self

    def transform(self, sym):
        data = [sym.transform_tensor(data, rank=self.rank,
                                     transformTR=self.transformTR, transformInv=self.transformInv) for data in
                self.data_list]
        return self.__class__(data,
                              transformTR=self.transformTR,
                              transformInv=self.transformInv,
                              other_properties=self.other_properties,
                              rank=self.rank
                              )

    def get_component_list(self):
        dim = len(self.data.shape[2:])
        comp_list = ["".join(s) for s in itertools.product(*[("x", "y", "z")] * dim)]
        comp_list = [s for s in comp_list if len(s) > 0]
        if self.ndim >= 2:
            comp_list.append("trace")
        if len(comp_list) == 0:
            comp_list = [None]
        return comp_list

    @property
    def ndim(self):
        dims = np.array(self.data.shape[2:])
        if not np.all(dims == 3):
            raise RuntimeError(f"dimensions of all components should be 3, found {dims}")
        return len(dims)

    def get_component(self, component=None):
        return get_component(data=self.data, ndim=self.ndim, component=component)


[docs] class KBandResult(K__Result): def get_rank(self): return len(self.data_list[0].shape) - 2 def fit(self, other): if self.nband != other.nband: print(f"parameter 'nband' does not match : `{self.nband}` and `{other.nband}` ") return False return super().fit(other) @property def nband(self): return self.data_list[0].shape[1] def select_bands(self, ibands): return self.__class__(self.data[:, ibands], transformTR=self.transformTR, transformInv=self.transformInv, rank=self.rank, other_properties=self.other_properties )
class NoComponentError(RuntimeError): def __init__(self, comp, dim, err=""): # Call the base class constructor with the parameters it needs super().__init__(f"component `{comp}` does not exist for tensor with dimension {dim} :\n{err}") def get_component(data, ndim, component=None): xyz = {"x": 0, "y": 1, "z": 2} if isinstance(component, tuple): Xnk = np.copy(data) for k in component[-1::-1]: Xnk = Xnk[..., k] return Xnk elif isinstance(component, str) or component is None: if component is not None: component = component.lower() if ndim == 0: if component is None: return data else: raise NoComponentError(component, 0) elif ndim == 1: if component in ["x", "y", "z"]: return data[..., xyz[component]] elif component == 'norm': return np.linalg.norm(data, axis=-1) elif component == 'sq': return np.linalg.norm(data, axis=-1) ** 2 else: raise NoComponentError(component, 1) else: dims = tuple(np.arange(data.ndim)) _data = data.transpose(dims[-ndim:] + dims[:-ndim]) print(f"dims={dims}, data_shape={data.shape}, , _data_shape={_data.shape}") if component == "trace": return sum([_data[((i,) * ndim)] for i in range(3)]) else: try: return _data[tuple([xyz[c] for c in component])] except IndexError as err: raise NoComponentError(component, 2, str(err)) else: raise ValueError( f"component is given by `{component}` of type {type(component)}. Should be str or tuple")