Source code for ipyvasp.widgets
__all__ = [
"load_results",
"summarize",
"Files",
"PropsPicker",
"BandsWidget",
"KPathWidget",
]
import inspect, re
from pathlib import Path
from collections.abc import Iterable
from functools import partial
# Widgets Imports
from IPython.display import display
from ipywidgets import (
Layout,
Button,
HBox,
VBox,
Dropdown,
Text,
Stack,
SelectMultiple,
TagsInput,
)
# More imports
import numpy as np
import pandas as pd
import ipywidgets as ipw
import traitlets
import plotly.graph_objects as go
import dashlab as dl
# Internal imports
from . import utils as gu
from . import lattice as lat
from .core import serializer, parser as vp, plot_toolkit as ptk
from .utils import _sig_kwargs, _sub_doc, get_file_size
from ._enplots import _fmt_labels
[docs]
def summarize(files, func, **kwargs):
"""
Apply given func to each file in files and return a dataframe of the results.
Parameters
----------
files: Iterable, must be an iterable of PathLike objects, a dictionary of {name: PathLike} pairs also works and name appears in the dataframe.
func: callable with a single arguemnt path. Must return a dictionary.
kwargs: passed to func itself.
"""
if not callable(func):
raise TypeError("Argument `func` must be a function.")
if not isinstance(files, Iterable): # Files is instance of Iterable due to __iter__ method
raise TypeError("Argument `files` must be an iterable of PathLike objects")
if not isinstance(files, dict):
files = {str(path): path for path in files} # make a dictionary of paths
outputs = []
for name, path in files.items():
output = func(path, **kwargs)
if not isinstance(output, dict):
raise TypeError("Function must return a dictionary to create DataFrame.")
if "FILE" in output:
raise KeyError(
"FILE is a reserved key to store the file name for reference."
)
outputs.append(
{**output, "FILE": name}
) # add the file name to the output at the end
unique_keys = {} # handle missing keys with types
for key,value in [item for out in outputs for item in out.items()]:
unique_keys[key] = '' if isinstance(value, str) else None
return pd.DataFrame(
{key: [out.get(key, ph) for out in outputs] for key,ph in unique_keys.items()}
)
def fix_signature(cls):
# VBox ruins signature of subclass, let's fix it
cls.__signature__ = inspect.signature(cls.__init__)
return cls
[docs]
@fix_signature
class Files:
"""Creates a Batch of files in a directory recursively based on glob pattern or given list of files.
This is a boilerplate abstraction to do analysis in multiple calculations simultaneously.
Parameters
----------
path_or_files : str, current directory by default or list of files or an instance of Files.
glob : str, glob pattern, '*' by default. '**' is used for recursive glob. Not used if files supplied above.
exclude : str, regular expression pattern to exclude files.
files_only : bool, if True, returns only files.
dirs_only : bool, if True, returns only directories.
Returns
-------
Files instance.
Notes
-----
Use methods on return such as ``summarize``, ``with_name``, ``filtered``, ``interact`` and others.
>>> Files(root_1, glob_1,...).add(root_2, glob_2,...) # Fully flexible to chain
**WARNING**: Don't use write operations on paths in files in batch mode, it can cause unrecoverable data loss.
"""
def __init__(self, path_or_files = '.', glob = '*', exclude = None,files_only = False, dirs_only=False):
if isinstance(path_or_files, Files):
self._files = path_or_files._files
return # Do nothing
if files_only and dirs_only:
raise ValueError("files_only and dirs_only cannot be both True")
files = []
if isinstance(path_or_files,(str, Path)):
files = Path(path_or_files).glob(glob)
else:
others = []
for item in path_or_files:
if isinstance(item, str):
item = Path(item)
elif not isinstance(item, Path):
raise TypeError(f"Expected str or Path in sequence, got {type(item)}")
if item.exists():
files.append(item)
else:
others.append(str(item))
if others:
print(f"Skipping paths that do not exist: {list(set(others))}")
if exclude:
files = (p for p in files if not re.search(exclude, str(p)))
if files_only:
files = (p for p in files if p.is_file())
if dirs_only:
files = (p for p in files if p.is_dir())
self._files = tuple(sorted(files))
def __str__(self):
return '\n'.join(str(f) for f in self._files)
def __repr__(self):
if not self: return "Files()"
show = ',\n'.join(f' {f!r}' for f in self._files)
return f"Files(\n{show}\n) {len(self._files)} items"
def __getitem__(self, index): return self._files[index]
def __iter__(self): return self._files.__iter__()
def __len__(self): return len(self._files)
def __bool__(self): return bool(self._files)
def __add__(self, other):
raise NotImplementedError("Use self.add method instead!")
[docs]
def with_name(self, name):
"Change name of all files. Only keeps existing files."
return self.__class__([f.with_name(name) for f in self._files])
[docs]
def filtered(self, include=None, exclude=None, files_only = False, dirs_only=False):
"Filter all files. Only keeps existing file."
files = [p for p in self._files if re.search(include, str(p))] if include else self._files
return self.__class__(files, exclude=exclude,dirs_only=dirs_only,files_only=files_only)
[docs]
def summarize(self, func, **kwargs):
"Apply a func(path) -> dict and create a dataframe."
return summarize(self._files,func, **kwargs)
[docs]
def load_results(self,exclude_keys=None):
"Load result.json files from these paths into a dataframe, with optionally excluding keys."
return load_results(self._files,exclude_keys=exclude_keys)
[docs]
def input_info(self, *tags):
"Grab input information into a dataframe from POSCAR and INCAR. Provide INCAR tags (case-insinsitive) to select only few of them."
from .lattice import POSCAR
def info(path, tags):
p = POSCAR(path).data
lines = [[v.strip() for v in line.split('=')]
for line in path.with_name('INCAR').read_text().splitlines()
if '=' in line]
if tags:
tags = [t.upper() for t in tags] # can send lowercase tag
lines = [(k,v) for k,v in lines if k in tags]
d = {k:v for k,v in lines if not k.startswith('#')}
d.update({k:len(v) for k,v in p.types.items()})
d.update(zip(['a','b','c','v','alpha','beta','gamma'], [*p.norms,p.volume,*p.angles]))
return d
return self.with_name('POSCAR').summarize(info, tags=tags)
[docs]
def update(self, path_or_files, glob = '*', cleanup = True, exclude=None,**kwargs):
"""Update files inplace with similar parameters as initialization. If `cleanup=False`, older files are kept too.
Useful for widgets such as BandsWidget to preserve their state while using `widget.files.update`."""
old = () if cleanup else self._files
self._files = self._unique(old, self.__class__(path_or_files, glob = glob, exclude=exclude,**kwargs)._files)
if (dd := getattr(self, '_dd', None)): # update dropdown
old = dd.value
dd.options = self._files
if old in dd.options:
dd.value = old
[docs]
def to_dropdown(self,description='File'):
"""
Convert this instance to Dropdown. If there is only one file, adds an
empty option to make that file switchable.
Options of this dropdown are update on calling `Files.update` method."""
if hasattr(self,'_dd'):
return self._dd # already created
options = self._files if len(self._files) != 1 else ['', *self._files] # make single file work
self._dd = Dropdown(description=description, options=options)
return self._dd
[docs]
def add(self, path_or_files, glob = '*', exclude=None, **kwargs):
"""Add more files or with a diffrent glob on top of exitsing files. Returns same instance.
Useful to add multiple globbed files into a single chained call.
>>> Files(root_1, glob_1,...).add(root_2, glob_2,...) # Fully flexible
"""
self._files = self._unique(self._files, self.__class__(path_or_files, glob = glob, exclude=exclude,**kwargs)._files)
return self
def _unique(self, *files_tuples):
return tuple(np.unique(np.hstack(files_tuples)))
[docs]
@_sub_doc(dl.interactive)
def interactive(self, *funcs, post_init:callable=None,**kwargs):
if 'file' in kwargs:
raise KeyError("file is a reserved keyword argument to select path to file!")
has_file_param = False
for func in funcs:
if not callable(func):
raise TypeError(f"Each item in *funcs should be callable, got {type(func)}")
params = [k for k,v in inspect.signature(func).parameters.items()]
for key in params:
if key == 'file':
has_file_param = True
break
if funcs and not has_file_param: # may be no func yet, that is test below
raise KeyError("At least one of funcs should take 'file' as parameter, none got it!")
return dl.interactive(*funcs, post_init=post_init, file = self.to_dropdown(), **kwargs)
[docs]
@_sub_doc(dl.interact)
def interact(self, *funcs, post_init:callable=None,**kwargs):
def inner(func):
display(self.interactive(func, *funcs,
post_init=post_init,
**kwargs)
)
return func
return inner
[docs]
def kpath_widget(self, height='400px'):
"Get KPathWidget instance with these files."
return KPathWidget(files = self.with_name('POSCAR'), height = height)
[docs]
def bands_widget(self, height='450px'):
"Get BandsWidget instance with these files."
return BandsWidget(files=self._files, height=height)
[docs]
def map(self,func, to_df=False):
"""Map files to a function that takes path as argument.
If `to_df=True`, func may return a dict to create named columns, or just two columns will be created.
Otherwise returns generator of elemnets `(path, func(path))`.
If you need to operate on opened file pointer, use `.mapf` instead.
>>> import ipyvasp as ipv
>>> files = ipv.Files(...)
>>> files.map(lambda path: ipv.read(path, '<pattern>',apply = lambda line: float(line.split()[0])))
>>> files.map(lambda path: ipv.load(path), to_df=True)
"""
if to_df:
return self._try_return_df(func)
return ((path, func(path)) for path in self._files) # generator must
def _try_return_df(self, func):
try: return summarize(self._files,func)
except: return pd.DataFrame(((path, func(path)) for path in self._files))
[docs]
def mapf(self, func, to_df=False,mode='r', encoding=None):
"""Map files to a function that takes opened file pointer as argument. Opened files are automatically closed and should be in readonly mode.
Load files content into a generator sequence of tuples like `(path, func(open(path)))` or DataFrame if `to_df=True`.
If `to_df=True`, func may return a dict to create named columns, or just two columns will be created.
If you need to operate on just path, use `.map` instead.
>>> import json
>>> import ipyvasp as ipv
>>> files = ipv.Files(...)
>>> files.mapf(lambda fp: json.load(fp,cls=ipv.DecodeToNumpy),to_df=True) # or use ipv.load(path) in map
>>> files.mapf(lambda fp: ipv.take(fp, range(5)) # read first five lines
>>> files.mapf(lambda fp: ipv.take(fp, range(-5,0)) # read last five lines
>>> files.mapf(lambda fp: ipv.take(fp, -1, 1, float) # read last line, second column as float
"""
if not mode in 'rb':
raise ValueError("Only 'r'/'rb' mode is allowed in this context!")
def loader(path):
with open(path, mode=mode,encoding=encoding) as f:
return func(f)
if to_df:
return self._try_return_df(loader)
return ((path, loader(path)) for path in self._files) # generator must
[docs]
def stat(self):
"Get files stat as DataFrame. Currently only size is supported."
return self.summarize(lambda path: {"size": get_file_size(path)})
@fix_signature
class _PropPicker(VBox):
"""Single projection picker with atoms and orbitals selection"""
props = traitlets.Dict({})
def __init__(self, system_summary=None):
super().__init__()
self._atoms = TagsInput(description="Atoms", allowed_tags=[],
placeholder="Select atoms", allow_duplicates = False).add_class('props-tags')
self._orbs = TagsInput(description="Orbs", allowed_tags=[],
placeholder="Select orbitals", allow_duplicates = False).add_class('props-tags')
self.children = [self._atoms, self._orbs]
self.layout.width = '100%' # avoid horizontal collapse
self._atoms_map = {}
self._orbs_map = {}
# Link changes
self._atoms.observe(self._update_props, 'value')
self._orbs.observe(self._update_props, 'value')
self._process(system_summary)
def _update_props(self, change):
"""Update props trait when selections change"""
_atoms = [self._atoms_map.get(tag, None) for tag in self._atoms.value]
_orbs = [self._orbs_map.get(tag, None) for tag in self._orbs.value]
# Filter out None values, and flatten
# Flatten and filter atoms
atoms = []
for ats in _atoms:
atoms.extend(ats if ats is not None else [])
# Flatten and filter orbitals
orbs = []
for ors in _orbs:
orbs.extend(ors if ors is not None else [])
if atoms and orbs:
self.props = {
'atoms': atoms, 'orbs': orbs,
'label': f"{'+'.join(self._atoms.value)} | {'+'.join(self._orbs.value)}"
}
else:
self.props = {}
def _process(self, system_summary):
"""Process system data and setup widget options"""
if system_summary is None or not hasattr(system_summary, "orbs"):
return
sorbs = system_summary.orbs
self._orbs_map = {"All": range(len(sorbs)), "s": [0]}
# p-orbitals
if set(["px", "py", "pz"]).issubset(sorbs):
self._orbs_map.update({
"p": range(1, 4),
"px+py": [idx for idx, key in enumerate(sorbs) if key in ("px", "py")],
**{k: [v] for k, v in zip(sorbs[1:4], range(1, 4))}
})
# d-orbitals
if set(["dxy", "dyz"]).issubset(sorbs):
self._orbs_map.update({
"d": range(4, 9),
**{k: [v] for k, v in zip(sorbs[4:9], range(4, 9))}
})
# f-orbitals
if len(sorbs) == 16:
self._orbs_map.update({
"f": range(9, 16),
**{k: [v] for k, v in zip(sorbs[9:16], range(9, 16))}
})
# Extra orbitals beyond f
if len(sorbs) > 16:
self._orbs_map.update({
k: [idx] for idx, k in enumerate(sorbs[16:], start=16)
})
self._orbs.allowed_tags = list(self._orbs_map.keys())
# Process atoms
self._atoms_map = {
"All": range(system_summary.NIONS),
**{k: v for k,v in system_summary.types.to_dict().items()},
**{f"{k}{n}": [v] for k,tp in system_summary.types.to_dict().items()
for n,v in enumerate(tp, 1)}
}
self._atoms.allowed_tags = list(self._atoms_map.keys())
self._update_props(None) # Trigger props update
def update(self, system_summary):
"""Update widget with new system data while preserving selections"""
old_atoms = self._atoms.value
old_orbs = self._orbs.value
self._process(system_summary)
# Restore previous selections if still valid
self._atoms.value = [tag for tag in old_atoms if tag in self._atoms.allowed_tags]
self._orbs.value = [tag for tag in old_orbs if tag in self._orbs.allowed_tags]
[docs]
@fix_signature
class PropsPicker(VBox): # NOTE: remove New Later
"""
A widget to pick atoms and orbitals for plotting.
Parameters
----------
system_summary : (Vasprun,Vaspout).summary
N : int, default is 3, number of projections to pick.
You can observe `projections` trait.
"""
projections = traitlets.Dict({})
def __init__(self, system_summary=None, N=3):
super().__init__()
self._N = N
self._pickers = [_PropPicker(system_summary) for _ in range(N)]
self.add_class("props-picker")
# Create widgets with consistent width
self._picker = Dropdown(
description="Color" if N == 3 else "Projection",
options=["Red", "Green", "Blue"] if N == 3 else [str(i+1) for i in range(N)],
)
self._stack = Stack(children=self._pickers, selected_index=0)
# Link picker dropdown to stack
ipw.link((self._picker, 'index'), (self._stack, 'selected_index'))
# Setup layout
self.children = [self._picker, self._stack]
# Observe pickers for props changes and button click
for picker in self._pickers:
picker.observe(self._update_projections, names=['props'])
def _update_projections(self, change):
"""Update combined projections when any picker changes"""
projs = {}
for picker in self._pickers:
if picker.props: # Only add non-empty selections
projs[picker.props['label']] = (
picker.props['atoms'],
picker.props['orbs']
)
self.projections = projs
[docs]
def update(self, system_summary):
"""Update all pickers with new system data"""
for picker in self._pickers:
picker.update(system_summary)
def _clean_legacy_data(path):
"clean old style keys like VBM to vbm"
data = serializer.load(path.absolute()) # Old data loaded
if not any(key in data for key in ['VBM', 'α','vbm_k']):
return data # already clean
keys_map = {
"SYSTEM": "sys",
"VBM": "vbm", # Old: New
"CBM": "cbm",
"VBM_k": "kvbm", "vbm_k": "kvbm",
"CBM_k": "kcbm", "cbm_k": "kcbm",
"E_gap": "gap",
"\u0394_SO": "soc",
"α": "alpha",
"β": "beta",
"γ": "gamma",
}
new_data = {k:v for k,v in data.items() if k not in (*keys_map.keys(),*keys_map.values())} # keep other data
for old, new in keys_map.items():
if old in data:
new_data[new] = data[old] # Transfer value from old key to new key
elif new in data:
new_data[new] = data[new] # Keep existing new style keys
# save cleaned data
serializer.dump(new_data,format="json",outfile=path)
return new_data
[docs]
def load_results(paths_list, exclude_keys=None):
"Loads result.json from paths_list and returns a dataframe. Use exclude_keys to get subset of data."
if exclude_keys is not None:
if not isinstance(exclude_keys, (list,tuple)):
raise TypeError(f"exclude_keys should be list of keys, got {type(exclude_keys)}")
if not all([isinstance(key,str) for key in exclude_keys]):
raise TypeError(f"all keys in exclude_keys should be str!")
paths_list = [Path(p) for p in paths_list]
result_paths = []
if paths_list:
for path in paths_list:
if path and path.is_dir():
result_paths.append(path / "result.json")
elif path and path.is_file():
result_paths.append(path.parent / "result.json")
def load_data(path):
try:
data = _clean_legacy_data(path)
return {k:v for k,v in data.items() if k not in (exclude_keys or [])}
except:
return {} # If not found, return empty dictionary
return summarize(result_paths, load_data)
def _get_css(mode):
return {
'--jp-widgets-color': 'white' if mode == 'dark' else 'black',
'--jp-widgets-label-color': 'white' if mode == 'dark' else 'black',
'--jp-widgets-readout-color': 'white' if mode == 'dark' else 'black',
'--jp-widgets-input-color': 'white' if mode == 'dark' else 'black',
'--jp-widgets-input-background-color': '#222' if mode == 'dark' else '#f7f7f7',
'--jp-widgets-input-border-color': '#8988' if mode == 'dark' else '#ccc',
'--jp-layout-color2': '#555' if mode == 'dark' else '#ddd', # buttons
'--jp-ui-font-color1': 'whitesmoke' if mode == 'dark' else 'black', # buttons
'--jp-content-font-color1': 'white' if mode == 'dark' else 'black', # main text
'--jp-layout-color1': '#111' if mode == 'dark' else '#fff', # background
':fullscreen': {'min-height':'100vh'},
'background': 'var(--jp-widgets-input-background-color)', 'border-radius': '4px', 'padding':'4px 4px 0 4px',
'> *': {
'box-sizing': 'border-box',
'background': 'var(--jp-layout-color1)',
'border-radius': '4px', 'grid-gap': '8px', 'padding': '8px',
},
'.left-sidebar .sm': {
'flex-grow': 1,
'select': {'height': '100%',},
},
'.footer': {'overflow': 'auto','padding':0},
'.widget-vslider, .jupyter-widget-vslider': {'width': 'auto'}, # otherwise it spans too much area
'table': { # dataframe display sucks
'color':'var(--jp-content-font-color1)',
'background':'var(--jp-layout-color1)',
'tr': {
'^:nth-child(odd)': {'background':'var(--jp-widgets-input-background-color)',},
'^:nth-child(even)': {'background':'var(--jp-layout-color1)',},
},
},
'.props-picker': {
'background': 'var(--jp-widgets-input-background-color)', # make feels like single widget
'overflow-x': 'hidden', 'border-radius': '4px', 'padding': '4px',
},
'.props-tags': {
'background':'var(--jp-layout-color1)', 'border-radius': '4px', 'padding': '4px',
'> input': {'width': '100%'},
'> input::placeholder': {'color': 'var(--jp-ui-font-color1)'},
},
}
class _ThemedFigureInteract(dl.DashboardBase):
"Keeps self._fig anf self._theme button attributes for subclasses to use."
def __init__(self, *args, **kwargs):
self._fig = dl.patched_plotly(go.FigureWidget())
self._theme = Button(icon='sun', description=' ', tooltip="Toggle Theme")
super().__init__(*args, **kwargs)
if not all([hasattr(self.params, 'fig'), hasattr(self.params, 'theme')]):
raise AttributeError("subclass must include already initialized "
"{'fig': self._fig,'theme':self._theme} in returned dict of _interactive_params() method.")
self._update_theme(self._fig,self._theme) # fix theme in starts
self.observe(self._autosize_figs, names = 'isfullscreen') # fix figurewidget problem
def _autosize_figs(self, change):
for w in self.params:
# don't know yet about these without importing
if re.search('plotly.*FigureWidget', str(type(w).__mro__)):
w.layout.autosize = False # Double trigger is important
w.layout.autosize = True
def _interactive_params(self): return {}
def __init_subclass__(cls):
if (not '_update_theme' in cls.__dict__) or (not hasattr(cls._update_theme,'_is_interactive_callback')):
raise AttributeError("implement _update_theme(self, fig, theme) decorated by @callback in subclass, "
"which should only call super()._update_theme(fig, theme) in its body.")
super().__init_subclass__()
@dl.callback
def _update_theme(self, fig, theme):
require_dark = (theme.icon == 'sun')
theme.icon = 'moon' if require_dark else 'sun' # we are not observing icon, so we can do this
fig.layout.template = "plotly_dark" if require_dark else "plotly_white"
self.set_css() # automatically sets dark/light, ensure after icon set
fig.layout.autosize = True # must
def _set_css(self, main=None, center=None): # allowed overriding, but not set_css
# This is after setting icon above, so logic is fliipped
style = _get_css("light" if self._theme.icon == 'sun' else 'dark') # infer from icon to match
if isinstance(main, dict):
style = {**style, **main} # main should allow override
elif main is not None:
raise TypeError("main must be a dict or None, got: {}".format(type(main)))
super()._set_css(style, center)
@property
def files(self):
"Use self.files.update(...) to keep state of widget preserved with new files."
if not hasattr(self, '_files'): # subclasses must set this, although no check unless user dots it
raise AttributeError("self._files = Files(...) was never set!")
return self._files
[docs]
@fix_signature
class BandsWidget(_ThemedFigureInteract):
"""Visualize band structure from VASP calculation. You can click on the graph to get the data such as VBM, CBM, etc.
You can observe three traits:
- file: Currently selected file
- clicked_data: Last clicked point data, which can be directly passed to a dataframe.
- selected_data: Last selection of points within a box or lasso, which can be directly passed to a dataframe and plotted accordingly.
- You can use `self.files.update` method to change source files without effecting state of widget.
- You can also use `self.iplot`, `self.splot` with `self.kws` to get static plts of current state, and self.results to get a dataframe.
- You can use store_clicks to provide extra names of points you want to click and save data, besides default ones.
"""
file = traitlets.Any(allow_none=True)
clicked_data = traitlets.Dict(allow_none=True)
selected_data = traitlets.Dict(allow_none=True)
def __init__(self, files, height="600px", store_clicks=None):
self.add_class("BandsWidget")
self._kb_fig = go.FigureWidget() # for extra stuff
self._kb_fig.update_layout(margin=dict(l=40, r=0, b=40, t=40, pad=0)) # show compact
self._files = Files(files)
self._bands = None
self._kws = {}
self._result = {}
self._extra_clicks = ()
if store_clicks is not None:
if not isinstance(store_clicks, (list,tuple)):
raise TypeError("store_clicks should be list of names "
f"of point to be stored from click on figure, got {type(store_clicks)}")
for name in store_clicks:
if not isinstance(name, str) or not name.isidentifier():
raise ValueError(f"items in store_clicks should be a valid python variable name, got {name!r}")
if name in ["vbm", "cbm", "so_max", "so_min"]:
raise ValueError(f"{name!r} already exists in default click points!")
reserved = "gap soc v a b c alpha beta gamma direct".split()
if name in reserved:
raise ValueError(f"{name!r} conflicts with reserved keys {reserved}")
self._extra_clicks += tuple(store_clicks)
super().__init__() # after extra clicks
traitlets.dlink((self.params.file,'value'),(self, 'file'))
traitlets.dlink((self.params.fig,'clicked'),(self, 'clicked_data'))
traitlets.dlink((self.params.fig,'selected'),(self, 'selected_data'))
self.set_layout(
left_sidebar=[
'head','file','krange','kticks','brange', 'ppicks',
HBox(children=self.gather('theme','button')), 'kb_fig',
],
center=['hdata','fig','cpoint'], footer = ['*out'], # all outputs
right_sidebar = ['showft'],
pane_widths=['25em',1,'2em'], pane_heights=[0,1,0], # footer only has uselessoutputs
height=height
)
@traitlets.validate('selected_data','clicked_data')
def _flatten_dict(self, proposal):
data = proposal['value']
if data is None: return None # allow None stuff
if not isinstance(data, dict):
raise traitlets.TraitError(f"Expected a dict for selected_data, got {type(data)}")
_data = {k:v for k,v in data.items() if k != 'customdata' and 'indexes' not in k}
_data.update(pd.DataFrame(data.get('customdata',{})).to_dict(orient='list'))
return _data # since we know customdata, we can flatten dict
@dl.callback
def _update_theme(self, fig, theme):
super()._update_theme(fig, theme)
self._kb_fig.layout.template = fig.layout.template
self._kb_fig.layout.autosize = True
def _interactive_params(self):
return dict(
fig = self._fig, theme = self._theme, # include theme and fig
kb_fig = self._kb_fig, # show selected data
head = ipw.HTML("<b>Band Structure Visualizer</b>"),
file = self.files.to_dropdown(),
ppicks = PropsPicker(),
button = Button(description="Update Graph", icon= 'update'),
krange = ipw.IntRangeSlider(description="kpoints",min=0, max=1,value=[0,1], tooltip="Includes non-zero weight kpoints"),
kticks = Text(description="kticks", tooltip="0 index maps to minimum value of kpoints slider."),
brange = ipw.IntRangeSlider(description="bands",min=1, max=1), # number, not index
cpoint = ipw.ToggleButtons(description="Select from options and click on figure to store data points",
value=None, options=["vbm", "cbm", *self._extra_clicks]).add_class('content-width-button'), # the point where clicked
showft = ipw.IntSlider(description = 'h', orientation='vertical',min=0,max=50, value=0,tooltip="outputs area's height ratio"),
cdata = 'fig.clicked',
projs = 'ppicks.projections', # for visual feedback on button
sdata = '.selected_data',
hdata = ipw.HTML(), # to show data in one place
)
@dl.callback('out-selected')
def _plot_data(self, kb_fig, sdata):
kb_fig.data = [] # clear in any case to avoid confusion
if not sdata: return # no change
df = pd.DataFrame(sdata)
if 'r' in sdata:
arr = df[['r','g','b']].to_numpy()
arr[arr == ''] = 0
arr, fmt = arr / (arr.max() or 1), lambda v : int(v*255) # color norms
df['color'] = [f"rgb({fmt(r)},{fmt(g)},{fmt(b)})" for r,g,b in arr]
else:
df['color'] = sdata['occ']
df['msize'] = df['occ']*7 + 10
cdata = (df[["ys","occ","r","g","b"]] if 'r' in sdata else df[['ys','occ']]).to_numpy()
rgb_temp = '<br>orbs: (%{customdata[2]},%{customdata[3]},%{customdata[4]})' if 'r' in sdata else ''
kb_fig.add_trace(go.Scatter(x=df.nk, y = df.nb, mode = 'markers', marker = dict(size=df.msize,color=df.color), customdata=cdata))
kb_fig.update_traces(hovertemplate=f"nk: %{{x}}, nb: %{{y}})<br>en: %{{customdata[0]:.4f}}<br>occ: %{{customdata[1]:.4f}}{rgb_temp}<extra></extra>")
kb_fig.update_layout(template = self._fig.layout.template, autosize=True,
title = "Selected Data", showlegend=False,coloraxis_showscale=False,
margin=dict(l=40, r=0, b=40, t=40, pad=0),font=dict(family="stix, serif", size=14)
)
@dl.callback('out-data')
def _load_data(self, file):
if not file: return # First time not available
self._bands = (
vp.Vasprun(file) if file.parts[-1].endswith('xml') else vp.Vaspout(file)
).bands
self.params.ppicks.update(self.bands.source.summary)
self.params.krange.max = self.bands.source.summary.NKPTS - 1
self.params.krange.tooltip = f"Includes {self.bands.source.get_skipk()} non-zero weight kpoints"
self.bands.source.set_skipk(0) # full range to view for slider flexibility after fix above
self._kws['kpairs'] = [self.params.krange.value,]
if (ticks := ", ".join(
f"{k}:{v}" for k, v in self.bands.get_kticks()
)): # Do not overwrite if empty
self.params.kticks.value = ticks
self.params.brange.max = self.bands.source.summary.NBANDS
if self.bands.source.summary.LSORBIT:
self.params.cpoint.options = ["vbm", "cbm", "so_max", "so_min", *self._extra_clicks]
else:
self.params.cpoint.options = ["vbm", "cbm",*self._extra_clicks]
if (path := file.parent / "result.json").is_file():
self._result = _clean_legacy_data(path)
pdata = self.bands.source.poscar.data
self._result.update(
{
"sys": pdata.SYSTEM, "v": round(pdata.volume, 4),
**{k: round(v, 4) for k, v in zip("abc", pdata.norms)},
**{k: round(v, 4) for k, v in zip(["alpha","beta","gamma"], pdata.angles)},
}
)
self._show_data(self._result) # Load into view
@dl.callback
def _toggle_footer(self, showft):
self.pane_heights = [0,100 - showft, showft]
@dl.callback
def _set_krange(self, krange):
self._kws["kpairs"] = [krange,]
@dl.callback
def _warn_update(self, file, kticks, brange, krange, projs):
self.params.button.description = "🔴 Update Graph"
@dl.callback('out-graph')
def _update_graph(self, fig, button):
if not self.bands: return # First time not available
fig.layout.autosize = True # must
hsk = [
[v.strip() for v in vs.split(":")]
for vs in self.params.kticks.value.split(",")
]
kmin, kmax = self.params.krange.value or [0,0]
kticks = [(int(vs[0]), vs[1])
for vs in hsk # We are going to pick kticks silently in given range
if len(vs) == 2 and abs(int(vs[0])) <= (kmax - kmin) # handle negative indices too
] or None
_bands = None
if self.params.brange.value:
l, h = self.params.brange.value
_bands = range(l-1, h) # from number to index
self._kws = {**self._kws, "kticks": kticks, "bands": _bands}
ISPIN = self.bands.source.summary.ISPIN
if self.params.ppicks.projections:
self._kws["projections"] = self.params.ppicks.projections
_fig = self.bands.iplot_rgb_lines(**self._kws, name="Up" if ISPIN == 2 else "")
if ISPIN == 2:
self.bands.iplot_rgb_lines(**self._kws, spin=1, name="Down", fig=fig)
self.iplot = partial(self.bands.iplot_rgb_lines, **self._kws)
self.splot = partial(self.bands.splot_rgb_lines, **self._kws)
else:
self._kws.pop("projections",None) # may be previous one
_fig = self.bands.iplot_bands(**self._kws, name="Up" if ISPIN == 2 else "")
if self.bands.source.summary.ISPIN == 2:
self.bands.iplot_bands(**self._kws, spin=1, name="Down", fig=fig)
self.iplot = partial(self.bands.iplot_bands, **self._kws)
self.splot = partial(self.bands.splot_bands, **self._kws)
ptk.iplot2widget(_fig, fig, template=fig.layout.template)
fig.clicked = {} # avoid data from previous figure
fig.selected = {} # avoid data from previous figure
button.description = "Update Graph" # clear trigger
@dl.callback('out-click')
def _click_save_data(self, cdata):
if self.params.cpoint.value is None: return # at reset-
data_dict = self._result.copy() # Copy old data
if cdata: # No need to make empty dict
key = self.params.cpoint.value
if key:
y = round(float(*cdata['ys']) + self.bands.data.ezero, 6) # Add ezero
if not key in self._extra_clicks:
data_dict[key] = y # Assign value back
if not key.startswith("so_"): # not spin-orbit points
cst, = cdata.get('customdata',[{}]) # single item
kp = [cst.get(f"k{n}", None) for n in 'xyz']
kp = tuple([round(k,6) if k else k for k in kp])
if key in ("vbm","cbm"):
data_dict[f"k{key}"] = kp
else: # user points, stor both for reference
data_dict[key] = {"k":kp,"e":y}
if data_dict.get("vbm", None) and data_dict.get("cbm", None):
data_dict["gap"] = np.round(data_dict["cbm"] - data_dict["vbm"], 6)
if data_dict.get("so_max", None) and data_dict.get("so_min", None):
data_dict["soc"] = np.round(
data_dict["so_max"] - data_dict["so_min"], 6
)
self._result.update(data_dict) # store new data
self._show_and_save(self._result, f"{key} = {data_dict[key]}")
self.params.cpoint.value = None # Reset to None to avoid accidental click at end
def _show_data(self, data, last_click=None):
"Show data in html widget, no matter where it was called."
keys = "sys vbm cbm gap direct soc v a b c alpha beta gamma".split()
data = {key:data[key] for key in keys if key in data} # show only standard data
kv, kc = [self._result.get(k,[None]*3) for k in ('kvbm','kcbm')]
data['direct'] = (kv == kc) if None not in kv else False
# Add a caption to the table
caption = f"<caption style='caption-side:bottom; opacity:0.7;'><code>{last_click or 'clicked data is shown here'}</code></caption>"
headers = "".join(f"<th>{key}</th>" for key in data.keys())
values = "".join(f"<td>{format(value, '.4f') if isinstance(value, float) else value}</td>" for value in data.values())
self.params.hdata.value = f"""<table border='1' style='width:100%;max-width:100% !important;border-collapse:collapse;'>
{caption}<tr>{headers}</tr>\n<tr>{values}</tr></table>"""
def _show_and_save(self, data_dict, last_click=None):
self._show_data(data_dict,last_click=last_click)
if self.file:
serializer.dump(data_dict,format="json",
outfile=self.file.parent / "result.json")
[docs]
def results(self, exclude_keys=None):
"Generate a dataframe form result.json file in each folder, with optionally excluding keys."
return load_results(self.params.file.options, exclude_keys=exclude_keys)
@property
def source(self):
"Returns data source object such as Vasprun or Vaspout."
return self.bands.source
@property
def bands(self):
"Bands class initialized"
if not self._bands:
raise ValueError("No data loaded by BandsWidget yet!")
return self._bands
@property
def kws(self):
"Selected keyword arguments from GUI"
return self._kws
[docs]
@fix_signature
class KPathWidget(_ThemedFigureInteract):
"""
Interactively bulid a kpath for bandstructure calculation.
After initialization and disply:
- Select a POSCAR file from "File:" dropdown menu. It will update the figure.
- Add points to select box on left by clicking on plot points. When done with points click on Lock to avoid adding more points.
- To update point(s), select point(s) from the select box and click on a scatter point in figure or use KPOINT input to update it manually, e.g. if a point is not available on plot.
- Add labels to the points by typing in the "Labels" box such as "Γ,X" or "Γ 5,X" that will add 5 points in interval.
- To break the path between two points "Γ" and "X" type "Γ 0,X" in the "Labels" box, zero means no points in interval.
- You can use `self.files.update` method to change source files without effecting state of widget.
- You can observe `self.file` trait to get current file selected and plot something, e.g. lattice structure.
"""
file = traitlets.Any(None, allow_none=True)
@property
def poscar(self): return self._poscar
def __init__(self, files, height="450px"):
self.add_class("KPathWidget")
self._poscar = None
self._oldclick = None
self._kpoints = {}
self._files = Files(files) # set name _files to ensure access to files
super().__init__()
traitlets.dlink((self.params.file,'value'),(self, 'file')) # update file trait
btns = HBox(children=self.gather('lock','delp', 'theme'), layout=Layout(min_height="24px"))
self.set_layout(
left_sidebar=['head','file',btns, 'info', 'sm','out-kpt','kpt', 'out-lab', 'lab'],
center=['fig'], footer = ['*out', '!out-lab','!out-kpt'], # all outputs except prefixed with !
pane_widths=['25em',1,0], pane_heights=[0,1,0], # footer only has uselessoutputs
height=height
)
def _show_info(self, text, color='skyblue'):
self.params.info.value = f'<span style="color:{color}">{text}</span>'
def _interactive_params(self):
return dict(
fig = self._fig, theme = self._theme, # include theme and fig
head = ipw.HTML("<b>K-Path Builder</b>"),
file = self.files.to_dropdown(), # auto updatable on files.update
sm = SelectMultiple(description="KPOINTS", options=[], layout=Layout(width="auto")),
lab = Text(description="Labels", continuous_update=True),
kpt = Text(description="KPOINT", continuous_update=False),
delp = Button(description=" ", icon='trash', tooltip="Delete Selected Points"),
click = 'fig.clicked',
lock = Button(description=" ", icon='unlock', tooltip="Lock/Unlock adding more points"),
info = ipw.HTML(), # consise information in one place
)
@dl.callback('out-fig')
def _update_fig(self, file, fig):
if not file: return # empty one
from ipyvasp.lattice import POSCAR # to avoid circular import
self._poscar = POSCAR(file)
ptk.iplot2widget(
self._poscar.iplot_bz(fill=False, color="red"), fig, self.params.fig.layout.template
)
fig.layout.autosize = True # must
with fig.batch_animate():
fig.add_trace(
go.Scatter3d(x=[], y=[], z=[],
mode="lines+text",
name="path",
text=[],
hoverinfo="none", # dont let it block other points
textfont_size=18,
)
) # add path that will be updated later
self._show_info("Click points on plot to store for kpath.")
@dl.callback('out-click')
def _click(self, click):
# We are setting value on select multiple to get it done in one click conveniently
# But that triggers infinite loop, so we need to check if click is different next time
if click != self._oldclick and (tidx := click.get('trace_indexes',[])):
self._oldclick = click # for next time
data = self.params.fig.data # click depends on fig, so accessing here
if not [data[i] for i in tidx if 'HSK' in data[i].name]: return
if cp := [*click.get('xs', []),*click.get('ys', []),*click.get('zs', [])]:
kp = self._poscar.bz.to_fractional(cp) # reciprocal space
if self.params.sm.value:
self._set_kpt(kp) # this updates plot back as well
elif self.params.lock.icon == "unlock": # only add when open
self._add_point(kp)
@dl.callback('out-kpt')
def _take_kpt(self, kpt):
print("Add kpoint e.g. 0,1,3 at selection(s)")
self._set_kpt(kpt)
@dl.callback('out-lab')
def _set_lab(self, lab):
print("Add label[:number] e.g. X:5,Y,L:9")
self._add_label(lab)
@dl.callback
def _update_theme(self, fig, theme):
return super()._update_theme(fig, theme)
@dl.callback
def _respond_click(self, lock, delp):
if lock.clicked:
self.params.lock.icon = 'lock' if self.params.lock.icon == 'unlock' else 'unlock'
self._show_info(f"{self.params.lock.icon}ed adding/deleting kpoints!")
elif delp.clicked:
if self.params.lock.icon == 'unlock': # Do not delete locked
sm = self.params.sm
for v in sm.value: # for loop here is important to update selection properly
sm.options = [opt for opt in sm.options if opt[1] != v]
self._update_selection() # update plot as well
else:
self._show_info("Select point(s) to delete")
else:
self._show_info("cannot delete point when locked!", 'red')
def _add_point(self, kpt):
sm = self.params.sm
sm.options = [*sm.options, ("⋮", len(sm.options))]
# select to receive point as well, this somehow makes infinit loop issues,
# but need to work, so self._oldclick is used to check in _click callback
sm.value = (sm.options[-1][1],)
self._set_kpt(kpt) # add point, label and plot back
def _set_kpt(self,kpt):
point = kpt
if isinstance(kpt, str) and kpt:
if len(kpt.split(",")) != 3: return # Enter at incomplete input
point = [float(v) for v in kpt.split(",")] # kpt is value widget
if not isinstance(point,(list, tuple,np.ndarray)): return # None etc
if len(point) != 3:
raise ValueError("Expects KPOINT of 3 floats")
self._kpoints.update({v: point for v in self.params.sm.value})
label = "{:>8.4f} {:>8.4f} {:>8.4f}".format(*point)
self.params.sm.options = [
(label, value) if value in self.params.sm.value else (lab, value)
for (lab, value) in self.params.sm.options
]
self._add_label(self.params.lab.value) # Re-adjust labels and update plot as well
def _add_label(self, lab):
labs = [" ⋮ " for _ in self.params.sm.options] # as much as options
for idx, (_, lb) in enumerate(zip(self.params.sm.options, (lab or "").split(","))):
labs[idx] = labs[idx] + lb # don't leave empty anyhow
self.params.sm.options = [
(v.split("⋮")[0].strip() + lb, idx)
for (v, idx), lb in zip(self.params.sm.options, labs)
]
self._update_selection() # Update plot in both cases, by click or manual input
[docs]
def get_kpoints(self):
"Returns kpoints list including labels and numbers in intervals if given."
keys = [idx for (_, idx) in self.params.sm.options if idx in self._kpoints] # order and existence is important
kpts = [self._kpoints[k] for k in keys]
LN = [
lab.split("⋮")[1].strip().split()
for (lab, idx) in self.params.sm.options
if idx in keys
]
for idx, ln in enumerate(LN):
if len(ln) == 2:
kpts[idx] = tuple([*kpts[idx], ln[0], int(ln[1])]) # label, number
elif len(ln) == 1:
try:
kpts[idx] = tuple([*kpts[idx], int(ln[0])]) # number
except:
kpts[idx] = tuple([*kpts[idx], ln[0]]) # label
elif len(ln) == 0:
kpts[idx] = tuple(kpts[idx])
else:
raise ValueError(
"Label and number should be separated by space or only one of them should be present"
)
return kpts
[docs]
def get_coords_labels(self):
"Returns tuple of (coordinates, labels) to directly plot."
points = self.get_kpoints()
coords = self.poscar.bz.to_cartesian([p[:3] for p in points]).tolist() if points else []
labels = [p[3] if (len(p) >= 4 and isinstance(p[3], str)) else "" for p in points]
numbers = [
p[4] if len(p) == 5
else p[3] if (len(p) == 4 and isinstance(p[3], int))
else "" for p in points]
j = 0
for i, n in enumerate(numbers, start=1):
if isinstance(n, int) and n == 0:
labels.insert(i + j, "")
coords.insert(i + j, [np.nan, np.nan, np.nan])
j += 1
return np.array(coords), labels
def _update_selection(self):
coords, labels = self.get_coords_labels()
with self.params.fig.batch_animate():
for trace in self.params.fig.data:
if "path" in trace.name and coords.any():
trace.x = coords[:, 0]
trace.y = coords[:, 1]
trace.z = coords[:, 2]
trace.text = _fmt_labels(labels) # convert latex to html equivalent
[docs]
@_sub_doc(lat.get_kpath, {"kpoints :.*n :": "n :", "rec_basis :.*\n\n": "\n\n"})
@_sig_kwargs(lat.get_kpath, ("kpoints", "rec_basis"))
def get_kpath(self, n=5, **kwargs):
return self.poscar.get_kpath(self.get_kpoints(), n=n, **kwargs)
[docs]
def iplot(self):
"Returns disconnected current plotly figure"
return go.Figure(data=self.params.fig.data, layout=self.params.fig.layout)
[docs]
def splot(self, plane=None, fmt_label=lambda x: x, plot_kws={}, **kwargs):
"""
Same as `ipyvasp.lattice.POSCAR.splot_bz` except it also plots path on BZ.
Parameters
----------
plane : str of plane like 'xy' to plot in 2D or None to plot in 3D
fmt_label : function, should take a string and return a string or (str, dict) of which dict is passed to `plt.text`.
plot_kws : dict of keyword arguments for `plt.plot` for kpath.
kwargs are passed to `ipyvasp.lattice.POSCAR.splot_bz`.
"""
if not isinstance(plot_kws, dict):
raise TypeError("plot_ks should be a dict")
ax = self.poscar.splot_bz(plane=plane, **kwargs)
kpoints = self.poscar.to_basis(coords, reciprocal=True)
coords, labels = self.get_coords_labels()
self.poscar.splot_kpath(
kpoints, labels=labels, fmt_label=fmt_label, **plot_kws
) # plots on ax automatically
return ax
# Should be at end
del fix_signature # no more need