Source code for wannierberri.run

#                                                            #
# This file is distributed as part of the WannierBerri code  #
# under the terms of the GNU General Public License. See the #
# file `LICENSE' in the root directory of the WannierBerri   #
# distribution, or http://www.gnu.org/copyleft/gpl.txt       #
#                                                            #
# The WannierBerri code is hosted on GitHub:                 #
# https://github.com/stepan-tsirkin/wannier-berri            #
#                     written by                             #
#           Stepan Tsirkin, University of Zurich             #
#                                                            #
#------------------------------------------------------------

import numpy as np
from collections.abc import Iterable
from time import time
import pickle
import glob

from .data_K import Data_K
from .grid import exclude_equiv_points, Path, Grid, GridTetra
from .parallel import Serial
from .result import ResultDict


def print_progress(count, total, t0, tprev, print_progress_step):
    t = time() - t0
    t_remain = None
    if count == 0:
        t_remain = "unknown"
    else:
        t_rem_s = t / count * (total - count)
        t_remain = "{:22.1f}".format(t_rem_s)
    if t-tprev>print_progress_step:
        print("{:20d}{:17.1f}{:>22s}".format(count, t, t_remain), flush=True)
        tprev=t
    return tprev


def process(paralfunc, K_list, parallel, symgroup=None, remote_parameters={}, print_progress_step=5):
    print(f"symgroup : {symgroup}")
    t0 = time()
    t_print_prev = t0
    selK = [ik for ik, k in enumerate(K_list) if k.res is None]
    numK = len(selK)
    dK_list = [K_list[ik] for ik in selK]
    if len(dK_list) == 0:
        print("nothing to process now")
        return 0

    if parallel.method == 'ray':
        remotes = [paralfunc.remote(dK, **remote_parameters) for dK in dK_list]


    print("processing {0} K points :".format(len(dK_list)), end=" ")
    if parallel.method == 'serial':
        print("in serial.")
    else:
        print("using  {} processes.".format(parallel.npar_K))

    print("# K-points calculated  Wall time (sec)  Est. remaining (sec)", flush=True)
    res = []
    nstep_print = parallel.progress_step(numK, parallel.npar_K)
    if parallel.method == 'serial':
        for count, Kp in enumerate(dK_list):
            res.append(paralfunc(Kp, **remote_parameters))
            if (count + 1) % nstep_print == 0:
                t_print_prev = print_progress(count + 1, numK, t0, t_print_prev, print_progress_step)
    elif parallel.method == 'ray':
        num_remotes = len(remotes)
        num_remotes_calculated = 0
        while True:
            remotes_calculated, _ = parallel.ray.wait(
                remotes, num_returns=min(num_remotes_calculated + nstep_print, num_remotes),
                timeout=60)  # even, if the required number of remotes had not finished,
            # the progress will be printed every minute
            num_remotes_calculated = len(remotes_calculated)
            if num_remotes_calculated >= num_remotes:
                break
            t_print_prev = print_progress(num_remotes_calculated, numK, t0, t_print_prev, print_progress_step)
        res = parallel.ray.get(remotes)
    else:
        raise RuntimeError(f"unsupported parallel method : '{parallel.method}'")

    if not (symgroup is None):
        res = [symgroup.symmetrize(r) for r in res]
    for i, ik in enumerate(selK):
        K_list[ik].set_res(res[i])

    t = time() - t0
    if parallel.method == 'serial':
        print("time for processing {0:6d} K-points in serial: ".format(numK), end="")
        nproc_ = 1
    else:
        print("time for processing {0:6d} K-points on {1:3d} processes: ".format(numK, parallel.npar_K), end="")
        nproc_ = parallel.npar_K
    print(
        "{0:10.4f} ; per K-point {1:15.4f} ; proc-sec per K-point {2:15.4f}".format(t, t / numK, t * nproc_ / numK),
        flush=True)
    return len(dK_list)


