Source code for wannierberri.run

#                                                            #
# This file is distributed as part of the WannierBerri code  #
# under the terms of the GNU General Public License. See the #
# file `LICENSE' in the root directory of the WannierBerri   #
# distribution, or http://www.gnu.org/copyleft/gpl.txt       #
#                                                            #
# The WannierBerri code is hosted on GitHub:                 #
# https://github.com/stepan-tsirkin/wannier-berri            #
#                     written by                             #
#           Stepan Tsirkin, University of Zurich             #
#                                                            #
# ------------------------------------------------------------

import numpy as np
from collections.abc import Iterable
from time import time
import pickle
import glob
from termcolor import cprint
import warnings
from .utility import remove_file
from .data_K import get_data_k
from .grid import exclude_equiv_points, Path, Grid, GridTetra
from .parallel import Serial
from .result import ResultDict


def print_progress(count, total, t0, tprev, print_progress_step):
    t = time() - t0
    if count == 0:
        t_remain = "unknown"
    else:
        t_rem_s = t / count * (total - count)
        t_remain = f"{t_rem_s:22.1f}"
    if t - tprev > print_progress_step:
        print(f"{count:20d}{t:17.1f}{t_remain:>22s}", flush=True)
        tprev = t
    return tprev


def process(paralfunc, K_list, parallel, pointgroup=None, remote_parameters=None, print_progress_step=5):
    if remote_parameters is None:
        remote_parameters = {}
    # print(f"pointgroup : {pointgroup}")
    t0 = time()
    t_print_prev = t0
    selK = [ik for ik, k in enumerate(K_list) if k.res is None]
    numK = len(selK)
    dK_list = [K_list[ik] for ik in selK]
    if len(dK_list) == 0:
        print("nothing to process now")
        return 0

    print(f"processing {len(dK_list)} K points :", end=" ")
    if parallel.method == 'serial':
        print("in serial.")
    else:
        print(f"using  {parallel.npar_K} processes.")

    print("# K-points calculated  Wall time (sec)  Est. remaining (sec)", flush=True)
    res = []
    nstep_print = parallel.progress_step(numK, parallel.npar_K)
    if parallel.method == 'serial':
        for count, Kp in enumerate(dK_list):
            res.append(paralfunc(Kp, **remote_parameters))
            if (count + 1) % nstep_print == 0:
                t_print_prev = print_progress(count + 1, numK, t0, t_print_prev, print_progress_step)
    elif parallel.method == 'ray':
        remotes = [paralfunc.remote(dK, **remote_parameters) for dK in dK_list]
        num_remotes = len(remotes)
        num_remotes_calculated = 0
        while True:
            remotes_calculated, _ = parallel.ray.wait(
                remotes, num_returns=min(num_remotes_calculated + nstep_print, num_remotes),
                timeout=60)  # even, if the required number of remotes had not finished,
            # the progress will be printed every minute
            num_remotes_calculated = len(remotes_calculated)
            if num_remotes_calculated >= num_remotes:
                break
            t_print_prev = print_progress(num_remotes_calculated, numK, t0, t_print_prev, print_progress_step)
        res = parallel.ray.get(remotes)
    else:
        raise RuntimeError(f"unsupported parallel method : '{parallel.method}'")

    if not (pointgroup is None):
        res = [pointgroup.symmetrize(r) for r in res]
    for i, ik in enumerate(selK):
        K_list[ik].set_res(res[i])

    t = time() - t0
    if parallel.method == 'serial':
        print(f"time for processing {numK:6d} K-points in serial: ", end="")
        nproc_ = 1
    else:
        print(f"time for processing {numK:6d} K-points on {parallel.npar_K:3d} processes: ", end="")
        nproc_ = parallel.npar_K
    print(f"{t:10.4f} ; per K-point {t / numK:15.4f} ; proc-sec per K-point {t * nproc_ / numK:15.4f}", flush=True)
    return len(dK_list)


