import re
import json
import numpy as np
from pathlib import Path
import requests as req
import inspect
from itertools import combinations, product
from functools import lru_cache
from typing import NamedTuple
from scipy.spatial import ConvexHull, KDTree
import plotly.graph_objects as go
import matplotlib.pyplot as plt # For viewpoint
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import matplotlib.colors as mplc
from ipywidgets import interactive, IntSlider
# Inside packages import
from .core import plot_toolkit as ptk
from .core import parser as vp, serializer
from .core.spatial_toolkit import (
to_plane,
rotation,
inside_convexhull, # be there for export
to_basis,
to_R3,
get_TM,
get_bz,
coplanar,
)
from .core.plot_toolkit import quiver3d
from .utils import color as tcolor
# These colors are taken from Mathematica's ColorData["Atoms"]
_atom_colors = {
"H": (0.7, 0.8, 0.7),
"He": (0.8367, 1.0, 1.0),
"Li": (0.7994, 0.9976, 0.5436),
"Be": (0.7706, 0.0442, 0.9643),
"B": (1.0, 0.5, 0),
"C": (0.4, 0.4, 0.4),
"N": (143 / 255, 143 / 255, 1),
"O": (0.8005, 0.1921, 0.2015),
"F": (128 / 255, 1, 0),
"Ne": (0.6773, 0.9553, 0.9284),
"Na": (0.6587, 0.8428, 0.4922),
"Mg": (0.6283, 0.0783, 0.8506),
"Al": (173 / 255, 178 / 255, 189 / 255),
"Si": (248 / 255, 209 / 255, 152 / 255),
"P": (1, 165 / 255, 0),
"S": (1, 200 / 255, 50 / 255),
"Cl": (0, 0.9, 0),
"Ar": (0.5461, 0.8921, 0.8442),
"K": (0.534, 0.7056, 0.4207),
"Ca": (0.4801, 0.0955, 0.7446),
"Sc": (0.902, 0.902, 0.902),
"Ti": (0.749, 0.7804, 0.7608),
"V": (0.651, 0.6706, 0.651),
"Cr": (0.5412, 0.7804, 0.6),
"Mn": (0.6118, 0.7804, 0.4784),
"Fe": (0.32, 0.33, 0.35),
"Co": (0.9412, 0.6275, 0.5647),
"Ni": (141 / 255, 142 / 255, 140 / 255),
"Cu": (184 / 255, 115 / 255, 51 / 255),
"Zn": (186 / 255, 196 / 255, 200 / 255),
"Ga": (90 / 255, 180 / 255, 189 / 255),
"Ge": (0.6051, 0.5765, 0.6325),
"As": (50 / 255, 71 / 255, 57 / 255),
"Se": (0.9172, 0.0707, 0.6578),
"Br": (161 / 255, 61 / 255, 45 / 255),
"Kr": (0.426, 0.8104, 0.7475),
"Rb": (0.4254, 0.5859, 0.3292),
"Sr": (0.326, 0.096, 0.6464),
"Y": (0.531, 1.0, 1.0),
"Zr": (0.4586, 0.9186, 0.9175),
"Nb": (0.385, 0.8417, 0.8349),
"Mo": (0.3103, 0.7693, 0.7522),
"Tc": (0.2345, 0.7015, 0.6694),
"Ru": (0.1575, 0.6382, 0.5865),
"Rh": (0.0793, 0.5795, 0.5036),
"Pd": (0.0, 0.5252, 0.4206),
"Ag": (0.7529, 0.7529, 0.7529),
"Cd": (0.8, 0.67, 0.73),
"In": (228 / 255, 228 / 255, 228 / 255),
"Sn": (0.398, 0.4956, 0.4915),
"Sb": (158 / 255, 99 / 255, 181 / 255),
"Te": (0.8167, 0.0101, 0.4513),
"I": (48 / 255, 25 / 255, 52 / 255),
"Xe": (0.3169, 0.7103, 0.6381),
"Cs": (0.3328, 0.4837, 0.2177),
"Ba": (0.1659, 0.0797, 0.556),
"La": (0.9281, 0.3294, 0.7161),
"Ce": (0.8948, 0.3251, 0.7314),
"Pr": (0.8652, 0.3153, 0.708),
"Nd": (0.8378, 0.3016, 0.663),
"Pm": (0.812, 0.2856, 0.6079),
"Sm": (0.7876, 0.2683, 0.5499),
"Eu": (0.7646, 0.2504, 0.4933),
"Gd": (0.7432, 0.2327, 0.4401),
"Tb": (0.7228, 0.2158, 0.3914),
"Dy": (0.7024, 0.2004, 0.3477),
"Ho": (0.68, 0.1874, 0.3092),
"Er": (0.652, 0.1778, 0.2768),
"Tm": (0.6136, 0.173, 0.2515),
"Yb": (0.5579, 0.1749, 0.2346),
"Lu": (0.4757, 0.1856, 0.2276),
"Hf": (0.7815, 0.7166, 0.7174),
"Ta": (0.7344, 0.6835, 0.5445),
"W": (0.6812, 0.6368, 0.3604),
"Re": (0.6052, 0.5563, 0.3676),
"Os": (0.5218, 0.4692, 0.3821),
"Ir": (0.4456, 0.3991, 0.3732),
"Pt": (0.8157, 0.8784, 0.8157),
"Au": (0.8, 0.7, 0.2),
"Hg": (0.7216, 0.8157, 0.7216),
"Tl": (0.651, 0.302, 0.3294),
"Pb": (0.3412, 0.3804, 0.349),
"Bi": (10 / 255, 49 / 255, 93 / 255),
"Po": (0.6706, 0.0, 0.3608),
"At": (0.4588, 0.2706, 0.3098),
"Rn": (0.2188, 0.5916, 0.5161),
"Fr": (0.2563, 0.3989, 0.0861),
"Ra": (0.0, 0.0465, 0.4735),
"Ac": (0.322, 0.9885, 0.7169),
"Th": (0.3608, 0.943, 0.6717),
"Pa": (0.3975, 0.8989, 0.628),
"U": (0.432, 0.856, 0.586),
"Np": (0.4645, 0.8145, 0.5455),
"Pu": (0.4949, 0.7744, 0.5067),
"Am": (0.5233, 0.7355, 0.4695),
"Cm": (0.5495, 0.698, 0.4338),
"Bk": (0.5736, 0.6618, 0.3998),
"Cf": (0.5957, 0.6269, 0.3675),
"Es": (0.6156, 0.5934, 0.3367),
"Fm": (0.6335, 0.5612, 0.3075),
"Md": (0.6493, 0.5303, 0.2799),
"No": (0.663, 0.5007, 0.254),
"Lr": (0.6746, 0.4725, 0.2296),
"Rf": (0.6841, 0.4456, 0.2069),
"Db": (0.6915, 0.42, 0.1858),
"Sg": (0.6969, 0.3958, 0.1663),
"Bh": (0.7001, 0.3728, 0.1484),
"Hs": (0.7013, 0.3512, 0.1321),
"Mt": (0.7004, 0.331, 0.1174),
"Ds": (0.6973, 0.312, 0.1043),
"Rg": (0.6922, 0.2944, 0.0928),
"Cn": (0.6851, 0.2781, 0.083),
"Nh": (0.6758, 0.2631, 0.0747),
"Fl": (0.6644, 0.2495, 0.0681),
"Mc": (0.6509, 0.2372, 0.0631),
"Lv": (0.6354, 0.2262, 0.0597),
"Ts": (0.6354, 0.2262, 0.0566),
"Og": (0.6354, 0.2262, 0.0528),
}
_atom_numbers = {k: i for i, k in enumerate(_atom_colors.keys())}
def atomic_number(atom):
"Return atomic number of atom"
return _atom_numbers[atom]
def atoms_color():
"Defualt color per atom used for plotting the crystal lattice"
return serializer.Dict2Data(
{k: [round(_v, 4) for _v in rgb] for k, rgb in _atom_colors.items()}
)
[docs]
def periodic_table(selection=None):
"Display colorerd elements in periodic table. Use a list of atoms to only color a selection."
_copy_names = np.array(
[f"$^{{{str(i+1)}}}${k}" for i, k in enumerate(_atom_colors.keys())]
)
blank = []
if isinstance(selection,(list, tuple, str)):
if isinstance(selection, str):
selection = selection.split()
blank = [key for key in _atom_colors if not (key in selection)]
_copy_array = np.array([[1,1,1,0] if key in blank else [*value,1] for key, value in _atom_colors.items()])
names = ["" for i in range(180)] # keep as list before modification
fc = np.ones((180, 4))
ec = np.zeros((180,3)) + (0.4 if blank else 0.9 )
offsets = np.array([[(i,j) for i in range(18)] for j in range(10)]).reshape((-1,2)) - 0.5
inds = np.array([
(0, 0),
(17, 1),
(18, 2),
(19, 3),
*[(30 + i, 4 + i) for i in range(8)],
*[(48 + i, 12 + i) for i in range(6)],
*[(54 + i, 18 + i) for i in range(18)],
*[(72 + i, 36 + i) for i in range(18)],
*[(90 + i, 54 + i) for i in range(3)],
*[(93 + i, 71 + i) for i in range(15)],
*[(108 + i, 86 + i) for i in range(3)],
*[(111 + i, 103 + i) for i in range(15)],
*[(147 + i, 57 + i) for i in range(14)],
*[(165 + i, 89 + i) for i in range(14)],
], dtype=int)
for i, j in inds:
fc[i,:] = _copy_array[j]
names[i] = _copy_names[j]
fidx = [i for i, _ in inds] # only plot at elements posistions,otherwise they overlap
offsets = offsets[fidx]
fc, ec = fc[fidx], ec[fidx]
names = np.array(names)[fidx]
# We are adding patches, because imshow does not properly appear in PDF of latex
ax = ptk.get_axes((7, 3.9),left=0.01,right=0.99,top=0.99,bottom=0.01)
patches = np.array([Rectangle(offset,0.9 if i in [92,110] else 1,1) for i, offset in zip(fidx,offsets)])
pc = PatchCollection(patches, facecolors=fc, edgecolors=ec,linewidths=(0.7,))
ax.add_collection(pc)
for (x,y), text, c in zip(offsets + 0.5, names, fc):
c = "k" if np.linalg.norm(c[:3]) > 1 else "w"
plt.text(x,y, text, color=c, ha="center", va="center")
ax.set_axis_off()
ax.set(xlim=[-0.6,17.6],ylim=[9.6,-0.6]) # to show borders correctly
return ax
def write_poscar(poscar_data, outfile=None, selective_dynamics=None, overwrite=False, comment="", scale=None, system=None):
"""Writes POSCAR data to a file or returns string
Parameters
----------
outfile : PathLike
selective_dynamics : callable
If given, should be a function like `f(a) -> (a.p < 1/4)` or `f(a) -> (a.x < 1/4, a.y < 1/4, a.z < 1/4)`
which turns on/off selective dynamics for each atom based in each dimension.
See `ipyvasp.POSCAR.data.get_selective_dynamics` for more info.
overwrite: bool
If file already exists, overwrite=True changes it.
comment: str
Add comment, previous comment will be there too.
scale: float
Scale factor for the basis vectors. Default is provided by loaded data.
system: str
System name to be used in POSCAR file instead of the one in `poscar_data.SYSTEM`.
.. note::
POSCAR is only written in direct format even if it was loaded from cartesian format.
"""
_comment = poscar_data.metadata.comment + comment
out_str = f"{system or poscar_data.SYSTEM} # " + (_comment or "Created by ipyvasp")
if scale is None:
scale = poscar_data.metadata.scale
elif not isinstance(scale, (int, float)):
raise TypeError("scale must be a number or None.")
elif scale == 0:
raise ValueError("scale can not be zero.")
out_str += "\n {:<20.14f}\n".format(scale)
out_str += "\n".join(
["{:>22.16f}{:>22.16f}{:>22.16f}".format(*a) for a in poscar_data.basis / scale]
)
uelems = poscar_data.types.to_dict()
out_str += "\n " + " ".join(uelems.keys())
out_str += "\n " + " ".join([str(len(v)) for v in uelems.values()])
if selective_dynamics is not None:
out_str += "\nSelective Dynamics"
out_str += "\nDirect\n"
positions = poscar_data.positions
pos_list = ["{:>21.16f}{:>21.16f}{:>21.16f}".format(*a) for a in positions]
if selective_dynamics is not None:
sd = poscar_data.get_selective_dynamics(selective_dynamics).values()
pos_list = [f"{p} {s}" for p, s in zip(pos_list, sd)]
out_str += "\n".join(pos_list)
if outfile:
path = Path(outfile)
if not path.is_file():
with path.open("w", encoding="utf-8") as f:
f.write(out_str)
elif overwrite and path.is_file():
with path.open("w", encoding="utf-8") as f:
f.write(out_str)
else:
raise FileExistsError(
f"{outfile!r} exists, can not overwrite, \nuse overwrite=True if you want to change."
)
else:
print(out_str)
def export_poscar(path=None, content=None):
"""Export POSCAR file to python objects.
Parameters
----------
path : PathLike
Path/to/POSCAR file. Auto picks in CWD.
content : str
POSCAR content as string, This takes precedence to path.
"""
if content and isinstance(content, str):
file_lines = [f"{line}\n" for line in content.splitlines()]
else:
P = Path(path or "./POSCAR")
if not P.is_file():
raise FileNotFoundError(f"{str(P)} not found.")
with P.open("r", encoding="utf-8") as f:
file_lines = f.readlines()
header = file_lines[0].split("#", 1)
SYSTEM = header[0].strip()
comment = header[1].strip() if len(header) > 1 else "Exported by Pivopty"
scale = float(file_lines[1].strip().split()[0]) # some people add comments here too
if scale < 0: # If that is for volume
scale = 1
basis = scale * vp.gen2numpy(file_lines[2:5], (3, 3), [-1, -1], exclude=None)
# volume = np.linalg.det(basis)
# rec_basis = np.linalg.inv(basis).T # general formula
out_dict = {
"SYSTEM": SYSTEM, #'volume':volume,
"basis": basis, #'rec_basis':rec_basis,
"metadata": {"comment": comment, "scale": scale},
}
elems = file_lines[5].split()
ions = [int(i) for i in file_lines[6].split()]
N = int(np.sum(ions)) # Must be py int, not numpy
inds = np.cumsum([0, *ions]).astype(int)
# Check Cartesian and Selective Dynamics
lines = [l.strip() for l in file_lines[7:9]] # remove whitespace or tabs
out_dict["metadata"]["cartesian"] = (
True if ((lines[0][0] in "cCkK") or (lines[1][0] in "cCkK")) else False
)
poslines = vp.gen2numpy(
file_lines[7:],
(N, 6),
(-1, [0, 1, 2]),
exclude="^\s+[a-zA-Z]|^[a-zA-Z]",
raw=True,
).splitlines() # handle selective dynamics word here
positions = np.array(
[line.split()[:3] for line in poslines], dtype=float
) # this makes sure only first 3 columns are taken
if out_dict["metadata"]["cartesian"]:
positions = scale * to_basis(basis, positions)
print(("Cartesian format found in POSCAR file, converted to direct format."))
unique_d = {}
for i, e in enumerate(elems):
unique_d.update({e: range(inds[i], inds[i + 1])})
elem_labels = []
for i, name in enumerate(elems):
for ind in range(inds[i], inds[i + 1]):
elem_labels.append(f"{name} {str(ind - inds[i] + 1)}")
out_dict.update({"positions": positions, "types": unique_d}) #'labels':elem_labels,
return serializer.PoscarData(out_dict)
# Cell
def _save_mp_API(api_key):
"Save materials project api key for autoload in functions. This works only for legacy API."
path = Path.home() / ".ipyvasprc"
lines = []
if path.is_file():
with path.open("r") as fr:
lines = fr.readlines()
lines = [line for line in lines if "MP_API_KEY" not in line]
with path.open("w") as fw:
fw.write("MP_API_KEY = {}".format(api_key))
for line in lines:
fw.write(line)
# Cell
def _load_mp_data(formula, api_key=None, mp_id=None, max_sites=None, min_sites=None):
if api_key is None:
try:
path = Path.home() / ".ipyvasprc"
with path.open("r") as f:
lines = f.readlines()
for line in lines:
if "MP_API_KEY" in line:
api_key = line.split("=")[1].strip()
except:
raise ValueError(
"api_key not given. provide in argument or generate in file using `_save_mp_API(your_mp_api_key)"
)
# url must be a raw string
url = r"https://legacy.materialsproject.org/rest/v2/materials/{}/vasp?API_KEY={}".format(
formula, api_key
)
resp = req.request(method="GET", url=url)
if resp.status_code != 200:
raise ValueError("Error in fetching data from materials project. Try again!")
jl = json.loads(resp.text)
if not "response" in jl: # check if response
raise ValueError("Either formula {!r} or API_KEY is incorrect.".format(formula))
all_res = jl["response"]
if max_sites != None and min_sites != None:
lower, upper = min_sites, max_sites
elif max_sites == None and min_sites != None:
lower, upper = min_sites, min_sites + 1
elif max_sites != None and min_sites == None:
lower, upper = max_sites - 1, max_sites
else:
lower, upper = "-1", "-1" # Unknown
if lower != "-1" and upper != "-1":
sel_res = []
for res in all_res:
if res["nsites"] <= upper and res["nsites"] >= lower:
sel_res.append(res)
return sel_res
# Filter to mp_id at last. more preferred
if mp_id != None:
for res in all_res:
if mp_id == res["material_id"]:
return [res]
return all_res
def _cif_str_to_poscar_str(cif_str, comment=None):
# Using it in other places too
lines = [
line for line in cif_str.splitlines() if line.strip()
] # remove empty lines
abc = []
abc_ang = []
index = 0
for ys in lines:
if "_cell" in ys:
if "_length" in ys:
abc.append(ys.split()[1])
if "_angle" in ys:
abc_ang.append(ys.split()[1])
if "_volume" in ys:
volume = float(ys.split()[1])
if "_structural" in ys:
top = ys.split()[1] + f" # {comment}" if comment else ys.split()[1]
for i, ys in enumerate(lines):
if "_atom_site_occupancy" in ys:
index = i + 1 # start collecting pos.
poses = lines[index:]
pos_str = ""
for pos in poses:
s_p = pos.split()
pos_str += "{0:>12} {1:>12} {2:>12} {3}\n".format(*s_p[3:6], s_p[0])
names = [re.sub("\d+", "", pos.split()[1]).strip() for pos in poses]
types = []
for name in names:
if name not in types:
types.append(name) # unique types, don't use numpy here.
# ======== Cleaning ===========
abc_ang = [float(ang) for ang in abc_ang]
abc = [float(a) for a in abc]
a = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(1.0, 0.0, 0.0) # lattic vector a.
to_rad = 0.017453292519
gamma = abc_ang[2] * to_rad
bx, by = abc[1] * np.cos(gamma), abc[1] * np.sin(gamma)
b = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(
bx / abc[0], by / abc[0], 0.0
) # lattic vector b.
cz = volume / (abc[0] * by)
cx = abc[2] * np.cos(abc_ang[1] * to_rad)
cy = (abc[1] * abc[2] * np.cos(abc_ang[0] * to_rad) - bx * cx) / by
c = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(
cx / abc[0], cy / abc[0], cz / abc[0]
) # lattic vector b.
elems = "\t".join(types)
nums = [str(len([n for n in names if n == t])) for t in types]
nums = "\t".join(nums)
content = (
f"{top}\n {abc[0]}\n {a}\n {b}\n {c}\n {elems}\n {nums}\nDirect\n{pos_str}"
)
return content
class InvokeMaterialsProject:
"""Connect to materials project and get data using `api_key` from their site.
Usage
-----
>>> from ipyvaspr.sio import InvokeMaterialsProject # or import ipyvasp.InvokeMaterialsProject as InvokeMaterialsProject
>>> mp = InvokeMaterialsProject(api_key='your_api_key')
>>> outputs = mp.request(formula='NaCl') #returns list of structures from response
>>> outupts[0].export_poscar() #returns poscar data
>>> outputs[0].cif #returns cif data
"""
def __init__(self, api_key=None):
"Request Materials Project acess. api_key is on their site. Your only need once and it is saved for later."
self.api_key = api_key
self.__response = None
self.success = False
def save_api_key(self, api_key):
"Save api_key for auto reloading later."
_save_mp_API(api_key)
@lru_cache(maxsize=2) # cache for 2 calls
def request(self, formula, mp_id=None, max_sites=None, min_sites=None):
"Fetch data using request api of python form materials project website. After request, you can access `cifs` and `poscars`."
self.__response = _load_mp_data(
formula=formula,
api_key=self.api_key,
mp_id=mp_id,
max_sites=max_sites,
min_sites=min_sites,
)
if self.__response == []:
raise req.HTTPError("Error in request. Check your api_key or formula.")
class Structure:
def __init__(self, response):
self._cif = response["cif"]
self.symbol = response["spacegroup"]["symbol"]
self.crystal = response["spacegroup"]["crystal_system"]
self.unit = response["unit_cell_formula"]
self.mp_id = response["material_id"]
@property
def cif(self):
return self._cif
def __repr__(self):
return f"Structure(unit={self.unit},mp_id={self.mp_id!r},symbol={self.symbol!r},crystal={self.crystal!r},cif='{self._cif[:10]}...')"
def write_cif(self, outfile=None):
if isinstance(outfile, str):
with open(outfile, "w") as f:
f.write(self._cif)
else:
print(self._cif)
def write_poscar(self, outfile=None, overwrite=False, comment="",scale=None):
"Use `ipyvasp.lattice.POSCAR.write` if you need extra options."
write_poscar(self.export_poscar(), outfile=outfile, overwrite=overwrite, comment=comment, scale=scale)
def export_poscar(self):
"Export poscar data form cif content."
content = _cif_str_to_poscar_str(
self._cif,
comment=f"[{self.mp_id!r}][{self.symbol!r}][{self.crystal!r}] Created by ipyvasp using Materials Project Database",
)
return export_poscar(content=content)
# get cifs
structures = []
for res in self.__response:
structures.append(Structure(res))
self.success = True # set success flag
return structures
def _str2kpoints(kpts_str):
try:
with open(kpts_str, "r", encoding="utf-8") as f:
kpts_str = f.read()
except:
pass
hsk_list = []
for j, line in enumerate(kpts_str.splitlines()):
if line.strip(): # Make sure line is not empty
data = line.split()
if len(data) < 3:
raise ValueError(f"Line {j + 1} has less than 3 values.")
point = [float(i) for i in data[:3]]
if len(data) == 4:
_4th = (
data[3]
if re.search("\$\\\\[a-zA-Z]+\$|[a-zA-Z]+|[α-ωΑ-Ω]+|\|_", line)
else int(data[3])
)
point.append(_4th)
elif len(data) == 5:
_5th = int(data[4])
point = point + [data[3], _5th]
hsk_list.append(point)
return hsk_list
[docs]
def get_kpath(
kpoints,
n: int = 5,
weight: float = None,
ibzkpt: str = None,
outfile: str = None,
rec_basis=None,
):
"""Generate list of kpoints along high symmetry path. Options are write to file or return KPOINTS list.
It generates uniformly spaced point with input `n` as just a scale factor of number of points per average length of `rec_basis`.
Parameters
----------
kpoints : list or str
Any number points as [(x,y,z,[label],[N]), ...]. N adds as many points in current interval.
To disconnect path at a point, provide it as (x,y,z,[label], 0), next point will be start of other patch.
If `kpoints` is a multiline string, it is converted to list of points. Each line should be in format "x y z [label] [N]".
A file path can be provided to read kpoints from file with same format as multiline string.
n : int
Number of point per averge length of `rec_basis`, this makes uniform steps based on distance between points.
If (x,y,z,[label], N) is provided, this is ignored for that specific interval. If `rec_basis` is not provided, each interval has exactly `n` points.
Number of points in each interval is at least 2 even if `n` is less than 2 to keep end points anyway.
weight : float
None by default to auto generates weights.
ibzkpt : PathLike
Path to ibzkpt file, required for HSE calculations.
outfile : PathLike
Path/to/file to write kpoints.
rec_basis : array_like
Reciprocal basis 3x3 array to use for calculating uniform points.
If `outfile = None`, KPONITS file content is printed.
"""
if isinstance(kpoints, str):
kpoints = _str2kpoints(kpoints)
elif not isinstance(kpoints, (list, tuple, np.ndarray)):
raise TypeError(
f"kpoints must be a sequence as [(x,y,z,[label],[N]), ...], or multiline string got {kpoints}"
)
if len(kpoints) < 2:
raise ValueError("At least two points are required.")
fixed_patches = []
where_zero = []
for idx, point in enumerate(kpoints):
if not isinstance(point, (list, tuple)):
raise TypeError(
f"kpoint must be a list or tuple as (x,y,z,[label],[N]), got {point}"
)
cpt = point # same for length 5, 4 with last entry as string
if len(point) == 3:
cpt = [*point, ""] # make (x,y,z,label)
elif len(point) == 4:
if isinstance(point[3], (int, np.integer)):
cpt = [*point[:3], "", point[-1]] # add full point as (x,y,z,label, N)
elif not isinstance(point[3], str):
raise TypeError(
f"4th entry in kpoint should be string label or int number of points for next interval if label is skipped, got {point}"
)
elif len(point) == 5:
if not isinstance(point[3], str):
raise TypeError(
f"4th entry in kpoint should be string label when 5 entries are given, got {point}"
)
if not isinstance(point[4], (int, np.integer)):
raise TypeError(
f"5th entry in kpoint should be an integer to add that many points in interval, got {point}"
)
else:
raise ValueError(f"Expects kpoint as (x,y,z,[label],[N]), got {point}")
if isinstance(cpt[-1], (int, np.integer)) and cpt[-1] == 0:
if idx - 1 in where_zero:
raise ValueError(
f"Break at adjacent kpoints {idx}, {idx+1} is not allowed!"
)
if any([idx < 1, idx > (len(kpoints) - 3)]):
raise ValueError("Bad break at edges!")
where_zero.append(idx)
fixed_patches.append(cpt)
def add_points(p1, p2, npts, rec_basis):
lab = p2[3] # end point label
if len(p1) == 5:
m = p1[4] # number of points given explicitly.
lab = (
f"<={p1[3]}|{lab}" if m == 0 else lab
) # merge labels in case user wants to break path
elif rec_basis is not None and np.size(rec_basis) == 9:
basis = np.array(rec_basis)
coords = to_R3(basis, [p1[:3], p2[:3]])
_mean = np.mean(
np.linalg.norm(basis, axis=1)
) # average length of basis vectors
m = np.rint(npts * np.linalg.norm(coords[0] - coords[1]) / _mean).astype(
int
) # number of points in interval
else:
m = npts # equal number of points in each interval, given by n.
# Doing m - 1 in an interval, so along with last point, total n points are generated per interval.
Np = max(m - 1, 1) # At least 2 points. one is given by end point of interval.
X = np.linspace(p1[0], p2[0], Np, endpoint=False)
Y = np.linspace(p1[1], p2[1], Np, endpoint=False)
Z = np.linspace(p1[2], p2[2], Np, endpoint=False)
kpts = [(x, y, z) for x, y, z in zip(X, Y, Z)]
return (kpts, Np, lab)
points, numbers, labels = [], [0], [fixed_patches[0][3]]
for p1, p2 in zip(fixed_patches[:-1], fixed_patches[1:]):
kp, m, lab = add_points(p1, p2, n, rec_basis)
points.extend(kp)
numbers.append(numbers[-1] + m)
labels.append(lab)
if lab.startswith("<="):
labels[-2] = "" # remove label for end of interval if broken, added here
else: # Add last point at end of for loop
points.append(p2[:3])
if weight is None and points:
weight = 0 if ibzkpt else 1 / len(points) # With IBZKPT, we need zero weight
out_str = [
"{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x, y, z, weight)
for x, y, z in points
]
out_str = "\n".join(out_str)
N = len(points)
if (PI := Path(ibzkpt or "")).is_file(): # handles None automatically
with PI.open("r") as f:
lines = f.readlines()
N = int(lines[1].strip()) + N # Update N.
slines = lines[3 : N + 4]
ibz_str = "".join(slines)
out_str = "{}\n{}".format(ibz_str.strip("\n"), out_str)
path_info = ", ".join(
f"{idx}:{lab}" for idx, lab in zip(numbers, labels) if lab != ""
)
top_str = "Automatically generated using ipyvasp for HSK-PATH {}\n\t{}\nReciprocal Lattice".format(
path_info, N
)
out_str = "{}\n{}".format(top_str, out_str)
if outfile != None:
with open(outfile, "w", encoding="utf-8") as f: # allow unicode
f.write(out_str)
else:
print(out_str)
# Cell
[docs]
def get_kmesh(
poscar_data,
*args,
shift=0,
weight=None,
cartesian=False,
ibzkpt=None,
outfile=None,
endpoint=True,
):
"""Generates uniform mesh of kpoints. Options are write to file, or return KPOINTS list.
Parameters
----------
poscar_data : ipyvasp.POSCAR.data
*args : tuple
1 or 3 integers which decide shape of mesh. If 1, mesh points equally spaced based on data from POSCAR.
shift : float
Only works if cartesian = False. Defualt is 0. Could be a number or list of three numbers to add to interval [0,1].
weight : float
If None, auto generates weights.
cartesian : bool
If True, generates cartesian mesh.
ibzkpt : PathLike
Path to ibzkpt file, required for HSE calculations.
outfile : PathLike
Path/to/file to write kpoints.
endpoint : bool
Default True, include endpoints in mesh at edges away from origin.
If `outfile = None`, KPOINTS file content is printed.
"""
if len(args) not in [1, 3]:
raise ValueError("get_kmesh() takes 1 or 3 args!")
if cartesian:
norms = np.ptp(poscar_data.rec_basis, axis=0)
else:
norms = np.linalg.norm(poscar_data.rec_basis, axis=1)
if len(args) == 1:
if not isinstance(args[0], (int, np.integer)):
raise ValueError("get_kmesh expects integer for first positional argument!")
nx, ny, nz = [args[0] for _ in range(3)]
weights = norms / np.max(norms) # For making largest side at given n
nx, ny, nz = np.rint(weights * args[0]).astype(int)
elif len(args) == 3:
for i, a in enumerate(args):
if not isinstance(a, (int, np.integer)):
raise ValueError("get_kmesh expects integer at position {}!".format(i))
nx, ny, nz = list(args)
low, high = np.array([[0, 0, 0], [1, 1, 1]]) + shift
if cartesian:
verts = get_bz(poscar_data.rec_basis, primitive=False).vertices
low, high = np.min(verts, axis=0), np.max(verts, axis=0)
low = (low * 2 * np.pi / poscar_data.metadata.scale).round(
12
) # Cartesian KPOINTS are in unit of 2pi/SCALE
high = (high * 2 * np.pi / poscar_data.metadata.scale).round(12)
(lx, ly, lz), (hx, hy, hz) = low, high
points = []
for k in np.linspace(lz, hz, nz, endpoint=endpoint):
for j in np.linspace(ly, hy, ny, endpoint=endpoint):
for i in np.linspace(lx, hx, nx, endpoint=endpoint):
points.append([i, j, k])
points = np.array(points)
points[np.abs(points) < 1e-10] = 0
if len(points) == 0:
raise ValueError("No KPOINTS in BZ from given input. Try larger input!")
if weight == None and len(points) != 0:
weight = float(1 / len(points))
out_str = [
"{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x, y, z, weight)
for x, y, z in points
]
out_str = "\n".join(out_str)
N = len(points)
if ibzkpt and (PI := Path(ibzkpt)):
with PI.open("r", encoding="utf-8") as f:
lines = f.readlines()
if (cartesian == False) and (lines[2].strip()[0] in "cCkK"):
raise ValueError(
"ibzkpt file is in cartesian coordinates, use get_kmesh(...,cartesian = True)!"
)
N = int(lines[1].strip()) + N # Update N.
slines = lines[3 : N + 4]
ibz_str = "".join(slines)
out_str = "{}\n{}".format(ibz_str, out_str) # Update out_str
mode = "Reciprocal" if cartesian == False else "Cartesian"
top_str = "Generated uniform mesh using ipyvasp, GRID-SHAPE = [{},{},{}]\n\t{}\n{}".format(
nx, ny, nz, N, mode
)
out_str = "{}\n{}".format(top_str, out_str)
if outfile != None:
with open(outfile, "w", encoding="utf-8") as f:
f.write(out_str)
else:
print(out_str)
# Cell
[docs]
def splot_bz(
bz_data,
plane=None,
ax=None,
color="blue",
fill=False,
fill_zorder=0,
vectors=(0, 1, 2),
colormap=None,
shade=True,
alpha=0.4,
zoffset=0,
**kwargs,
):
"""Plots matplotlib's static figure of BZ/Cell. You can also plot in 2D on a 3D axes.
Parameters
----------
bz_data : Output of `get_bz`.
plane : str
Default is None and plots 3D surface. Can take 'xy','yz','zx' to plot in 2D.
fill : bool
True by defult, determines whether to fill surface of BZ or not.
fill_zorder : int
Default is 0, determines zorder of filled surface in 2D plots if `fill=True`.
color : Any
Color to fill surface and stroke color. Default is 'blue'. Can be any valid matplotlib color.
vectors : tuple
Tuple of indices of basis vectors to plot. Default is (0,1,2). All three are plotted in 3D
(you can turn of by None or empty tuple), whhile you can specify any two/three in 2D.
Vectors do not appear if given data is subzone data.
ax : matplotlib.pyplot.Axes
Auto generated by default, 2D/3D axes, auto converts in 3D on demand as well.
colormap : str
If None, single color is applied, only works in 3D and `fill=True`. Colormap is applied along z.
shade : bool
Shade polygons or not. Only works in 3D and `fill=True`.
alpha : float
Opacity of filling in range [0,1]. Increase for clear viewpoint.
zoffset : float
Only used if plotting in 2D over a 3D axis. Default is 0. Any plane 'xy','yz' etc.
kwargs are passed to `plt.plot` or `Poly3DCollection` if `fill=True`.
Returns
-------
matplotlib.pyplot.Axes
Matplotlib's 2D axes if `plane=None` otherswise 3D axes.
"""
vname = "a" if bz_data.__class__.__name__ == "CellData" else "b"
label = r"$k_{}/2π$" if vname == "b" else "{}"
_label = r"\vec{" + vname + "}" # For both
if vectors and not isinstance(vectors, (tuple, list)):
raise ValueError(f"`vectors` expects tuple or list, got {vectors!r}")
if vectors is None:
vectors = () # Empty tuple to make things work below
for v in vectors:
if v not in [0, 1, 2]:
raise ValueError(f"`vectors` expects values in [0,1,2], got {vectors!r}")
name = kwargs.pop("label", None) # will set only on single line
kwargs.pop("zdir", None) # remove , no need
is_subzone = hasattr(bz_data, "_specials") # For subzone
if plane: # Project 2D, works on 3D axes as well
if not ax: # Create 2D axes if not given
ax = ptk.get_axes(figsize=(3.4, 3.4)) # For better display
kwargs = {"solid_capstyle": "round", **kwargs}
is3d = getattr(ax, "name", "") == "3d"
normals = {
"xy": (0, 0, 1),
"yz": (1, 0, 0),
"zx": (0, 1, 0),
"yx": (0, 0, -1),
"zy": (-1, 0, 0),
"xz": (0, -1, 0),
}
if plane not in normals:
raise ValueError(
f"`plane` expects value in 'xyzxzyx' or None, got {plane!r}"
)
z0 = (
[0, 0, zoffset]
if plane in "xyx"
else [0, zoffset, 0]
if plane in "xzx"
else [zoffset, 0, 0]
)
idxs = {
"xy": [0, 1],
"yz": [1, 2],
"zx": [2, 0],
"yx": [1, 0],
"zy": [2, 1],
"xz": [0, 2],
}
for idx, f in enumerate(bz_data.faces_coords):
g = to_plane(normals[plane], f) + z0
(line,) = ax.plot(
*(g.T if is3d else g[:, idxs[plane]].T), color=color, **kwargs
)
if idx == 0:
line.set_label(name) # only one line
if fill and not is3d:
ax.fill(
*g[:, idxs[plane]].T,
facecolor=color,
edgecolor=color,
linewidth=0.0001,
alpha=alpha,
zorder=fill_zorder,
)
elif fill and is3d:
poly = Poly3DCollection(
[g], # 3D fill in plane
edgecolors=[color],
facecolors=[color],
alpha=alpha,
shade=shade,
zorder=fill_zorder,
)
ax.add_collection(poly)
ax.autoscale_view()
if vectors and not is_subzone:
s_basis = to_plane(normals[plane], bz_data.basis[(vectors,)])
for k, b in zip(vectors, s_basis):
x, y = b[idxs[plane]]
l = r" ${}_{} $".format(_label, k + 1)
l = l + "\n" if y < 0 else "\n" + l
ha = "right" if x < 0 else "left"
xyz = 0.8 * b + z0 if is3d else np.array([0.8 * x, 0.8 * y])
ax.text(
*xyz, l, va="center", ha=ha, clip_on=True
) # must clip to have limits of axes working.
ax.scatter(
*(xyz / 0.8), color="w", s=0.0005
) # Must be to scale below arrow.
if is3d:
XYZ, UVW = (np.ones_like(s_basis) * z0).T, s_basis.T
quiver3d(
*XYZ,
*UVW,
C=color,
L=0.7,
ax=ax,
arrowstyle="-|>",
mutation_scale=7,
)
else:
s_zero = [0 for _ in s_basis] # either 3 or 2.
ax.quiver(
s_zero,
s_zero,
*s_basis[:, idxs[plane]].T,
lw=0.7,
color=color,
angles="xy",
scale_units="xy",
scale=1,
)
ax.set_xlabel(label.format(plane[0]))
ax.set_ylabel(label.format(plane[1]))
if is3d:
lab = [v for v in "xyz" if v not in plane][0]
ax.set_zlabel(label.format(lab))
ax.set_aspect("equal")
zmin, zmax = ax.get_zlim()
if zoffset > zmax:
zmax = zoffset
elif zoffset < zmin:
zmin = zoffset
ax.set_zlim([zmin, zmax])
else:
ax.set_aspect("equal") # Must for 2D axes to show actual lengths of BZ
else: # Plot 3D
if not ax: # For 3D.
ax = ptk.get_axes(figsize=(3.4, 3.4), axes_3d=True)
if getattr(ax, "name", "") != "3d":
raise ValueError("3D axes required for 3D plot.")
if fill:
if colormap:
colormap = colormap if colormap in plt.colormaps() else "viridis"
cz = [
np.mean(np.unique(f, axis=0), axis=0)[2]
for f in bz_data.faces_coords
]
levels = (cz - np.min(cz)) / np.ptp(cz) # along Z.
colors = plt.cm.get_cmap(colormap)(levels)
else:
colors = np.array(
[[*mplc.to_rgb(color)] for f in bz_data.faces_coords]
) # Single color.
poly = Poly3DCollection(
bz_data.faces_coords,
edgecolors=[color],
facecolors=colors,
alpha=alpha,
shade=shade,
label=name,
**kwargs,
)
ax.add_collection(poly)
ax.autoscale_view()
else:
kwargs = {"solid_capstyle": "round", **kwargs}
(line,) = [
ax.plot3D(f[:, 0], f[:, 1], f[:, 2], color=(color), **kwargs)
for f in bz_data.faces_coords
][0]
line.set_label(name) # only one line
if vectors and not is_subzone:
for k, v in enumerate(0.35 * bz_data.basis):
ax.text(*v, r"${}_{}$".format(_label, k + 1), va="center", ha="center")
XYZ, UVW = [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 0.3 * bz_data.basis.T
quiver3d(
*XYZ, *UVW, C="k", L=0.7, ax=ax, arrowstyle="-|>", mutation_scale=7
)
l_ = np.min(bz_data.vertices, axis=0)
h_ = np.max(bz_data.vertices, axis=0)
ax.set_xlim([l_[0], h_[0]])
ax.set_ylim([l_[1], h_[1]])
ax.set_zlim([l_[2], h_[2]])
# Set aspect to same as data.
ax.set_box_aspect(np.ptp(bz_data.vertices, axis=0))
ax.set_xlabel(label.format("x"))
ax.set_ylabel(label.format("y"))
ax.set_zlabel(label.format("z"))
if vname == "b": # These needed for splot_kpath internally
type(bz_data)._splot_kws = dict(plane=plane, zoffset=zoffset, ax=ax)
return ax
def splot_kpath(
bz_data, kpoints, labels=None, fmt_label=lambda x: (x, {"color": "blue"}), **kwargs
):
"""Plot k-path over existing BZ. It will take ``ax``, ``plane`` and ``zoffset`` internally from most recent call to ``splot_bz``/``bz.splot``.
Parameters
----------
kpoints : array_like
List of k-points in fractional coordinates. e.g. [(0,0,0),(0.5,0.5,0.5),(1,1,1)] in order of path.
labels : list
List of labels for each k-point in same order as kpoints.
fmt_label : callable
Function that takes a label from labels and should return a string or (str, dict) of which dict is passed to ``plt.text``.
kwargs are passed to ``plt.plot`` with some defaults.
You can get ``kpoints = POSCAR.get_bz().specials.masked(lambda x,y,z : (-0.1 < z 0.1) & (x >= 0) & (y >= 0))`` to get k-points in positive xy plane.
Then you can reorder them by an indexer like ``kpoints = kpoints[[0,1,2,0,7,6]]``, note double brackets, and also that point at zero index is taken twice.
.. tip::
You can use this function multiple times to plot multiple/broken paths over same BZ.
"""
if not hasattr(bz_data, "_splot_kws"):
raise ValueError("Plot BZ first to get ax, plane and zoffset.")
if not np.ndim(kpoints) == 2 and np.shape(kpoints)[-1] == 3:
raise ValueError("kpoints must be 2D array of shape (N,3)")
plane, ax, zoffset = [
bz_data._splot_kws.get(attr, default) # class level attributes
for attr, default in zip(["plane", "ax", "zoffset"], [None, None, 0])
]
ijk = [0, 1, 2]
_mapping = {
"xy": [0, 1],
"xz": [0, 2],
"yz": [1, 2],
"zx": [2, 0],
"zy": [2, 1],
"yx": [1, 0],
}
_zoffset = [0, 0, 0]
if plane:
_zoffset = (
[0, 0, zoffset]
if plane in "xyx"
else [0, zoffset, 0]
if plane in "xzx"
else [zoffset, 0, 0]
)
if isinstance(plane, str) and plane in _mapping:
if getattr(ax, "name", None) != "3d":
ijk = _mapping[
plane
] # only change indices if axes is not 3d, even if plane is given
if not labels:
labels = [
"[{0:5.2f}, {1:5.2f}, {2:5.2f}]".format(x, y, z) for x, y, z in kpoints
]
if fmt_label is None:
fmt_label = lambda x: (x, {"color": "blue"})
_validate_label_func(fmt_label,labels[0])
coords = bz_data.to_cartesian(kpoints)
if _zoffset and plane:
normal = (
[0, 0, 1] if plane in "xyx" else [0, 1, 0] if plane in "xzx" else [1, 0, 0]
)
coords = to_plane(normal, coords) + _zoffset
coords = coords[:, ijk] # select only required indices
kwargs = {
**dict(color="blue", linewidth=0.8, marker=".", markersize=10),
**kwargs,
} # need some defaults
ax.plot(*coords.T, **kwargs)
for c, text in zip(coords, labels):
lab, textkws = fmt_label(text), {}
if isinstance(lab, (list, tuple)):
lab, textkws = lab
ax.text(*c, lab, **textkws)
return ax
# Cell
[docs]
def iplot_bz(
bz_data,
fill=False,
color="rgba(84,102,108,0.8)",
special_kpoints=True,
alpha=0.4,
ortho3d=True,
fig=None,
**kwargs,
):
"""Plots interactive figure showing axes,BZ surface, special points and basis, each of which could be hidden or shown.
Parameters
----------
bz_data : Output of `get_bz`.
fill : bool
False by defult, determines whether to fill surface of BZ or not.
color : str
Color to fill surface 'rgba(84,102,108,0.8)` by default. This sholud be a valid Plotly color.
special_kpoints : bool or callable
True by default, determines whether to plot special points or not.
You can also proivide a mask function f(x,y,z) -> bool which will be used to filter special points
based on their fractional coordinates. This is ignored if BZ is primitive.
alpha : float
Opacity of BZ planes.
ortho3d : bool
Default is True, decides whether x,y,z are orthogonal or perspective.
fig : plotly.graph_objects.Figure
Plotly's `go.Figure`. If you want to plot on another plotly's figure, provide that.
kwargs are passed to `plotly.graph_objects.Scatter3d` for BZ lines.
Returns
-------
plotly.graph_objects.Figure
"""
if not fig:
fig = go.Figure()
# Name fixing
vname = "a" if bz_data.__class__.__name__ == "CellData" else "b"
axes_text = [
"<b>k</b><sub>x</sub>/2π",
"",
"<b>k</b><sub>y</sub>/2π",
"",
"<b>k</b><sub>z</sub>/2π",
]
if vname == "a":
axes_text = ["<b>x</b>", "", "<b>y</b>", "", "<b>z</b>"] # Real space
zone_name = kwargs.pop("name", "BZ" if vname == "b" else "Lattice")
is_subzone = hasattr(bz_data, "_specials") # For subzone
if not is_subzone: # No basis, axes for subzone
# Axes
_len = 0.5 * np.mean(bz_data.basis)
fig.add_trace(
go.Scatter3d(
x=[_len, 0, 0, 0, 0],
y=[0, 0, _len, 0, 0],
z=[0, 0, 0, 0, _len],
mode="lines+text",
text=axes_text,
line_color="skyblue",
legendgroup="Axes",
name="Axes",
)
)
fig.add_trace(
go.Cone(
x=[_len, 0, 0],
y=[0, _len, 0],
z=[0, 0, _len],
u=[1, 0, 0],
v=[0, 1, 0],
w=[0, 0, 1],
showscale=False,
sizemode="absolute",
sizeref=0.5,
anchor="tail",
colorscale=["skyblue" for _ in range(3)],
legendgroup="Axes",
name="Axes",
)
)
# Basis
for i, b in enumerate(bz_data.basis):
fig.add_trace(
go.Scatter3d(
x=[0, b[0]],
y=[0, b[1]],
z=[0, b[2]],
mode="lines+text",
legendgroup="{}<sub>{}</sub>".format(vname, i + 1),
line_color="red",
name="<b>{}</b><sub>{}</sub>".format(vname, i + 1),
text=["", "<b>{}</b><sub>{}</sub>".format(vname, i + 1)],
)
)
uvw = b / np.linalg.norm(b) # Unit vector for cones
fig.add_trace(
go.Cone(
x=[b[0]],
y=[b[1]],
z=[b[2]],
u=uvw[0:1],
v=uvw[1:2],
w=uvw[2:],
showscale=False,
colorscale="Reds",
sizemode="absolute",
sizeref=0.02,
anchor="tail",
legendgroup="{}<sub>{}</sub>".format(vname, i + 1),
name="<b>{}</b><sub>{}</sub>".format(vname, i + 1),
)
)
# Rest of the code is same for both subzone and BZ/Cell
# Faces
legend = True
for pts in bz_data.faces_coords:
fig.add_trace(
go.Scatter3d(
x=pts[:, 0],
y=pts[:, 1],
z=pts[:, 2],
mode="lines",
line_color=color,
legendgroup=zone_name,
name=zone_name,
showlegend=legend,
surfaceaxis=2 if fill and coplanar(bz_data.vertices) else -1,
surfacecolor=color, # fills z-axis because its where 2D projection is
opacity=alpha,
**kwargs,
)
)
legend = False # Only first legend to show for all
if fill and not coplanar(bz_data.vertices):
# coplanar fill is not supported by Mesh3d
xc = bz_data.vertices[ConvexHull(bz_data.vertices).vertices]
fig.add_trace(
go.Mesh3d(
x=xc[:, 0],
y=xc[:, 1],
z=xc[:, 2],
color=color,
opacity=alpha,
alphahull=0,# convex body
lighting=dict(diffuse=0.5),
legendgroup=zone_name,
name=zone_name,
)
)
# Special Points only if in reciprocal space and regular BZ
if vname == "b" and (not getattr(bz_data, "primitive", False)) and special_kpoints:
if callable(special_kpoints):
skpts = bz_data.specials.masked(special_kpoints)
else:
skpts = bz_data.specials
for tr in fig.data: # hide all traces hover made before
tr.hoverinfo = "none" # avoid overlapping with special points
texts, values = [], []
norms = np.round(np.linalg.norm(skpts.coords, axis=1), 8)
for key, value, norm in zip(skpts.kpoints.round(6), skpts.coords, norms):
texts.append("K = {}</br>d = {}".format(key, norm))
values.append([[*value, norm]])
values = np.array(values).reshape((-1, 4))
norm_max = np.max(values[:, 3])
c_vals = np.array([int(v * 255 / norm_max) for v in values[:, 3]])
colors = [0 for i in c_vals]
_unique = np.unique(np.sort(c_vals))[::-1]
_lnp = np.linspace(0, 255, len(_unique) - 1)
_u_colors = ["rgb({},0,{})".format(r, b) for b, r in zip(_lnp, _lnp[::-1])]
for _un, _uc in zip(_unique[:-1], _u_colors):
_index = np.where(c_vals == _un)[0]
for _ind in _index:
colors[_ind] = _uc
colors[0] = "rgb(255,215,0)" # Gold color at Gamma!.
fig.add_trace(
go.Scatter3d(
x=values[:, 0],
y=values[:, 1],
z=values[:, 2],
hovertext=texts,
name="HSK",
marker=dict(color=colors, size=4),
mode="markers",
)
)
proj = dict(projection=dict(type="orthographic")) if ortho3d else {}
camera = dict(center=dict(x=0.1, y=0.1, z=0.1), **proj)
fig.update_layout(
template="plotly_white",
scene_camera=camera,
font_family="Times New Roman",
font_size=14,
scene=dict(
aspectmode="data",
xaxis=dict(showbackground=False, visible=False),
yaxis=dict(showbackground=False, visible=False),
zaxis=dict(showbackground=False, visible=False),
),
margin=dict(r=10, l=10, b=10, t=30),
)
return fig
# Cell
def _fix_sites(
poscar_data, tol=1e-2, eqv_sites=False, translate=None, origin=(0, 0, 0)
):
"""Add equivalent sites to make a full data shape of lattice. Returns same data after fixing.
It should not be exposed mostly be used in visualizations"""
if not isinstance(origin, (tuple, list, np.ndarray)) or len(origin) != 3:
raise ValueError("origin must be a list, tuple or numpy array of length 3.")
pos = (
poscar_data.positions.copy()
) # We can also do poscar_data.copy().positions that copies all contents.
labels = np.array(poscar_data.labels) # We need to store equivalent labels as well
out_dict = poscar_data.to_dict() # For output
if isinstance(translate, (int, np.integer, float)):
pos = pos + (translate - int(translate)) # Only translate in 0 - 1
elif isinstance(translate,(tuple, list, np.ndarray)) and len(translate) == 3:
txyz = np.array([translate])
pos = pos + (txyz - txyz.astype(int))
# Fix coordinates of sites distributed on edges and faces
if getattr(poscar_data.metadata, 'eqv_fix', True): # no more fixing over there
pos -= (pos > (1 - tol)).astype(int) # Move towards orign for common fixing like in joining POSCARs
out_dict["positions"] = pos
out_dict["metadata"]["comment"] = "Modified by ipyvasp"
# Add equivalent sites on edges and faces if given,handle each sepecies separately
if eqv_sites and getattr(poscar_data.metadata, 'eqv_fix', True):
new_dict, start = {}, 0
for k, v in out_dict["types"].items():
vpos = pos[v]
vlabs = labels[v]
inds = np.array(v)
ivpos = np.concatenate([np.indices((len(vpos),)).reshape((-1,1)),vpos],axis=1) # track of indices
ivpos = np.array([ivpos + [0, *p] for p in product([-1,0,1],[-1,0,1],[-1,0,1])]).reshape((-1,4))
ivpos = ivpos[(ivpos[:,1:] > -tol).all(axis=1) & (ivpos[:,1:] < 1 + tol).all(axis=1)]
ivpos = ivpos[ivpos[:,0].argsort()]
idxs = ivpos[:,0].ravel().astype(int).tolist()
new_dict[k] = {"pos": ivpos[:,1:], "lab": vlabs[idxs], "inds": inds[idxs]}
new_dict[k]["range"] = range(start, start + len(new_dict[k]["pos"]))
start += len(new_dict[k]["pos"])
out_dict["positions"] = np.vstack([new_dict[k]["pos"] for k in new_dict.keys()])
out_dict["positions"] -= origin # origin given by user to subtract
out_dict["metadata"]["eqv_labels"] = np.hstack(
[new_dict[k]["lab"] for k in new_dict.keys()]
)
out_dict["metadata"]["eqv_indices"] = np.hstack(
[new_dict[k]["inds"] for k in new_dict.keys()]
)
out_dict["types"] = {k: new_dict[k]["range"] for k in new_dict.keys()}
return serializer.PoscarData(out_dict)
def translate_poscar(poscar_data, offset):
"""Translate sites of a POSCAR with a given offset as a number or list of three number.
Usully a farction of integarers like 1/2,1/4 etc."""
return _fix_sites(poscar_data, translate=offset, eqv_sites=False)
def get_pairs(coords, r, tol=1e-3):
"""Returns a tuple of Points(coords,pairs, dist), so coords[pairs] given nearest site bonds.
Parameters
----------
coords : array_like
Array(N,3) of cartesian positions of lattice sites.
r : float
Cartesian distance between the pairs in units of Angstrom e.g. 1.2 -> 1.2E-10.
tol : float
Tolerance value. Default is 10^-3.
"""
if np.ndim(coords) != 2 and np.shape(coords)[1] != 3:
raise ValueError("coords must be a 2D array of shape (N,3).")
tree = KDTree(coords)
inds = np.array([[*p] for p in tree.query_pairs(r, eps=tol)])
if len(inds) > 0:
dist = np.linalg.norm(coords[inds[:, 0],] - coords[inds[:, 1],], axis=1)
else:
dist = np.array([])
return serializer.dict2tuple(
"Points", {"coords": coords, "pairs": inds, "dist": dist}
)
def _get_bond_length(poscar_data, bond_length=None):
"Given `bond_length` should be in unit of Angstrom, and can be a number of dict like {'Fe-O':1.2,...}"
if bond_length is not None:
if isinstance(bond_length, (int, float, np.integer)):
return bond_length
elif isinstance(bond_length, dict):
for k, v in bond_length.items():
if not isinstance(v, (int, float, np.integer)):
raise TypeError(
f"Value to key `{k}` should be a number in unit of Angstrom."
)
if not isinstance(k, str) or k.count("-") != 1:
raise TypeError(
f"key `{k}` should be a string connecting two elements like 'Fe-O'."
)
return max(
list(bond_length.values())
) # return the maximum distance, will filter later
else:
raise TypeError("`bon_length` should be a number or a dict.")
else:
keys = list(poscar_data.types.keys())
if len(keys) == 1:
keys = [*keys, *keys] # still need it to be a list of two elements
dists = [poscar_data.get_distance(k1, k2) for k1, k2 in combinations(keys, 2)]
return (
np.mean(dists) * 1.05
) # Add 5% margin over mean distance, this covers same species too, and in multiple species, this will stop bonding between same species.
class _Atom(NamedTuple):
"Object passed to POSCAR operations `func` where atomic sites are modified. Additinal property p -> array([x,y,z])."
symbol : str
number : int
index : int
x : float
y : float
z : float
@property
def p(self): return np.array([self.x,self.y,self.z]) # for robust operations
class _AtomLabel(str):
"Object passed to `fmt_label` in plotting. `number` and `symbol` are additional attributes and `to_latex` is a method."
@property
def number(self): return int(self.split()[1])
@property
def symbol(self): return self.split()[0]
def to_latex(self): return "{}$_{{{}}}$".format(*self.split())
def _validate_func(func):
if not callable(func):
raise ValueError("`func` must be a callable function with single parameter `Atom(symbol,number, index,x,y,z)`.")
if len(inspect.signature(func).parameters) != 1:
raise ValueError(
"`func` takes exactly 1 argument: `Atom(symbol, number, index,x,y,z)` in fractional coordinates"
)
ret = func(_Atom('',0,0,0,0,0))
if not isinstance(ret, (bool, np.bool_)):
raise ValueError(
f"`func` must be a function that returns a bool, got {type(ret)}."
)
def _masked_data(poscar_data, func):
"Returns indices of sites which satisfy the func."
_validate_func(func)
eqv_inds = tuple(getattr(poscar_data.metadata, "eqv_indices",[]))
pick = []
for i, pos in enumerate(poscar_data.positions):
idx = eqv_inds[i] if eqv_inds else i # map to original index
if func(_Atom(*poscar_data._sn[i], idx, *pos)): # labels based on i, not eqv_idx
pick.append(i)
return pick # could be duplicate indices
def _filter_pairs(labels, pairs, dist, bond_length):
"""Filter pairs based on bond_length dict like {1.2:['Fe','O'],...}. Returns same pairs otherwise."""
if isinstance(bond_length, dict):
new_pairs = []
for pair, d in zip(pairs, dist):
t1, t2 = [labels[idx].split()[0] for idx in pair]
for k, v in bond_length.items():
p = tuple(k.split("-"))
if p in [(t1, t2), (t2, t1)] and d <= v:
new_pairs.append(pair)
return np.unique(new_pairs, axis=0) # remove duplicates
# Return all pairs otherwise
return pairs # None -> auto calculate bond_length, number -> use that number
def filter_atoms(poscar_data, func, tol = 0.01):
"""Filter atomic sites based on a function that acts on an atom such as `lambda a: (a.p < 1/2).all()`.
`atom` passed to function is a namedtuple like `Atom(symbol,number,index,x,y,z)` which has extra attribute `p = array([x,y,z])`.
This may include equivalent sites, so it should be used for plotting purpose only, e.g. showing atoms on a plane.
An attribute `source_indices` is added to metadata which is useful to pick other things such as `OUTCAR.ion_pot[POSCAR.filter(...).data.metadata.source_indices]`.
>>> filter_atoms(..., lambda a: a.symbol=='Ga' or a.number in range(2)) # picks all Ga atoms and first two atoms of every other types.
Note: If you are filtering a plane with more than one non-zero hkl like 110, you may first need to translate or set boundary on POSCAR to bring desired plane in full view to include all atoms.
"""
if hasattr(poscar_data.metadata, 'source_indices'):
raise ValueError("Cannot filter an already filtered POSCAR data.")
poscar_data = _fix_sites(poscar_data, tol = tol, eqv_sites=True)
idxs = _masked_data(poscar_data, func)
data = poscar_data.to_dict()
eqvi = data['metadata'].pop('eqv_indices', []) # no need of this
all_pos, npos, eqv_labs, finds = [], [0,],[],[]
for value in poscar_data.types.values():
indices = [i for i in value if i in idxs] # search from value make sure only non-equivalent sites added
finds.extend(eqvi[indices] if len(eqvi) else indices)
eqv_labs.extend(poscar_data.labels[indices])
pos = data['positions'][indices]
all_pos.append(pos)
npos.append(len(pos))
if not np.sum(npos):
raise ValueError("No sites found with given filter func!")
data['positions'] = np.concatenate(all_pos, axis = 0)
data['metadata']['source_indices'] = np.array(finds)
data['metadata']['eqv_fix'] = False
data['metadata']['eqv_labels'] = np.array(eqv_labs) # need these for compare to previous
ranges = np.cumsum(npos)
data['types'] = {key: range(i,j) for key, i,j in zip(data['types'],ranges[:-1],ranges[1:]) if range(i,j)} # avoid empty
return serializer.PoscarData(data)
# Cell
def iplot_lattice(
poscar_data,
sizes=10,
colors=None,
bond_length=None,
tol=1e-2,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
fig=None,
ortho3d=True,
bond_kws=dict(line_width=4),
site_kws=dict(line_color="rgba(1,1,1,0)", line_width=0.001, opacity=1),
plot_cell=True,
label_sites = False,
**kwargs,
):
"""Plotly's interactive plot of lattice.
Parameters
----------
sizes : float or dict of type -> float
Size of sites. Either one int/float or a mapping like {'Ga': 2, ...}.
colors : color or dict of type -> color
Mapping of colors like {'Ga': 'red, ...} or a single color. Automatically generated color for missing types.
bond_length : float or dict
Length of bond in Angstrom. Auto calculated if not provides. Can be a dict like {'Fe-O':3.2,...} to specify bond length between specific types.
bond_kws : dict
Keyword arguments passed to `plotly.graph_objects.Scatter3d` for bonds.
Default is jus hint, you can use any keyword argument that is accepted by `plotly.graph_objects.Scatter3d`.
site_kws : dict
Keyword arguments passed to `plotly.graph_objects.Scatter3d` for sites.
Default is jus hint, you can use any keyword argument that is accepted by `plotly.graph_objects.Scatter3d`.
plot_cell : bool
Defult is True. Plot unit cell with default settings.
If you want to customize, use `POSCAR.iplot_cell(fig = <return of iplot_lattice>)` function.
kwargs are passed to `iplot_bz`.
"""
if len(poscar_data.positions) < 1:
raise ValueError("Need at least 1 atom!")
poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
blen = _get_bond_length(poscar_data, bond_length)
coords, pairs, dist = get_pairs(poscar_data.coords, r=blen)
_labels = poscar_data.labels
pairs = _filter_pairs(_labels, pairs, dist, bond_length)
if not fig:
fig = go.Figure()
uelems = poscar_data.types.to_dict()
_fcs = _fix_color_size(uelems, colors, sizes, 10, backend = 'plotly')
sizes = [v['size'] for v in _fcs.values()]
colors = [v['color'] for v in _fcs.values()]
_colors = np.array([colors[i] for i, vs in enumerate(uelems.values()) for v in vs],dtype=object) # could be mixed color types
if np.any(pairs):
coords_p = coords[pairs] # paired points
_colors = _colors[pairs] # Colors at pairs
coords_n = []
colors_n = []
for c_p, _c in zip(coords_p, _colors):
mid = np.mean(c_p, axis=0)
arr = np.concatenate([c_p[0], mid, mid, c_p[1]]).reshape((-1, 2, 3))
coords_n = [*coords_n, *arr] # Same shape
colors_n = [*colors_n, *_c] # same shape.
coords_n = np.array(coords_n)
colors_n = np.array(colors_n, dtype=object)
# Instead of plotting for each pair, we can make only as little lines as types of atoms to speec up
unqc = [] # mixed colors type can't be sorted otherwise
for c in colors_n:
if c not in unqc:
unqc.append(c)
clabs = [unqc.index(c) for c in colors_n] # few colors categories
corder = np.argsort(clabs) # coordinates order for those categories
groups = dict([(i,[]) for i in range(len(unqc))])
for co in corder:
groups[clabs[co]].append(coords_n[co])
groups[clabs[co]].append([[np.nan, np.nan, np.nan]]) # nan to break links outside bonds
for i in range(len(unqc)):
groups[i] = np.concatenate(groups[i], axis=0)
bond_kws = {"line_width": 4, **bond_kws}
for i, cp in groups.items():
showlegend = True if i == 0 else False
fig.add_trace(go.Scatter3d(x=cp[:, 0].T, y=cp[:, 1].T, z=cp[:, 2].T,
mode="lines",
line_color=unqc[i],
legendgroup="Bonds",
showlegend=showlegend,
hoverinfo='skip',
name="Bonds",
**bond_kws,
))
site_kws = {
**dict(line_color="rgba(1,1,1,0)", line_width=0.001, opacity=1),
**site_kws,
}
eqv_idxs = getattr(poscar_data.metadata, 'eqv_indices',np.array(range(poscar_data.positions.shape[0])))
for (k, v), c, s in zip(uelems.items(), colors, sizes):
coords = poscar_data.coords[v]
labs = poscar_data.labels[v]
idxs = eqv_idxs[v]
hovertext = [f"<br>{x:7.3f} {y:7.3f} {z:7.3f}<br>Index: {idx} Label: {lab}" for lab,idx, (x,y,z) in zip(labs,idxs, poscar_data.positions[idxs])]
fig.add_trace(
go.Scatter3d(
x=coords[:, 0].T,
y=coords[:, 1].T,
z=coords[:, 2].T,
mode="markers+text" if label_sites else "markers",
marker_color=c,
hovertext=hovertext,
text=["{}<sub>{}</sub>".format(*l.split()) for l in labs] if label_sites else None,
marker_size=s,
name=k,
**site_kws,
)
)
if plot_cell:
bz_data = serializer.CellData(
get_bz(poscar_data.basis, primitive=True).to_dict()
) # Make cell for correct vector notations
iplot_bz(bz_data, fig=fig, ortho3d=ortho3d, special_kpoints=False, **kwargs)
else:
if kwargs:
print("Warning: kwargs are ignored as plot_cell is False.")
# These thing are update in iplot_bz function, but if plot_cell is False, then we need to update them here.
proj = dict(projection=dict(type="orthographic")) if ortho3d else {}
camera = dict(center=dict(x=0.1, y=0.1, z=0.1), **proj)
fig.update_layout(
template="plotly_white",
scene_camera=camera,
font_family="Times New Roman",
font_size=14,
scene=dict(
aspectmode="data",
xaxis=dict(showbackground=False, visible=False),
yaxis=dict(showbackground=False, visible=False),
zaxis=dict(showbackground=False, visible=False),
),
margin=dict(r=10, l=10, b=10, t=30),
)
return fig
def _validate_label_func(fmt_label,label):
if not callable(fmt_label):
raise ValueError("fmt_label must be a callable function.")
if len(inspect.signature(fmt_label).parameters.values()) != 1:
raise ValueError("fmt_label must have only one argument that accepts a str like 'Ga 1'.")
test_out = fmt_label(_AtomLabel(label))
if isinstance(test_out, (list, tuple)):
if len(test_out) != 2:
raise ValueError(
"fmt_label must return string or a list/tuple of length 2."
)
if not isinstance(test_out[0], str):
raise ValueError(
"Fisrt item in return of `fmt_label` must return a string! got {}".format(
type(test_out[0])
)
)
if not isinstance(test_out[1], dict):
raise ValueError(
"Second item in return of `fmt_label` must return a dictionary of keywords to pass to `plt.text`! got {}".format(
type(test_out[1])
)
)
elif not isinstance(test_out, str):
raise ValueError("fmt_label must return a string or a list/tuple of length 2.")
def _fix_color_size(types, colors, sizes, default_size, backend=None):
cs = {key: {'color': _atom_colors.get(key, 'blue'), 'size': default_size} for key in types}
for k in cs:
if len(cs[k]['color']) == 3: # otherwise its blue
if backend == 'plotly':
cs[k]['color'] = "rgb({},{},{})".format(*[int(255*c) for c in cs[k]['color']])
elif backend == 'ngl':
cs[k]['color'] = mplc.to_hex(cs[k]['color'])
if isinstance(sizes,(int,float,np.integer)):
for k in cs:
cs[k]['size'] = sizes
elif isinstance(sizes, dict):
for k,v in sizes.items():
cs[k]['size'] = v
else:
raise TypeError("sizes should be a single int/float or dict as {'Ga':10,'As':15,...}")
if isinstance(colors,dict):
for k,v in colors.items():
cs[k]['color'] = v
elif isinstance(colors,(str,list,tuple,np.ndarray)):
for k in cs:
cs[k]['color'] = colors
elif colors is not None:
raise TypeError("colors should be a single valid color or dict as {'Ga':'red','As':'blue',...}")
return cs
# Cell
def splot_lattice(
poscar_data,
plane=None,
sizes=50,
colors=None,
bond_length=None,
tol=1e-2,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
ax=None,
showlegend=True,
fmt_label=None,
site_kws=dict(alpha=0.7),
bond_kws=dict(alpha=0.7, lw=1),
plot_cell=True,
**kwargs,
):
"""Matplotlib Static plot of lattice.
Parameters
----------
plane : str
Plane to plot. Either 'xy','xz','yz' or None for 3D plot.
sizes : float or dict of type -> float
Size of sites. Either one int/float or a mapping like {'Ga': 2, ...}.
colors : color or dict of type -> color
Mapping of colors like {'Ga': 'red, ...} or a single color. Automatically generated color for missing types.
bond_length : float or dict
Length of bond in Angstrom. Auto calculated if not provides. Can be a dict like {'Fe-O':3.2,...} to specify bond length between specific types.
alpha : float
Opacity of points and bonds.
showlegend : bool
Default is True, show legend for each ion type.
site_kws : dict
Keyword arguments to pass to `plt.scatter` for plotting sites.
Default is just hint, you can pass any keyword argument that `plt.scatter` accepts.
bond_kws : dict
Keyword arguments to pass to `LineCollection`/`Line3DCollection` for plotting bonds.
fmt_label : callable
If given, each site label is passed to it as a subclass of str 'Ga 1' with extra attributes `symbol` and `number` and a method `to_latex`.
You can show specific labels based on condition, e.g. `lambda lab: lab.to_latex() if lab.number in [1,5] else ''` will show 1st and 5th atom of each types.
It must return a string or a list/tuple of length 2 with first item as label and second item as dictionary of keywords to pass to `plt.text`.
plot_cell : bool
Default is True, plot unit cell with default settings.
To customize options, use `plot_cell = False` and do `POSCAR.splot_cell(ax = <return of splot_lattice>)`.
kwargs are passed to `splot_bz`.
.. tip::
Use `plt.style.use('ggplot')` for better 3D perception.
"""
if len(poscar_data.positions) < 1:
raise ValueError("Need at least 1 atom!")
# Plane fix
if plane and plane not in "xyzxzyx":
raise ValueError("plane expects in 'xyzxzyx' or None.")
if plane:
ind = "xyzxzyx".index(plane)
arr = [0, 1, 2, 0, 2, 1, 0]
ix, iy = arr[ind], arr[ind + 1]
poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
blen = _get_bond_length(poscar_data, bond_length)
labels = poscar_data.labels
coords, pairs, dist = get_pairs(poscar_data.coords, r=blen)
pairs = _filter_pairs(labels, pairs, dist, bond_length)
if fmt_label is not None:
_validate_label_func(fmt_label,labels[0])
if plot_cell:
bz_data = serializer.CellData(
get_bz(poscar_data.basis, primitive=True).to_dict()
) # For correct vectors
ax = splot_bz(bz_data, plane=plane, ax=ax, **kwargs)
else:
ax = ax or ptk.get_axes(axes_3d=True if plane is None else False)
if kwargs:
print(f"Warning: Parameters {list(kwargs.keys())} are not used when `plot_cell = False`.")
uelems = poscar_data.types.to_dict()
_fcs = _fix_color_size(uelems, colors, sizes, 50)
sizes = [v['size'] for v in _fcs.values()]
colors = [v['color'] for v in _fcs.values()]
# Before doing other stuff, create something for legend.
if showlegend:
for key, c, s in zip(uelems.keys(), colors, sizes):
ax.scatter([], [], s=s, color=c, label=key, **site_kws) # Works both for 3D and 2D.
ptk.add_legend(ax)
# Now change colors and sizes to whole array size
colors = np.array(
[mplc.to_rgb(colors[i]) for i, vs in enumerate(uelems.values()) for v in vs]
)
sizes = np.array([sizes[i] for i, vs in enumerate(uelems.values()) for v in vs])
if np.any(pairs):
coords_p = coords[pairs] # paired points
_colors = colors[pairs] # Colors at pairs
coords_n = []
colors_n = []
for c_p, _c in zip(coords_p, _colors):
mid = np.mean(c_p, axis=0)
arr = np.concatenate([c_p[0], mid, mid, c_p[1]]).reshape((-1, 2, 3))
coords_n = [*coords_n, *arr] # Same shape
colors_n = [*colors_n, *_c] # same shape.
coords_n = np.array(coords_n)
colors_n = np.array(colors_n)
bond_kws = {
"alpha": 0.7,
"capstyle": "butt",
**bond_kws,
} # bond_kws overrides alpha and capstyle only
# 3D LineCollection by default, very fast as compared to plot one by one.
lc = Line3DCollection(coords_n, colors=colors_n, **bond_kws)
if plane and plane in "xyzxzyx": # Avoid None
lc = LineCollection(coords_n[:, :, [ix, iy]], colors=colors_n, **bond_kws)
ax.add_collection(lc)
ax.autoscale_view()
if not plane:
site_kws = {
**dict(alpha=0.7, depthshade=False),
**site_kws,
} # site_kws overrides alpha only
ax.scatter(
coords[:, 0], coords[:, 1], coords[:, 2], c=colors, s=sizes, **site_kws
)
if fmt_label:
for i, coord in enumerate(coords):
lab, textkws = fmt_label(_AtomLabel(labels[i])), {}
if isinstance(lab, (list, tuple)):
lab, textkws = lab
ax.text(*coord, lab, **textkws)
# Set aspect to same as data if cell plotted
if plot_cell:
ax.set_box_aspect(np.ptp(bz_data.vertices, axis=0))
elif plane in "xyzxzyx":
site_kws = {**dict(alpha=0.7, zorder=3), **site_kws}
(iz,) = [i for i in range(3) if i not in (ix, iy)]
zorder = coords[:, iz].argsort()
if plane in "yxzy": # Left handed
zorder = zorder[::-1]
ax.scatter(
coords[zorder][:, ix],
coords[zorder][:, iy],
c=colors[zorder],
s=sizes[zorder],
**site_kws,
)
if fmt_label:
labels = [labels[i] for i in zorder] # Reorder labels
for i, coord in enumerate(coords[zorder]):
lab, textkws = fmt_label(_AtomLabel(labels[i])), {}
if isinstance(lab, (list, tuple)):
lab, textkws = lab
ax.text(*coord[[ix, iy]], lab, **textkws)
# Set aspect to display real shape.
ax.set_aspect("equal")
ax.set_axis_off()
return ax
# Cell
def join_poscars(poscar_data, other, direction="c", tol=1e-2, system=None):
"""Joins two POSCARs in a given direction. In-plane lattice parameters are kept from
first poscar and out of plane basis vector of other is modified while volume is kept same.
Parameters
----------
other : type(self)
Other POSCAR to be joined with this POSCAR.
direction : str
The joining direction. It is general and can join in any direction along basis. Expect one of ['a','b','c'].
tol : float
Default is 0.01. It is used to bring sites near 1 to near zero in order to complete sites in plane.
Vasp relaxation could move a point, say at 0.00100 to 0.99800 which is not useful while merging sites.
system : str
If system is given, it is written on top of file. Otherwise, it is infered from atomic species.
"""
_poscar1 = _fix_sites(poscar_data, tol=tol, eqv_sites=False)
_poscar2 = _fix_sites(other, tol=tol, eqv_sites=False)
pos1 = _poscar1.positions.copy()
pos2 = _poscar2.positions.copy()
s1, s2 = 0.5, 0.5 # Half length for each.
a1, b1, c1 = np.linalg.norm(_poscar1.basis, axis=1)
a2, b2, c2 = np.linalg.norm(_poscar2.basis, axis=1)
basis = _poscar1.basis.copy() # Must be copied, otherwise change outside.
# Processing in orthogonal space since a.(b x c) = abc sin(theta)cos(phi), and theta and phi are same for both.
if direction in "cC":
c2 = (
(a2 * b2) / (a1 * b1) * c2
) # Conservation of volume for right side to stretch in c-direction.
netc = c1 + c2
s1, s2 = c1 / netc, c2 / netc
pos1[:, 2] = s1 * pos1[:, 2]
pos2[:, 2] = s2 * pos2[:, 2] + s1
basis[2] = netc * basis[2] / np.linalg.norm(basis[2]) # Update 3rd vector
elif direction in "bB":
b2 = (
(a2 * c2) / (a1 * c1) * b2
) # Conservation of volume for right side to stretch in b-direction.
netb = b1 + b2
s1, s2 = b1 / netb, b2 / netb
pos1[:, 1] = s1 * pos1[:, 1]
pos2[:, 1] = s2 * pos2[:, 1] + s1
basis[1] = netb * basis[1] / np.linalg.norm(basis[1]) # Update 2nd vector
elif direction in "aA":
a2 = (
(b2 * c2) / (b1 * c1) * a2
) # Conservation of volume for right side to stretch in a-direction.
neta = a1 + a2
s1, s2 = a1 / neta, a2 / neta
pos1[:, 0] = s1 * pos1[:, 0]
pos2[:, 0] = s2 * pos2[:, 0] + s1
basis[0] = neta * basis[0] / np.linalg.norm(basis[0]) # Update 1st vector
else:
raise Exception("direction expects one of ['a','b','c']")
scale = np.linalg.norm(basis[0])
u1 = _poscar1.types.to_dict()
u2 = _poscar2.types.to_dict()
u_all = ({**u1, **u2}).keys() # Union of unique atom types to keep track of order.
pos_all = []
i_all = []
for u in u_all:
_i_ = 0
if u in u1.keys():
_i_ = len(u1[u])
pos_all = [*pos_all, *pos1[u1[u]]]
if u in u2.keys():
_i_ = _i_ + len(u2[u])
pos_all = [*pos_all, *pos2[u2[u]]]
i_all.append(_i_)
i_all = np.cumsum([0, *i_all]) # Do it after labels
uelems = {_u: range(i_all[i], i_all[i + 1]) for i, _u in enumerate(u_all)}
sys = system or "".join(uelems.keys())
iscartesian = poscar_data.metadata.cartesian or other.metadata.cartesian
metadata = {
"cartesian": iscartesian,
"scale": scale,
"comment": "Modified by ipyvasp",
}
out_dict = {
"SYSTEM": sys,
"basis": basis,
"metadata": metadata,
"positions": np.array(pos_all),
"types": uelems,
}
return serializer.PoscarData(out_dict)
# Cell
def repeat_poscar(poscar_data, n, direction):
"""Repeat a given POSCAR.
Parameters
----------
n : int
Number of repetitions.
direction : str
Direction of repetition. Can be 'a', 'b' or 'c'.
"""
if not isinstance(n, (int, np.integer)) and n < 2:
raise ValueError("n must be an integer greater than 1.")
given_poscar = poscar_data
for i in range(1, n):
poscar_data = join_poscars(given_poscar, poscar_data, direction=direction)
return poscar_data
def scale_poscar(poscar_data, scale=(1, 1, 1), tol=1e-2):
"""Create larger/smaller cell from a given POSCAR. Can be used to repeat a POSCAR with integer scale values.
Parameters
----------
scale : tuple
Tuple of three values along (a,b,c) vectors. int or float values. If number of sites are not as expected in output,
tweak `tol` instead of `scale`. You can put a minus sign with `tol` to get more sites and plus sign to reduce sites.
tol : float
It is used such that site positions are blow `1 - tol`, as 1 belongs to next cell, not previous one.
.. note::
``scale = (2,2,2)`` enlarges a cell and next operation of ``(1/2,1/2,1/2)`` should bring original cell back.
.. warning::
A POSCAR scaled with Non-integer values should only be used for visualization purposes, Not for any other opration such as making supercells, joining POSCARs etc.
"""
if not isinstance(scale, (tuple, list)) or len(scale) != 3:
raise ValueError("scale must be a tuple of three values.")
ii, jj, kk = np.ceil(scale).astype(int) # Need int for joining.
if tuple(scale) == (1, 1, 1): # No need to scale.
return poscar_data
if ii >= 2:
poscar_data = repeat_poscar(poscar_data, ii, direction="a")
if jj >= 2:
poscar_data = repeat_poscar(poscar_data, jj, direction="b")
if kk >= 2:
poscar_data = repeat_poscar(poscar_data, kk, direction="c")
if np.all([s == int(s) for s in scale]):
return poscar_data # No need to prcess further in case of integer scaling.
new_poscar = poscar_data.to_dict() # Update in it
# Get clip fraction
fi, fj, fk = scale[0] / ii, scale[1] / jj, scale[2] / kk
# Clip at end according to scale, change length of basis as fractions.
pos = poscar_data.positions.copy() / np.array([fi, fj, fk]) # rescale for clip
basis = poscar_data.basis.copy()
for i, f in zip(range(3), [fi, fj, fk]):
basis[i] = f * basis[i] # Basis rescale for clip
new_poscar["basis"] = basis
new_poscar["metadata"]["scale"] = np.linalg.norm(basis[0])
new_poscar["metadata"]["comment"] = f"Modified by ipyvasp"
uelems = poscar_data.types.to_dict()
# Minus in below for block is because if we have 0-2 then 1 belongs to next cell not original.
positions, shift = [], 0
for key, value in uelems.items():
s_p = pos[value] # Get positions of key
s_p = s_p[(s_p < 1 - tol).all(axis=1)] # Get sites within tolerance
if len(s_p) == 0:
raise Exception(
f"No sites found for {key!r}, cannot scale down. Increase scale!"
)
uelems[key] = range(shift, shift + len(s_p))
positions = [*positions, *s_p] # Pick sites
shift += len(s_p) # Update for next element
new_poscar["types"] = uelems
new_poscar["positions"] = np.array(positions)
return serializer.PoscarData(new_poscar)
def set_boundary(poscar_data, a = [0,1], b = [0,1], c = [0,1]):
"View atoms in a given boundary along a,b,c directions."
for d, name in zip([a,b,c],'abc'):
if not isinstance(d,(list,tuple)) or len(d) != 2:
raise ValueError(f"{name} should be a list/tuple of type [min, max]")
if d[1] < d[0]:
raise ValueError(f"{name} should be in increasing order as [min, max]")
data = poscar_data.to_dict()
upos = {}
for key, value in poscar_data.types.items():
pos = data['positions'][value]
for i, (l,h), shift in zip(range(3), [a,b,c],np.eye(3)):
pos = np.concatenate([pos + shift*k for k in np.arange(np.floor(l), np.ceil(h))],axis=0)
pos = pos[(pos[:,i] >= l) & (pos[:,i] <= h)]
upos[key] = pos
data['positions'] = np.concatenate(list(upos.values()), axis = 0)
data['metadata']['eqv_fix'] = False
ranges = np.cumsum([0, *[len(v) for v in upos.values()]])
data['types'] = {key: range(i,j) for key, i,j in zip(upos,ranges[:-1],ranges[1:])}
del upos
return serializer.PoscarData(data)
def rotate_poscar(poscar_data, angle_deg, axis_vec):
"""Rotate a given POSCAR.
Parameters
----------
angle_deg : float
Rotation angle in degrees.
axis_vec : array_like
Vector (x,y,z) of axis about which rotation takes place. Axis passes through origin.
"""
rot = rotation(angle_deg=angle_deg, axis_vec=axis_vec)
p_dict = poscar_data.to_dict()
p_dict["basis"] = rot.apply(
p_dict["basis"]
) # Rotate basis so that they are transpose
p_dict["metadata"]["comment"] = f"Modified by ipyvasp"
return serializer.PoscarData(p_dict)
def set_zdir(poscar_data, hkl, phi=0):
"""Set z-direction of POSCAR along a given hkl direction and returns new data.
Parameters
----------
hkl : tuple
(h,k,l) of the direction along which z-direction is to be set.
Vector is constructed as h*a + k*b + l*c in cartesian coordinates.
phi: float
Rotation angle in degrees about z-axis to set a desired rotated view.
Returns
-------
New instance of poscar with z-direction set along hkl.
"""
if not isinstance(hkl, (list, tuple, np.ndarray)) and len(hkl) != 3:
raise ValueError("hkl must be a list, tuple or numpy array of length 3.")
p_dict = poscar_data.to_dict()
basis = p_dict["basis"]
zvec = to_R3(basis, [hkl])[0] # in cartesian coordinates
angle = np.arccos(
zvec.dot([0, 0, 1]) / np.linalg.norm(zvec)
) # Angle between zvec and z-axis
rot = rotation(
angle_deg=np.rad2deg(angle), axis_vec=np.cross(zvec, [0, 0, 1])
) # Rotation matrix
new_basis = rot.apply(basis) # Rotate basis so that zvec is along z-axis
p_dict["basis"] = new_basis
p_dict["metadata"]["comment"] = f"Modified by ipyvasp"
new_pos = serializer.PoscarData(p_dict)
if phi: # Rotate around z-axis
return rotate_poscar(new_pos, angle_deg=phi, axis_vec=[0, 0, 1])
return new_pos
def mirror_poscar(poscar_data, direction):
"Mirror a POSCAR in a given direction. Sometime you need it before joining two POSCARs"
poscar = poscar_data.to_dict() # Avoid modifying original
idx = "abc".index(direction) # Check if direction is valid
poscar["positions"][:, idx] = (
1 - poscar["positions"][:, idx]
) # Trick: Mirror by subtracting from 1. not by multiplying with -1.
return serializer.PoscarData(poscar) # Return new POSCAR
def convert_poscar(poscar_data, atoms_mapping, basis_factor):
"""Convert a POSCAR to a similar structure of other atomic types or same type with strained basis.
Parameters
----------
atoms_mapping : dict
A dictionary of {old_atom: new_atom} like {'Ga':'Al'} will convert GaAs to AlAs structure.
basis_factor : float
A scaling factor multiplied with basis vectors, single value (useful for conversion to another type)
or list of three values to scale along (a,b,c) vectors (useful for strained structures).
.. note::
This can be used to strain basis vectors uniformly only. For non-uniform strain, use :func:`ipyvasp.POSCAR.deform`.
"""
poscar_data = poscar_data.to_dict() # Avoid modifying original
poscar_data["types"] = {
atoms_mapping.get(k, k): v for k, v in poscar_data["types"].items()
} # Update types
basis = poscar_data["basis"].copy() # Get basis to avoid modifying original
if isinstance(basis_factor, (int, np.integer, float)):
poscar_data["basis"] = basis_factor * basis # Rescale basis
elif isinstance(basis_factor, (list, tuple, np.ndarray)):
if len(basis_factor) != 3:
raise Exception("basis_factor should be a list/tuple/array of length 3")
if np.ndim(basis_factor) != 1:
raise Exception(
"basis_factor should be a list/tuple/array of 3 int/float values"
)
poscar_data["basis"] = np.array(
[
basis_factor[0] * basis[0],
basis_factor[1] * basis[1],
basis_factor[2] * basis[2],
]
)
else:
raise Exception(
"basis_factor should be a list/tuple/array of 3 int/float values, got {}".format(
type(basis_factor)
)
)
poscar_data["SYSTEM"] = "".join(poscar_data["types"].keys()) # Update system name
return serializer.PoscarData(poscar_data) # Return new POSCAR
def transform_poscar(poscar_data, transformation, fill_factor=2, tol=1e-2):
"""Transform a POSCAR with a given transformation matrix or function that takes old basis and return target basis.
Use `get_TM(basis1, basis2)` to get transformation matrix from one basis to another or function to return new basis of your choice.
An example of transformation function is `lambda a,b,c: a + b, a-b, c` which will give a new basis with a+b, a-b, c as basis vectors.
You may find errors due to missing atoms in the new basis, use `fill_factor` and `tol` to include any possible site in new cell.
Examples
--------
- FCC primitive → 111 hexagonal cell: ``lambda a,b,c: (a-c,b-c,a+b+c) ~ [[1,0,-1],[0,1,-1],[1,1,1]]``
- FCC primitive → FCC unit cell: ``lambda a,b,c: (b+c -a,a+c-b,a+b-c) ~ [[-1,1,1],[1,-1,1],[1,1,-1]]``
- FCC unit cell → 110 tetragonal cell: ``lambda a,b,c: (a-b,a+b,c) ~ [[1,-1,0],[1,1,0],[0,0,1]]``
.. note::
This function keeps underlying lattice same. To apply strain, use `deform` function instead.
"""
if callable(transformation):
new_basis = np.array(transformation(*poscar_data.basis)) # mostly a tuple
if new_basis.shape != (3, 3):
raise Exception(
"transformation function should return a tuple equivalent to 3x3 matrix"
)
elif np.ndim(transformation) == 2 and np.shape(transformation) == (3, 3):
new_basis = np.matmul(transformation, poscar_data.basis)
else:
raise Exception(
"transformation should be a function that accept 3 arguemnts or 3x3 matrix"
)
if not isinstance(fill_factor,int):
raise TypeError(f'fill_factor should be int, got {type(fill_factor)}')
_p = range(-fill_factor, fill_factor + 1)
pos = np.concatenate([poscar_data.positions,[[i] for i,_ in enumerate(poscar_data.positions)]], axis=1) # keep track of index
pos = np.concatenate([pos + [*p,0] for p in product(_p,_p,_p)],axis=0) # increaser by fill_factor^3
pos[:,:3] = to_basis(new_basis, poscar_data.to_cartesian(pos[:,:3])) # convert to coords in this and to points in new
pos = pos[(pos[:,:3] <= 1 - tol).all(axis=1) & (pos[:,:3] >= -tol).all(axis=1)]
pos = pos[pos[:,-1].argsort()] # sort for species
new_poscar = poscar_data.to_dict() # Update in it
new_poscar["basis"] = new_basis
new_poscar["metadata"]["scale"] = np.linalg.norm(new_basis[0])
new_poscar["metadata"]["comment"] = f"Transformed by ipyvasp"
new_poscar["metadata"]["TM"] = get_TM(poscar_data.basis, new_basis) # save Transformation matrix
old_numbers = [len(v) for v in poscar_data.types.values()]
uelems, start = {}, 0
for k, v in poscar_data.types.items():
uelems[k] = range(start, start + len(pos[(pos[:,-1] >= v.start) & (pos[:,-1] < v.stop)]))
start = uelems[k].stop
# warn if crystal formula changes
new_numbers = [len(v) for v in uelems.values()]
ratio = (np.array(new_numbers)/old_numbers).round(4) # Round to avoid floating point errors,can cover 1 to 10000 atoms transformation
if len(np.unique(ratio)) != 1:
print(tcolor.rb(f"WARNING: Transformation failed, atoms proportion changed: {old_numbers} -> {new_numbers}."
" If your transformation is an allowed one for this structure, increase `fill_factor` or `tol`."))
new_poscar["types"] = uelems
new_poscar["positions"] = np.array(pos[:,:3])
return serializer.PoscarData(new_poscar)
def add_vaccum(poscar_data, thickness, direction, left=False):
"""Add vacuum to a POSCAR.
Parameters
----------
thickness : float
Thickness of vacuum in Angstrom.
direction : str
Direction of vacuum. Can be 'a', 'b' or 'c'.
left : bool
If True, vacuum is added to left of sites. By default, vacuum is added to right of sites.
"""
if direction not in "abc":
raise Exception("Direction must be a, b or c.")
poscar_dict = poscar_data.to_dict() # Avoid modifying original
basis = poscar_dict["basis"].copy() # Copy basis to avoid modifying original
pos = poscar_dict["positions"].copy() # Copy positions to avoid modifying original
idx = "abc".index(direction)
norm = np.linalg.norm(basis[idx]) # Get length of basis vector
s1, s2 = norm / (norm + thickness), thickness / (norm + thickness) # Get scaling
basis[idx, :] *= (thickness + norm) / norm # Add thickness to basis
poscar_dict["basis"] = basis
if left:
pos[:, idx] *= s2 # Scale down positions
pos[:, idx] += s1 # Add vacuum to left of sites
poscar_dict["positions"] = pos
else:
pos[:, idx] *= s1 # Scale down positions
poscar_dict["positions"] = pos
return serializer.PoscarData(poscar_dict) # Return new POSCAR
def transpose_poscar(poscar_data, axes=[1, 0, 2]):
"Transpose a POSCAR by switching basis from [0,1,2] -> `axes`. By Default, x and y are transposed."
if isinstance(axes, (list, tuple, np.ndarray)) and len(axes) == 3:
if not all(isinstance(i, (int, np.integer)) for i in axes):
raise ValueError("`axes` must be a list of three integers.")
poscar_data = poscar_data.to_dict() #
basis = poscar_data["basis"].copy() # Copy basis to avoid modifying original
positions = poscar_data[
"positions"
].copy() # Copy positions to avoid modifying original
poscar_data["basis"] = basis[axes] # Transpose basis
poscar_data["positions"] = positions[:, axes] # Transpose positions
return serializer.PoscarData(poscar_data) # Return new POSCAR
else:
raise Exception("`axes` must be a squence of length 3.")
def add_atoms(poscar_data, name, positions):
"Add atoms with a `name` to a POSCAR at given `positions` in fractional coordinates."
positions = np.array(positions)
if (not np.ndim(positions) == 2) or (not positions.shape[1] == 3):
raise ValueError("`positions` must be a 2D array of shape (n,3)")
new_pos = np.vstack([poscar_data.positions, positions]) # Add new pos
unique = poscar_data.types.to_dict() # avoid modifying original
unique[name] = range(len(poscar_data.positions), len(new_pos))
data = poscar_data.to_dict() # Copy data to avoid modifying original
data["types"] = unique # Update unique dictionary
data["positions"] = new_pos # Update positions
data["metadata"][
"comment"
] = f'{data["metadata"]["comment"]} + Added {name!r}' # Update comment
data["SYSTEM"] = "".join(data["types"].keys()) # Update system name
return serializer.PoscarData(data) # Return new POSCAR
def replace_atoms(poscar_data, func, name):
"""Replace atoms satisfying a `func(atom) -> bool` with a new `name`. Like `lambda a: a.symbol == 'Ga'`"""
data = poscar_data.to_dict() # Copy data to avoid modifying original
mask = _masked_data(poscar_data, func)
new_types = {**{k: [] for k in poscar_data.types.keys()}, name: []}
for k, vs in data["types"].items():
for idx in vs:
if idx in mask:
new_types[name].append(idx)
else:
new_types[k].append(idx)
data["positions"] = np.vstack([data["positions"][t] for t in new_types.values()])
idxs = np.cumsum([0, *map(len, new_types.values())])
data["types"] = {
k: range(idxs[i], idxs[i + 1])
for i, k in enumerate(new_types.keys())
if len(new_types[k]) != 0
}
data["SYSTEM"] = "".join(data["types"].keys()) # Update system name
return serializer.PoscarData(data) # Return new POSCAR
def sort_poscar(poscar_data, new_order):
"sort poscar with new_order list/tuple of species."
if not isinstance(new_order, (list, tuple)):
raise TypeError(f"new_order should be a list/tuple of types, got {type(new_order)}")
data = poscar_data.to_dict()
if not all([set(new_order).issubset(data["types"]), set(data["types"]).issubset(new_order)]):
raise ValueError(f"new_order should contain all existings types {list(data['types'])}")
data["types"] = {key:data["types"][key] for key in new_order}
data["positions"] = data["positions"][[i for tp in data["types"].values() for i in tp]]
idxs = np.cumsum([0, *map(len, data["types"].values())])
data["types"] = {
k: range(idxs[i], idxs[i + 1])
for i, k in enumerate(data["types"].keys())
if len(data["types"][k]) != 0
}
data["SYSTEM"] = "".join(data["types"].keys()) # Update system name
return serializer.PoscarData(data)
def remove_atoms(poscar_data, func, fillby=None):
"""Remove atoms that satisfy `func(atom) -> bool` on their fractional coordinates like `lambda a: all(a.p < 1/2)`.
`atom` passed to function is a namedtuple like `Atom(symbol,number,index,x,y,z)` which has extra attribute `p = array([x,y,z])`.
If `fillby` is given, it will fill the removed atoms with atoms from fillby POSCAR.
>>> remove_atoms(..., lambda a: sum((a.p - 0.5)**2) <= 0.25**2) # remove atoms in center of cell inside radius of 0.25
.. note::
The coordinates of fillby POSCAR are transformed to basis of given POSCAR, before filling.
So a good filling is only guaranteed if both POSCARs have smaller lattice mismatch.
"""
_validate_func(func) # need to validate for fillbay
data = poscar_data.to_dict() # Copy data to avoid modifying original
positions = data["positions"]
mask = _masked_data(poscar_data, lambda s: not func(s))
new_types = {k: [] for k in poscar_data.types.keys()}
for k, vs in data["types"].items():
for idx in vs:
if idx in mask:
new_types[k].append(idx)
if fillby:
if not isinstance(fillby, serializer.PoscarData):
raise ValueError("`fillby` must be instance of PoscarData class.")
filldata = fillby.to_dict()
positions = np.vstack(
[data["positions"], to_basis(poscar_data.basis, fillby.coords)]
) # update positions of fillby in given data basis, not fillby basis
def keep_pos(i, x, y, z): # keep positions in basis of given data
u, v, w = to_basis(poscar_data.basis, to_R3(fillby.basis, [[x, y, z]]))[0]
return bool(func(_Atom('', 0, 0, u, v, w)))
mask = _masked_data(fillby, keep_pos)
N_prev = len(data["positions"]) # before filling
new_types = {
**{k: [] for k in filldata["types"]},
**new_types,
} # Add new types from fillby but keep old types values
for k, vs in filldata["types"].items():
for idx in vs:
if idx in mask:
new_types[k].append(N_prev + idx)
data["positions"] = np.vstack([positions[t] for t in new_types.values()])
idxs = np.cumsum([0, *map(len, new_types.values())])
data["types"] = {
k: range(idxs[i], idxs[i + 1])
for i, k in enumerate(new_types.keys())
if len(new_types[k]) != 0
}
data["SYSTEM"] = "".join(data["types"].keys()) # Update system name
return serializer.PoscarData(data) # Return new POSCAR
def deform_poscar(poscar_data, deformation):
"""Deform a POSCAR by a deformation as 3x3 ArrayLike, or a function that takee basis and returns a 3x3 ArrayLike,
to be multiplied with basis (elementwise) and return a new POSCAR.
.. note::
This function can change underlying crystal structure if cell shape changes, to just change cell shape, use `transform` function instead.
"""
poscar_dict = poscar_data.to_dict() # make a copy
if callable(deformation):
try:
poscar_dict["basis"] = np.array(
deformation(*poscar_data.basis)
) # mostly tuple
except:
raise ValueError(
"`deformation` function must be a function(a,b,c) -> 3x3 matrix to multiply with basis."
)
else:
dmatrix = deformation
if not isinstance(dmatrix, np.ndarray):
dmatrix = np.array(dmatrix)
if dmatrix.shape != (3, 3):
raise ValueError(
"`deformation` must be a 3x3 matrix or a function(a,b,c) -> 3x3 matrix to multiply with basis."
)
# Update basis by elemetwise multiplication
poscar_dict["basis"] = poscar_data.basis * dmatrix
poscar_dict["metadata"][
"comment"
] = f'{poscar_data["metadata"]["comment"]} + Deformed POSCAR'
return serializer.PoscarData(poscar_dict) # Return new POSCAR
def view_poscar(poscar_data, **kwargs):
"View a POSCAR in a jupyter notebook. kwargs are passed to splot_lattice. After setting a view, you can do view.f(**view.kwargs) to get same plot in a cell."
def view(elev, azim, roll):
ax = splot_lattice(poscar_data, **kwargs)
ax.view_init(elev=elev, azim=azim, roll=roll)
elev = IntSlider(description='elev', min=0,max=180,value=30, continuous_update=False)
azim = IntSlider(description='azim', min=0,max=360,value=30, continuous_update=False)
roll = IntSlider(description='roll', min=0,max=360,value=0, continuous_update=False)
return interactive(view, elev=elev, azim=azim, roll=roll)