Source code for wannierberri.symmetry.projections

import copy
from functools import cached_property
import itertools
import numpy as np
from ..symmetry.orbitals import orbitals_sets_dic

try:
    from jax import config
    config.update("jax_enable_x64", True)
    from jax import numpy as jnp
    from jax.scipy.optimize import minimize as jminimize
    from jax import jit as jjit
except ImportError:
    # warnings.warn("jax not found, will use numpy insrtead")
    import numpy as jnp
    from scipy.optimize import minimize as jminimize
    from functools import partial as jjit


from ..symmetry.orbitals import Orbitals, num_orbitals
from .unique_list import UniqueListMod1
from .wyckoff_position import WyckoffPosition, WyckoffPositionNumeric, get_shifts

ORBITALS = Orbitals()


[docs] class Projection: """ A class to store initial projections. Parameters ---------- position_num : np.array(shape=(n,3,), dtype=float) or str The position of the projection if fractional coordinates. position_str : str comma-separated positions with x,y,z being the free variables, e.g. "x,y,z", "x,x-y,1/2", etc. orbital : str The orbital of the projection. e.g. "s", "p", "sp3 etc. or several separated by a semicolon (e.g. "s;p") spacegroup : irrep.spacegroup.SpaceGroup The spacegroup of the structure. All points equivalent to the given ones are also added (not needed if wyckoff_position is provided) void : bool if true, create an empty object, to be filled later wyckoff_position : WyckoffPosition or WyckoffPositionNumeric The wyckoff position of the projection. If provided, the position_num and position_str nd spacegroup are ignored free_var_values : np.array(shape=(n,), dtype=float) The values of the free variables in the position_str spinor : bool If True, the projection is a spinor (overrides the spacegroup.spinor) rotate_basis : bool If True, the basis for the projection is rotated for each site according to the spacegroup (experimental) If False, the basis is the same for all sites (old behaviour) zaxis : np.array(shape=(3,), dtype=float) The z-axis of the basis, if rotate_basis is True xaxis : np.array(shape=(3,), dtype=float) The x-axis of the basis, if rotate_basis is True Notes ----- * if both xaxis and zaxis are provided, they should be orthogonal * if only one of xaxis and zaxis is provided, the other is calculated as the perpendicular vector, coplanar with the provided one and the default one * if neither xaxis nor zaxis are provided, the default basis is used * the yaxis is calculated as the cross product of zaxis and xaxis * the spinor basis is NOT rotated, i.e. all wannier functions are in the sigma-z basis Attributes ---------- orbitals : list(str) The orbitals of the projection wyckoff_position : WyckoffPosition or WyckoffPositionNumeric The wyckoff position of the projection spinor : bool If True, the projection is a spinor basis_list : list(np.array(shape=(3,3), dtype=float)) The basis for each site (row-vectors) positions : np.array(shape=(n,3), dtype=float) The positions of the projections num_wann_per_site : int The number of Wannier functions per site num_points : int The number of points num_wann : int The total number of Wannier functions orbitals_str : str The orbitals of the projection as one string separated by semicolons `;` """ def __init__(self, position_sym=None, position_num=None, spacegroup=None, wyckoff_position=None, orbital='s', void=False, free_var_values=None, spinor=None, rotate_basis=False, zaxis=None, xaxis=None, do_not_split_projections=False): if void: return if do_not_split_projections: self.orbitals = [orbital] else: self.orbitals = orbital.split(";") if wyckoff_position is not None: self.wyckoff_position = wyckoff_position else: assert spacegroup is not None, "either wyckoff_position or spacegroup should be provided" if position_num is None: assert position_sym is not None, "either position_num or position_str should be provided" self.wyckoff_position = WyckoffPosition(position_str=position_sym, spacegroup=spacegroup, free_var_values=free_var_values) else: assert position_sym is None, "position_num and position_str should NOT be provided together" position_num = np.array(position_num) if position_num.ndim == 1: position_num = position_num[None, :] self.wyckoff_position = WyckoffPositionNumeric(positions=position_num, spacegroup=spacegroup) if spinor is None: if spacegroup is not None: spinor = spacegroup.spinor elif wyckoff_position is not None: spinor = wyckoff_position.spacegroup.spinor else: spinor = False self.spinor = spinor if rotate_basis: basis0 = read_xzaxis(xaxis, zaxis) self.basis_list = [np.dot(basis0, rot.T) for rot in self.wyckoff_position.rotations_cart] else: self.basis_list = [np.eye(3, dtype=float)] * self.num_points @property def positions(self): return self.wyckoff_position.positions @property def num_wann_per_site(self): """number of wannier functions per site (without spin)""" return sum(num_orbitals(o) for o in self.orbitals) @property def num_wann_per_site_spinor(self): """number of wannier functions per site (with spin)""" return self.num_wann_per_site * (2 if self.spinor else 1) @property def num_points(self): return self.wyckoff_position.num_points @property def num_wann(self): return self.num_points * self.num_wann_per_site @property def orbitals_str(self): return ";".join(self.orbitals)
[docs] def split(self): """ assuming that Projections may contain several orbitals, this function splits them into separate projections if there is only one - a list with one element is returned """ return [Projection(wyckoff_position=self.wyckoff_position, orbital=o) for o in self.orbitals]
[docs] def copy(self): new = Projection(void=True) new.orbitals = self.orbitals new.spinor = self.spinor new.wyckoff_position = self.wyckoff_position return new
def __add__(self, other): new = self.copy() if other is not None: assert self.wyckoff_position == other.wyckoff_position, f"Cannot add projections from different wyckoff positions {self.wyckoff_position} and {other.wyckoff_position}" new.orbitals += other.orbitals return new def __radd__(self, other): return self.__add__(other) def __str__(self): return (f"Projection {self.wyckoff_position.string}:{self.orbitals} with {self.num_wann} Wannier functions" f" on {self.num_points} points ({self.num_wann_per_site} per site)" # + self.wyckoff_position.__str__() )
[docs] def write_wannier90(self, mod1=False): string = "" for o in self.orbitals: for pos in self.wyckoff_position.positions: if mod1: pos = pos % 1 string += f"f={pos[0]:.12f}, {pos[1]:.12f}, {pos[2]:.12f}: {o}\n" return string
@cached_property def str_short(self): return f"{self.wyckoff_position.string}:{self.orbitals}"
[docs] def get_positions_and_orbitals(self): """ Returns ------- list(np.ndarray(shape=(3,), dtype=float)) The positions of the projections list(str) The orbitals of the projections (each orbital , e.g. pz, sp3-2, dx2-y2, etc.) """ orbitals = [] positions = [] for pos in self.positions: for orb in self.orbitals: for o in orbitals_sets_dic[orb]: orbitals.append(o) positions.append(pos) return positions, orbitals
[docs] class ProjectionsSet: """ class to store the set of projections and corresponding windows """ def __init__(self, projections=[]): self.spinor = None for i, p in enumerate(projections): assert isinstance(p, Projection), f"element {i} of list 'projections' should be a Projection, not {p}" self.set_spinor(p.spinor) self.projections = copy.copy(projections)
[docs] def copy(self): return ProjectionsSet(projections=[p.copy() for p in self.projections])
[docs] def set_spinor(self, spinor: bool): self.spinor = spinor if self.spinor is None: self.spinor = spinor else: assert self.spinor == spinor, f"spinor should be the same for all projections. Previously set to {self.spinor}, now trying to set to {spinor}"
@property def num_proj(self): return len(self.projections) def __len__(self): return self.num_proj @cached_property def num_points(self): print(f"finding num points from {self.num_proj} projections") return sum([p.wyckoff_position.num_points for p in self.projections]) @cached_property def num_wann(self): return sum([p.num_wann for p in self.projections])
[docs] def add(self, projection): self.projections.append(projection) self.set_spinor(projection.spinor)
def __add__(self, other): new = ProjectionsSet(projections=self.projections + other.projections) new.clear_cached_properties() return new def __str__(self): return (f"ProjectionsSet with {self.num_wann} Wannier functions and {self.num_free_vars} free variables\n" + "\n".join([str(p) for p in self.projections]) ) @cached_property def num_free_vars_wyckoff(self): return sum([p.wyckoff_position.num_free_vars for p in self.projections])
[docs] def as_numeric(self): new = self.copy() for p in new.projections: p.wyckoff_position = p.wyckoff_position.as_numeric() return new
[docs] def split_orbitals(self): return ProjectionsSet(sum((p.split() for p in self.projections), []))
@property def map_free_vars(self): """ get the mapping from free variables to positions of the points Returns ------- np.ndarray(shape=(num_points, 3, num_free_vars_wyckoff), dtype=float) The rotation matrices of the symmetry operations np.ndarray(shape=(num_points, 3), dtype=float) The translation vectors of the symmetry operations """ if not hasattr(self, "map_free_vars_cached"): # print ("mapping free vars" ) maps = [p.wyckoff_position.map_orbit_on_free_vars for p in self.projections] rotations = np.zeros((self.num_points, 3, self.num_free_vars_wyckoff), dtype=float) translations = np.zeros((self.num_points, 3), dtype=float) rot = [m[0] for m in maps] trans = [m[1] for m in maps] self._vars_end = np.cumsum([m[0].shape[-1] for m in maps]) self._vars_start = np.concatenate(([0], self._vars_end[:-1])) self._pos_end = np.cumsum([m[0].shape[0] for m in maps]) self._pos_start = np.concatenate(([0], self._pos_end[:-1])) for r, t, vs, ve, ps, pe in zip(rot, trans, self._vars_start, self._vars_end, self._pos_start, self._pos_end): rotations[ps:pe, :, vs:ve] = r translations[ps:pe] = t self.map_free_vars_cached = rotations, translations return self.map_free_vars_cached @property def num_free_vars(self): return self.map_free_vars[0].shape[2] @cached_property def num_wann_per_site_list(self): """ Returns: -------- np.array(int, shape=(num_points)) for each point - a value od how many wannier functioons there are on this point """ return np.array(sum(([p.num_wann_per_site] * p.num_points for p in self.projections), [])) @property def vars_end(self): self.map_free_vars return self._vars_end @property def vars_start(self): self.map_free_vars return self._vars_start @property def pos_end(self): self.map_free_vars return self._pos_end @property def pos_start(self): self.map_free_vars return self._pos_start
[docs] def get_positions_from_free_vars(self, free_vars): rotations, translations = self.map_free_vars return rotations @ free_vars + translations
[docs] def get_positions(self): return self.get_positions_from_free_vars(self.free_var_values)
[docs] def get_distances(self): pos = self.get_positions() return find_distance_periodic(pos, self.projections[0].wyckoff_position.spacegroup.Lattice, max_shift=2)
[docs] def get_min_distance(self): return np.min([l[i:].min() for i, l in enumerate(self.get_distances())])
[docs] def join_same_wyckoff(self, unmergable=[], use_unmergable_defaults=True): """ merge different projections on the same wyckoff positions """ if use_unmergable_defaults: unmergable_loc = [('s', 'sp3'), ('p', 'sp3')] # TODO : add more else: unmergable_loc = [] unmergable_loc += unmergable stick = [] istick = [] for i, p in enumerate(self.projections): found = False for st, ist in zip(stick, istick): if st[0].wyckoff_position == p.wyckoff_position: for p2 in st: proj1 = set(p.orbitals) proj2 = set(p2.orbitals) # check if they may be merged print(f"checking if they may be merged {proj1} and {proj2}") merge = True for p1, p2 in itertools.product(proj1, proj2): print(p1, p2) if (p1 == p2 or (p1, p2) in unmergable_loc or (p2, p1) in unmergable_loc ): print("not merging") merge = False break else: print("merging") if not merge: break else: st.append(p) ist.append(i) found = True break if not found: stick.append([p]) istick.append([i]) num_free_vars_new_per_group = [self.vars_end[ist[0]] - self.vars_start[ist[0]] for ist in istick] srt = np.argsort(num_free_vars_new_per_group) stick = [stick[i] for i in srt] istick = [istick[i] for i in srt] new_projections = [] for st in stick: projection = sum(st, None) new_projections.append(projection) self.projections = new_projections self.clear_cached_properties()
[docs] def clear_cached_properties(self, attributes=None): """ Clear the cached properties Parameters: ----------- attributes: list(str) The list of attributes to clear. If None, all cached properties are cleared """ if attributes is None: attributes = ["map_free_vars_cached", "_free_vars", "num_wann_per_site", "num_points", "num_wann", "num_free_vars_wyckoff",] for attr in attributes: if hasattr(self, attr): delattr(self, attr)
[docs] def stick_to_atoms(self, atoms=[]): """ NOT SURE IF THIS IS NEEDED NEITHER IF IT WORKS PROPERLY Reduces the number of free variables by sticking together the ones that correspond to the same wyckoff position but different projections resets the free_vars and clears the cached properties and sets map_free_vars to the new map Parameters ---------- atoms : np.ndarray(shape=(num_atoms,3), dtype=float) List of atomic positions """ fixed = [] atoms_filled = np.zeros(len(atoms), dtype=bool) num_free_vars_new = 0 for p in self.projections: fixed.append(p.wyckoff_position.stick_to_atoms(atoms=atoms, atoms_filled=atoms_filled)) if fixed[-1] is None: num_free_vars_new += p.wyckoff_position.num_free_vars new_map = np.zeros((self.num_free_vars, num_free_vars_new), dtype=int) new_map_fix = np.zeros(self.num_free_vars, dtype=float) start = 0 for i, p in enumerate(self.projections): if fixed[i] is None: nvar_loc = self.vars_end[i] - self.vars_start[i] end = start + nvar_loc new_map[self.vars_start[i]:self.vars_end[i], start:end] = np.eye(nvar_loc, dtype=int) start = end else: new_map_fix[self.vars_start[i]:self.vars_end[i]] = fixed[i] rot, trans = self.map_free_vars self.clear_cached_properties(["_free_vars"]) self.map_free_vars_cached = rot @ new_map, rot @ new_map_fix + trans
# print(f"updated rot,trans {self.map_free_vars[0].shape}, {self.map_free_vars[1].shape}") # print(f"updated rot,trans {self.map_free_vars_cached[0].shape}, {self.map_free_vars_cached[1].shape}")
[docs] def write_wannier90(self, mod1=False, beginend=True, numwann=True): """ return a string of wannier90 input file for projections Parameters ---------- mod1 : bool If True, the positions are printed modulo 1 Returns ------- str The string for the wannier90 input file """ positions = self.get_positions() if mod1: positions = positions % 1 string = "" if numwann: string += f"num_wann = {self.num_wann}\n" if beginend: string += "begin projections\n" for p in self.projections: string += p.write_wannier90(mod1=mod1) if beginend: string += "end projections\n" return string
[docs] def write_with_multiplicities(self, multiplicities=None, orbit=False): """ return a string describing which projections are taken and how many times(if not zero) Parameters ---------- multiplicity : np.ndarray(shape=(num_projections), dtype=int) The multiplicity of each projection orbit : bool If True, the orbit of the wyckoff position is also printed """ if multiplicities is None: multiplicities = np.ones(self.num_proj, dtype=int) assert len(multiplicities) == self.num_proj breakline = "-" * 80 + "\n" string = breakline num_wann = 0 for m, p in zip(multiplicities, self.projections): assert m >= 0, f"multiplicity {m} should be non-negative" if m > 0: string += f"{m} X | {p.str_short} \n" num_wann += m * p.num_wann if orbit: string += p.wyckoff_position.orbit_str() + "\n" string += f"total number of Wannier functions = {num_wann}\n" string += breakline return string
[docs] def get_combination(self, multiplicities, dcopy=True): """ get the combination of projections Parameters ---------- multiplicities : np.ndarray(shape=(num_projections), dtype=int) The multiplicity of each projection Returns ------- ProjectionsSet The projections set with the given multiplicities """ assert len(multiplicities) == self.num_proj new_projections = [] for m, p in zip(multiplicities, self.projections): assert m >= 0, f"multiplicity {m} should be non-negative" for _ in range(m): new_projections.append(p) if dcopy: new_projections = [copy.deepcopy(p) for p in new_projections] return ProjectionsSet(projections=new_projections)
[docs] def maximize_distance(self, r0=1): rot, trans = self.map_free_vars num_free_vars = self.num_free_vars real_lattice = self.projections[0].wyckoff_position.spacegroup.Lattice same_site = np.eye(self.num_points, dtype=bool) # not_same_site = np.logical_not(same_site) repulsive_potential = RepulsivePotential(rotation=rot, translation=trans, weights=self.num_wann_per_site_list, same_site=same_site, r0=r0, real_lattice=real_lattice) jit_potential = jjit(repulsive_potential.potential_jax) if num_free_vars > 0: free_var_values = self.free_var_values print(f"starting minimization with free vars {free_var_values} ") print(f"starting potential {jit_potential(free_var_values)}") print(f"minimal distance {self.get_min_distance()}") res = jminimize(jit_potential, free_var_values, method='BFGS') v = res.x print(f"minimized free vars {v}") print(f"minimized potential {jit_potential(v)}") else: v = jnp.zeros(0) pot = jit_potential(v) self.free_var_values = v self.potential = pot print(f"minimal distance {self.get_min_distance()}") print(f"positions\n {self.get_positions().round(4)}") print(f"distances\n {self.get_distances().round(2)}")
@property def free_var_values(self): return np.hstack([proj.wyckoff_position.free_var_values for proj in self.projections]) @free_var_values.setter def free_var_values(self, value): start = 0 for proj in self.projections: end = start + proj.wyckoff_position.num_free_vars proj.wyckoff_position.free_var_values = value[start:end] start = end
[docs] class RepulsivePotential: """ A class to store the repulsive potential between the projections Parameters ---------- rotation : np.ndarray(shape=(num_points, 3, nfree_vars), dtype=float) The rotation matrices to get the symmetry operations translation : np.ndarray(shape=(num_points, 3), dtype=float) The translation vectors of the symmetry operations weights : np.ndarray(shape=(num_points, dtype=float) The weights of the symmetry operations same_site : np.ndarray(shape=(num_points, num_points), dtype=bool) """ def __init__(self, rotation, translation, weights=None, same_site=None, r0=1, real_lattice=jnp.eye(3), max_G_r0=5): assert rotation.ndim == 3, f"rotation.ndim = {rotation.ndim}, should be 3" assert translation.ndim == 2, f"translation.ndim = {translation.ndim}, should be 2" assert rotation.shape[0] == translation.shape[0], f"rotation.shape = {rotation.shape}, translation.shape = {translation.shape}" # print ("rotation",repr(rotation)) # print ("translation",repr(translation)) if weights is None: weights = np.ones(rotation.shape[0]) else: weights = np.array(weights) assert weights.ndim == 1 assert weights.shape[0] == rotation.shape[0], f"weights.shape = {weights.shape}, rotation.shape = {rotation.shape}, weights = {weights}" self.weights = weights[:, None] * weights[None, :] if same_site is None: same_site = np.eye(self.weights.shape[0], dtype=bool) else: assert same_site.shape == self.weights.shape self.weights[same_site] = 0 self.rotation = jnp.array(rotation) self.translation = jnp.array(translation) self.num_free_vars = self.rotation.shape[2] # free_vars_random = jnp.random.rand(self.num_free_vars) num_pos = self.rotation.shape[0] assert self.rotation.shape[0] == self.translation.shape[0] real_lattice = np.array(real_lattice) / abs(np.linalg.det(real_lattice))**(1 / 3) reciprocal_lattice = np.linalg.inv(real_lattice).T # print (f"reciprocal_lattice = {reciprocal_lattice}, {np.linalg.det(reciprocal_lattice)}") r0 = r0 / num_pos**(1 / 3) maxG = 10 max_mod_G = max_G_r0 / r0 G = np.array([[i, j, k] for i in range(-maxG, maxG + 1) for j in range(-maxG, maxG + 1) for k in range(-maxG, maxG + 1)]) g = np.linalg.norm(G @ reciprocal_lattice, axis=1) select = g <= max_mod_G self.G = jnp.array(G[select]) g = g[select] self.Ug = jnp.exp(-g**2 * r0**2 / 2.0)
[docs] def potential_jax(self, free_vars): V = (self.rotation @ free_vars + self.translation) % 1 diff = (V[None, :] - V[:, None]) return jnp.sum((jnp.cos(2 * np.pi * jnp.dot(diff, self.G.T)) @ self.Ug) * self.weights)
[docs] def get_orbit(spacegroup, p, tol=1e-5): """ Get the orbit of a point p under the symmetry operations of a structure. Parameters ---------- spacegroup : irrep.spacegroup.SpaceGroup The spacegroup of the structure. p : np.ndarray(shape=(3,), dtype=float) Point for which to calculate the orbit in the reduced coordinates. Returns ------- UniqueListMod1 of np.ndarray(shape=(3,), dtype=float) The orbit of v under the symmetry operations of the structure. """ return UniqueListMod1([symop.transform_r(p) % 1 for symop in spacegroup.symmetries], tol=tol)
[docs] def check_orbit(spacegroup, positions, tol=1e-5): """ check if the positions are in the same orbit of the spacegroup Parameters ---------- spacegroup : irrep.spacegroup.SpaceGroup The spacegroup of the structure. positions : np.ndarray(shape=(N,3), dtype=float) Points which are checked to transform into each other under the symmetry operations. Returns ------- bool True if the points are in the same orbit, False otherwise. """ orbit = get_orbit(spacegroup, positions[0], tol=tol) for p in positions[1:]: if p not in orbit: return False return True
[docs] def orbit_and_rottrans(spacegroup, p): """ Get the orbit of a point p under the symmetry operations of a structure. and the corresponding rotation matrices and translation vectors. Parameters ---------- spacegroup : irrep.spacegroup.SpaceGroup The spacegroup object. p : np.ndarray(shape=(3,), dtype=float) Point for which to calculate the orbit in the reduced coordinates. Returns ------- np.ndarray(shape=(N, 3) The orbit of v under the symmetry operations of the structure. np.ndarray(shape=(N, 3, 3) The rotation matrices of the symmetry operations np.ndarray(shape=(N, 3) The translation vectors of the symmetry operations """ orbit = get_orbit(spacegroup, p) ind_oper = orbit.appended_indices rotations = [] translations = [] for i in ind_oper: symop = spacegroup.symmetries[i] rotations.append(symop.rotation) translations.append(symop.translation) return np.array(orbit), np.array(rotations), np.array(translations)
[docs] def read_xzaxis(xaxis, zaxis): if zaxis is not None: zaxis = np.array(zaxis) assert zaxis.shape == (3,), f"zaxis should be a 3-vector, not an array of {zaxis.shape}" assert np.linalg.norm(zaxis) > 1e-3, f"zaxis should be a non-zero vector, found length {np.linalg.norm(zaxis)}" zaxis = zaxis / np.linalg.norm(zaxis) if xaxis is not None: xaxis = np.array(xaxis) assert xaxis.shape == (3,), f"xaxis should be a 3-vector, not an array of {xaxis.shape}" assert np.linalg.norm(xaxis) > 1e-3, f"xaxis should be a non-zero vector, found length {np.linalg.norm(xaxis)}" xaxis = xaxis / np.linalg.norm(xaxis) match (xaxis, zaxis): case (None, None): return np.eye(3, dtype=float) case (None, _): xaxis = get_perpendicular_coplanar_vector(zaxis, np.array([1, 0, 0])) case (_, None): zaxis = get_perpendicular_coplanar_vector(xaxis, np.array([0, 0, 1])) case (_, _): assert np.abs(np.dot(xaxis, zaxis)) < 1e-3, f"xaxis and zaxis should be orthogonal, found dot product of normalized vectors : {np.dot(xaxis, zaxis)}" yaxis = np.cross(zaxis, xaxis) return np.array([xaxis, yaxis, zaxis])
[docs] def get_perpendicular_coplanar_vector(a, b): """return a vector c perpendicular to a and coplanar with both a and b and such that (b.c)>0 Parameters ---------- a : np.ndarray(3,) The first vector b : np.ndarray(3,) The second vector Returns ------- np.ndarray(3,) The perpendicular vector(normalized) """ c = np.cross(a, b) if np.linalg.norm(c) > 1e-5: c = np.cross(c, a) return c / np.linalg.norm(c) else: raise ValueError(f"the vectors {a} and {b} are collinear, their cross product is {c}, norm {np.linalg.norm(c)}")
[docs] def find_distance_periodic(positions, real_lattice, max_shift=2): """ find the distances between the pairs of atoms in a list of positions the distance to the closest image in the periodic lattice is returned Parameters ---------- positions : np.ndarray( (num_atoms,3), dtype=float) The list of atomic positions in reduced coordinates. real_lattice : np.ndarray((3,3), dtype=float) The lattice vectors. Returns ------- np.ndarray( (num_atoms,num_atoms), dtype=float) The distance between the pairs atoms. """ if len(positions) == 0: return np.array([[np.inf]]) positions = np.array(positions) % 1 shifts = get_shifts(max_shift) diff = positions[:, None, None, :] - positions[None, :, None, :] + shifts[None, None, :, :] metric = real_lattice @ real_lattice.T prod = np.einsum('ijla,ab,ijlb->ijl', diff, metric, diff) rng = np.arange(len(positions)) prod[rng, rng, 0] = np.inf # distance to itself is not interesting, so the distance to its nearest image is counted distances2 = np.min(prod, axis=2) return np.sqrt(distances2)