[docs] def run( system, grid, calculators, adpt_num_iter=0, use_irred_kpt=True, symmetrize=True, fout_name="result", suffix="", parameters_K=None, file_Klist=None, restart=False, Klist_part=10, parallel=None, # serial by default print_Kpoints=False, adpt_mesh=2, adpt_fac=1, fast_iter=True, print_progress_step=5, ): """ The function to run a calculation. Substitutes the old (obsolete and removed) `integrate()` and `tabulate()` and allows to integrate and tabulate in one run. Parameters ---------- system : :class:`~wannierberri.system.System` System under investigation grid : :class:`~wannierberri.Grid` or :class:`~wannierberri.Path` initial grid for integration. or path for tabulation calculators : dict a dictionary where keys aare any string identifiers, and the values are of :class:`~wannierberri.calculators.Calculator` adpt_num_iter : int number of recursive adaptive refinement iterations. See :ref:`sec-refine` adpt_mesh : int the size of the refinement grid (usuallay no need to change) adpt_fac : int number of K-points to be refined per quantity and criteria. parallel : :class:`~wannierberri.parallel.Parallel` object describing parallelization scheme use_irred_kpt : bool evaluate only symmetry-irreducible K-points symmetrize : bool symmetrize the result (always `True` if `use_irred_kpt == True`) fout_name : str beginning of the output files for each quantity after each iteration suffix : str extra marker inserted into output files to mark this particular calculation run print_Kpoints : bool print the list of K points file_Klist : str or None name of file where to store the Kpoint list of each iteration. May be needed to restart a calculation to get more iterations. If `None` -- the file is not written restart : bool if `True` : reads restart information from `file_Klist` and starts from there Klist_part : int write the file_Klist by portions. Increase for speed, decrease for memory saving parameters_K: dict parameters to be passed to :class:`~wannierberri.data_K.Data_K` class fast_iter : bool if under iterations appear peaks that arte not further removed, set this parameter to False. print_progress_step : float or int intervals to print progress Returns -------- dictionary of :class:`~wannierberri.result.EnergyResult` Notes ----- Results are also printed to ASCII files """ if parallel is None: parallel = Serial() cprint("Starting run()", 'red', attrs=['bold']) if parameters_K is None: parameters_K = {} print_calculators(calculators) # along a path only tabulating is possible if isinstance(grid, Path): print("Calculation along a path - checking calculators for compatibility") for key, calc in calculators.items(): print(key, calc) if not calc.allow_path: raise ValueError( f"Calculation along a Path is running, but calculator `{key}` is not compatible with a Path") print("All calculators are compatible") if symmetrize: print("Symmetrization switched off for Path") symmetrize = False else: print("Calculation on grid - checking calculators for compatibility") if use_irred_kpt: symmetrize = True for key, calc in calculators.items(): print(key, calc) if not calc.allow_grid: raise ValueError( f"Calculation on Grid is running, but calculator `{key}` is not compatible with a Grid") print("All calculators are compatible") if isinstance(grid, GridTetra): print("Grid is tetrahedral") else: print("Grid is regular") if file_Klist is not None: do_write_Klist = True if not file_Klist.endswith(".pickle"): file_Klist += ".pickle" file_Klist_factor_changed = file_Klist + ".changed_factors.txt" else: file_Klist_factor_changed = file_Klist[:-7] + ".changed_factors.txt" else: do_write_Klist = False file_Klist_factor_changed = None print(f"The set of k points is a {grid.str_short}") remote_parameters = {'_system': system, '_grid': grid, 'npar_k': parallel.npar_k, '_calculators': calculators} if parallel.method == 'ray': ray = parallel.ray remote_parameters = {k: ray.put(v) for k, v in remote_parameters.items()} @ray.remote def paralfunc(Kpoint, _system, _grid, _calculators, npar_k): # import sys # print("Worker sys.path:", sys.path) # from wannierberri.system.rvectors import Rvectors data = get_data_k(_system, Kpoint.Kp_fullBZ, npar_k=npar_k, grid=_grid, Kpoint=Kpoint, **parameters_K) return ResultDict({k: v(data) for k, v in _calculators.items()}) else: def paralfunc(Kpoint, _system, _grid, _calculators, npar_k): data = get_data_k(_system, Kpoint.Kp_fullBZ, npar_k=npar_k, grid=_grid, Kpoint=Kpoint, **parameters_K) return ResultDict({k: v(data) for k, v in _calculators.items()}) if restart: try: fr = open(file_Klist, "rb") K_list = [] while True: try: K_list += pickle.load(fr) except EOFError: print(f"Finished reading Klist from file {file_Klist}") break print(f"{len(K_list)} K-points were read from {file_Klist}") if len(K_list) == 0: warnings.warn(f"{file_Klist} contains zero points starting from scrath") restart = False fr.close() nk_prev = len(K_list) try: # patching the Klist by updating the factors fr_div = open(file_Klist_factor_changed, "r") factor_changed_K_list = [] for line in fr_div: line_ = line.split() iK = int(line_[0]) fac = float(line_[1]) factor_changed_K_list.append(iK) K_list[iK].factor = fac print(f"{len(factor_changed_K_list)} K-points were read from {file_Klist_factor_changed}") fr_div.close() except FileNotFoundError: print(f"File with changed factors {file_Klist_factor_changed} not found, assume they were not changed") except Exception as err: raise RuntimeError(f"{err}: reading from {file_Klist} failed, starting from scrath") print("searching for start_iter") try: start_iter = int( sorted(glob.glob(fout_name + "*" + suffix + "_iter-*.dat"))[-1].split("-")[-1].split(".")[0]) print(f"start_iter = {start_iter}") except Exception as err: warnings.warn(f"{err} : failed to read start_iter. Setting to zero") start_iter = 0 else: K_list = grid.get_K_list(use_symmetry=use_irred_kpt) print("Done, sum of weights:{}".format(sum(Kp.factor for Kp in K_list))) start_iter = 0 nk_prev = 0 remove_file(file_Klist) remove_file(file_Klist_factor_changed) if adpt_num_iter < 0: adpt_num_iter = -adpt_num_iter * np.prod(grid.div) / np.prod(adpt_mesh) / adpt_fac / 3 adpt_num_iter = int(round(adpt_num_iter)) if (adpt_mesh is None) or np.max(adpt_mesh) <= 1: adpt_num_iter = 0 else: if not isinstance(adpt_mesh, Iterable): adpt_mesh = [adpt_mesh] * 3 adpt_mesh = np.array(adpt_mesh) counter = 0 result_all = None result_excluded = None for i_iter in range(adpt_num_iter + 1): if print_Kpoints: print("iteration {0} - {1} points. New points are:".format(i_iter + start_iter, len([K for K in K_list if K.res is None]))) for i, K in enumerate(K_list): if not K.evaluated: print(f" K-point {i} : {K} ") counter += process( paralfunc, K_list, parallel, pointgroup=system.pointgroup if symmetrize else None, print_progress_step=print_progress_step, remote_parameters=remote_parameters) nk = len(K_list) try: if do_write_Klist: # append new (refined) k-points only fw = open(file_Klist, "ab") for ink in range(nk_prev, nk, Klist_part): pickle.dump(K_list[ink:ink + Klist_part], fw) fw.close() except Exception as err: warnings.warn(f" {err} \n the K_list was not pickled") time0 = time() if (result_all is None) or (not fast_iter): result_all = sum(kp.get_res for kp in K_list) else: if result_excluded is not None: result_all -= result_excluded result_all += sum(kp.get_res for kp in K_list[nk_prev:]) time1 = time() print("time1 = ", time1 - time0) if not (restart and i_iter == 0): result_all.savedata(prefix=fout_name, suffix=suffix, i_iter=i_iter + start_iter) if i_iter >= adpt_num_iter: break # Now add some more points Kmax = np.array([K.max for K in K_list]).T select_points = set().union(*(np.argsort(Km)[-adpt_fac:] for Km in Kmax)) time2 = time() print("time2 = ", time2 - time1) l1 = len(K_list) excluded_Klist = [] result_excluded = None nk_prev = nk for iK in select_points: results = K_list[iK].get_res K_list += K_list[iK].divide(adpt_mesh, periodic=system.periodic, use_symmetry=use_irred_kpt) if abs(K_list[iK].factor) < 1.e-10: excluded_Klist.append(iK) if result_excluded is None: result_excluded = results - K_list[iK].get_res else: result_excluded += results - K_list[iK].get_res if use_irred_kpt and isinstance(grid, Grid): print(f"checking for equivalent points in all points (of new {len(K_list) - l1} points)") nexcl, weight_changed_old = exclude_equiv_points(K_list, new_points=len(K_list) - l1) print(f"excluded {nexcl} points") else: weight_changed_old = {} print("sum of weights now :{}".format(sum(Kp.factor for Kp in K_list))) for iK, prev_factor in weight_changed_old.items(): result_excluded += K_list[iK].res * (prev_factor - K_list[iK].factor) if do_write_Klist: print(f"Writing file_Klist_factor_changed to {file_Klist_factor_changed}") fw_changed = open(file_Klist_factor_changed, "a") for iK in excluded_Klist: fw_changed.write(f"{iK} 0.0 # refined\n") for iK in weight_changed_old: fw_changed.write(f"{iK} {K_list[iK].factor} # changed\n") fw_changed.close() print(f"Totally processed {counter} K-points ") print("run() finished") return result_all
def print_calculators(calculators): cprint("Using the follwing calculators : \n" + "#" * 60 + "\n", "cyan", attrs=["bold"]) for key, val in calculators.items(): cprint(f" '{key}' ", "magenta", attrs=["bold"], end="") print(" : ", end="") cprint(f" {val} ", "yellow", attrs=["bold"], end="") print(f" : {val.comment}") cprint("#" * 60, "cyan", attrs=["bold"])