[docs]def run( system, grid, calculators, adpt_num_iter=0, use_irred_kpt=True, symmetrize=True, fout_name="result", suffix="", parameters_K={}, file_Klist=None, restart=False, Klist_part=10, parallel=Serial(), # serial by default print_Kpoints=True, adpt_mesh=2, adpt_fac=1, fast_iter=True, print_progress_step=5, ): """ The function to run a calculation. Substitutes the old :func:`~wannierberri.integrate` and :func:`~wannierberri.tabulate` and allows to integrate and tabulate in one run. Parameters ---------- system : :class:`~wannierberri.system.System` System under investigation grid : :class:`~wannierberri.Grid` or :class:`~wannierberri.Path` initial grid for integration. or path for tabulation calculators : dict a dictionary where keys aare any string identifiers, and the values are of :class:`~wannierberri.calculators.Calculator` adpt_num_iter : int number of recursive adaptive refinement iterations. See :ref:`sec-refine` adpt_mesh : int the size of the refinement grid (usuallay no need to change) adpt_fac : int number of K-points to be refined per quantity and criteria. parallel : :class:`~wannierberri.parallel.Parallel` object describing parallelization scheme use_irred_kpt : bool evaluate only symmetry-irreducible K-points symmetrize : bool symmetrize the result (always `True` if `use_irred_kpt == True`) fout_name : str beginning of the output files for each quantity after each iteration suffix : str extra marker inserted into output files to mark this particular calculation run print_Kpoints : bool print the list of K points file_Klist : str or None name of file where to store the Kpoint list of each iteration. May be needed to restart a calculation to get more iterations. If `None` -- the file is not written restart : bool if `True` : reads restart information from `file_Klist` and starts from there Klist_part : int write the file_Klist by portions. Increase for speed, decrease for memory saving fast_iter : bool if under iterations appear peaks that arte not further removed, set this parameter to False. print_progress_step : float or int intervals to print progress Returns -------- dictionary of :class:`~wannierberri.result.EnergyResult` Notes ----- Results are also printed to ASCII files """ # along a path only tabulating is possible if isinstance(grid,Path): print ("Calculation along a path - checking calculators for compatibility") for key,calc in calculators.items(): print (key,calc) if not calc.allow_path: raise ValueError(f"Calculation along a Path is running, but calculator `{key}` is not compatible with a Path") print ("All calculators are compatible") if symmetrize: print ("Symmetrization switched off for Path") symmetrize = False if isinstance(grid,GridTetra): print ("Grid is tetrahedral") else: print ("Grid is regular") if file_Klist is not None: do_write_Klist = True if not file_Klist.endswith(".pickle"): file_Klist += ".pickle" file_Klist_factor_changed = file_Klist+".changed_factors.txt" else: file_Klist_factor_changed = file_Klist[:-7]+".changed_factors.txt" else: do_write_Klist = False file_Klist_factor_changed = None print(f"The set of k points is a {grid.str_short}") remote_parameters = {'_system': system, '_grid': grid, 'npar_k': parallel.npar_k, '_calculators': calculators} if parallel.method == 'ray': ray = parallel.ray remote_parameters = {k: ray.put(v) for k, v in remote_parameters.items()} @ray.remote def paralfunc(Kpoint, _system, _grid, _calculators, npar_k): data = Data_K(_system, Kpoint.Kp_fullBZ, grid=_grid, Kpoint=Kpoint, **parameters_K) return ResultDict({k: v(data) for k, v in _calculators.items()}) else: def paralfunc(Kpoint, _system, _grid, _calculators, npar_k): data = Data_K(_system, Kpoint.Kp_fullBZ, grid=_grid, Kpoint=Kpoint, **parameters_K) return ResultDict({k: v(data) for k, v in _calculators.items()}) if restart: try: fr = open(file_Klist, "rb") K_list = [] while True: try: K_list += pickle.load(fr) except EOFError: print("Finished reading Klist from file {0}".format(file_Klist)) break print("{0} K-points were read from {1}".format(len(K_list), file_Klist)) if len(K_list) == 0: print("WARNING : {0} contains zero points starting from scrath".format(file_Klist)) restart = False fr.close() nk_prev = len(K_list) try: # patching the Klist by updating the factors fr_div = open(file_Klist_factor_changed, "r") factor_changed_K_list = [] for line in fr_div: line_ = line.split() iK = int(line_[0]) fac = float(line_[1]) factor_changed_K_list.append(iK) K_list[iK].factor = fac print("{0} K-points were read from {1}".format(len(factor_changed_K_list), file_Klist_factor_changed)) fr_div.close() except FileNotFoundError: print (f"File with changed factors {file_Klist_factor_changed} not found, assume they were not changed") except Exception as err: restart = False # print("WARNING: {}".format(err)) raise RuntimeError("{1}: reading from {0} failed, starting from scrath".format(file_Klist,err)) else: K_list = grid.get_K_list(use_symmetry=use_irred_kpt) print("Done, sum of weights:{}".format(sum(Kp.factor for Kp in K_list))) start_iter = 0 nk_prev = 0 if not restart: import os def remove_file(filename): if filename is not None and os.path.exists(filename): os.remove(filename) remove_file(file_Klist) remove_file(file_Klist_factor_changed) # suffix="-"+suffix if len(suffix)>0 else "" if restart: print("searching for start_iter") try: start_iter = int( sorted(glob.glob(fout_name + "*" + suffix + "_iter-*.dat"))[-1].split("-")[-1].split(".")[0]) print(f"start_iter = {start_iter}") except Exception as err: print("WARNING : {0} : failed to read start_iter. Setting to zero".format(err)) start_iter = 0 if adpt_num_iter < 0: adpt_num_iter = -adpt_num_iter * np.prod(grid.div) / np.prod(adpt_mesh) / adpt_fac / 3 adpt_num_iter = int(round(adpt_num_iter)) if (adpt_mesh is None) or np.max(adpt_mesh) <= 1: adpt_num_iter = 0 else: if not isinstance(adpt_mesh, Iterable): adpt_mesh = [adpt_mesh] * 3 adpt_mesh = np.array(adpt_mesh) counter = 0 result_all = None result_excluded = None for i_iter in range(adpt_num_iter + 1): if print_Kpoints: print( "iteration {0} - {1} points. New points are:".format(i_iter+start_iter, len([K for K in K_list if K.res is None]))) for i, K in enumerate(K_list): if not K.evaluated: print(" K-point {0} : {1} ".format(i, K)) counter += process( paralfunc, K_list, parallel, symgroup=system.symgroup if symmetrize else None, print_progress_step=print_progress_step, remote_parameters=remote_parameters) nk = len(K_list) try: if do_write_Klist: # append new (refined) k-points only fw = open(file_Klist, "ab") for ink in range(nk_prev, nk, Klist_part): pickle.dump(K_list[ink:ink + Klist_part], fw) fw.close() except Exception as err: print("Warning: {0} \n the K_list was not pickled".format(err)) time0 = time() if (result_all is None) or (not fast_iter): result_all = sum(kp.get_res for kp in K_list) else: if result_excluded is not None: result_all -= result_excluded result_all += sum(kp.get_res for kp in K_list[nk_prev:]) time1 = time() print("time1 = ", time1 - time0) if not (restart and i_iter == 0): result_all.savedata(prefix=fout_name, suffix=suffix, i_iter=i_iter + start_iter) if i_iter >= adpt_num_iter: break # Now add some more points Kmax = np.array([K.max for K in K_list]).T select_points = set().union(*(np.argsort(Km)[-adpt_fac:] for Km in Kmax)) time2 = time() print("time2 = ", time2 - time1) l1 = len(K_list) excluded_Klist = [] result_excluded = None nk_prev = nk for iK in select_points: results = K_list[iK].get_res K_list += K_list[iK].divide(adpt_mesh, periodic=system.periodic, use_symmetry=use_irred_kpt) if abs(K_list[iK].factor) < 1.e-10: excluded_Klist.append(iK) if result_excluded is None: result_excluded = results - K_list[iK].get_res else: result_excluded += results - K_list[iK].get_res if use_irred_kpt and isinstance(grid,Grid): print("checking for equivalent points in all points (of new {} points)".format(len(K_list) - l1)) nexcl, weight_changed_old = exclude_equiv_points(K_list, new_points=len(K_list) - l1) print(" excluded {0} points".format(nexcl)) else: weight_changed_old = {} print("sum of weights now :{}".format(sum(Kp.factor for Kp in K_list))) for iK, prev_factor in weight_changed_old.items(): result_excluded += K_list[iK].res * (prev_factor - K_list[iK].factor) if do_write_Klist: print (f"Writing file_Klist_factor_changed to {file_Klist_factor_changed}") fw_changed = open(file_Klist_factor_changed, "a") for iK in excluded_Klist: fw_changed.write("{0} {1} # refined\n".format(iK, 0.0)) for iK in weight_changed_old: fw_changed.write("{0} {1} # changed\n".format(iK, K_list[iK].factor)) fw_changed.close() print("Totally processed {0} K-points ".format(counter)) return result_all