Source code for ipyvasp._enplots

import re
from collections.abc import Iterable

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from matplotlib.collections import LineCollection
import plotly.graph_objects as go


# Inside packages import
from . import utils as gu
from .core.plot_toolkit import (
    adjust_axes,
    get_axes,
    add_text,
    add_legend,
    add_colorbar,
    color_cube,
)


def join_ksegments(kpath, *pairs):
    """Joins a broken kpath's next segment to previous. `pairs` should provide the adjacent indices of the kpoints to be joined."""
    path_arr = np.array(kpath)
    path_max = path_arr.max()
    if pairs:
        for pair in pairs:
            if len(pair) != 2:
                raise ValueError(f"{pair} should have exactly two indices.")
            for idx in pair:
                if not isinstance(idx, (int, np.integer)):
                    raise ValueError(f"{pair} should have integers, got {idx!r}.")

            idx_1, idx_2 = pair
            if idx_2 - idx_1 != 1:
                raise ValueError(
                    f"Indices in pair ({idx_1}, {idx_2}) are not adjacent."
                )
            path_arr[idx_2:] -= path_arr[idx_2] - path_arr[idx_1]
        path_arr = path_max * path_arr / path_arr[-1]  # Normalize to max value back
    return list(path_arr)


# This is to verify things together and make sure they are working as expected.
def _validate_data(K, E, elim, kticks, interp):
    if np.ndim(E) != 2:
        raise ValueError("E must be a 2D array.")

    if np.shape(E)[0] != len(K):
        raise ValueError("Length of first dimension of E must be equal to length of K.")

    if isinstance(kticks, zip):
        kticks = list(kticks)  # otherwise it will be empty after first use
    elif kticks is None:
        kticks = []

    if not isinstance(kticks, (list, tuple)):
        raise ValueError(
            "kticks must be a list, tuple or zip consisting of (index, label) pairs. index must be an int or tuple of (i, i+1) to join broken path."
        )

    for k, v in kticks:
        if not isinstance(k, (np.integer, int)):
            raise ValueError("First item of pairs in kticks must be int")
        if not isinstance(v, str):
            raise ValueError("Second item of pairs in kticks must be str.")

    pairs = [
        (k - 1, k) for k, v in kticks if v.startswith("<=")
    ]  # Join broken path at these indices
    K = join_ksegments(K, *pairs)
    inds = [k for k, _ in kticks]

    xticks = (
        [K[i] for i in inds] if inds else None
    )  # Avoid turning off xticks if no kticks given
    xticklabels = (
        [v.replace("<=", "") for _, v in kticks] if kticks else None
    )  # clean up labels

    if elim and len(elim) != 2:
        raise ValueError("elim must be a list or tuple of length 2.")

    if interp and not isinstance(interp, (int, np.integer, list, tuple)):
        raise ValueError("interp must be an integer or a list/tuple of (n,k).")

    if isinstance(interp, (list, tuple)) and len(interp) != 2:
        raise ValueError("interp must be an integer or a list/tuple of (n,k).")

    return K, E, xticks, xticklabels


_docs = dict(
    params="""
    Parameters
    ----------""",
    kticks="""kticks : list
        List of pairs [(int, str),...] for indices of high symmetry k-points. 
        To join a broken path, use '<=' before symbol, e.g.  [(0, 'G'),(40, '<=K|M'), ...] 
        will join 40 back to 39. You can also use shortcut like zip([0,10,20],'GMK').""",
    interp="""interp : int or list/tuple
        If int, n is number of points to interpolate. If list/tuple, n is number of points and k is the order of spline.""",
    return_ax="""
    Returns
    -------
    matplotlib.pyplot.Axes""",
    return_fig="""
    Returns
    -------
    plotly.graph_objects.Figure""",
    K="""K : array-like
        Array of kpoints with shape (nkpts,)""",
    E="""E : array-like
        Array of eigenvalues with shape (nkpts, nbands)""",
    ax="""ax : matplotlib.pyplot.Axes
        Matplotlib axes to plot on. If None, a new figure and axes will be created.""",
    elim="""elim : list or tuple
        A list or tuple of length 2 for energy limits.""",
    pros="""pros : array-like
        Projections of shape (m,nk,nb), m is the number of projections. m <= 3 in rgb case.""",
    labels="""labels : list
        As many labels for as projections.""",
    colormap="""colormap : str
        A valid matplotlib colormap name.""",
    maxwidth="""maxwidth : float
        Maximum linewidth to which the projections line width will be scaled. Default is 3.""",
)


[docs] @gu._fmt_doc(_docs) def splot_bands(K, E, ax=None, elim=None, kticks=None, interp=None, **kwargs): """Plot band structure for a single spin channel and return the matplotlib axes which can be used to add other channel if spin polarized. {params}\n {K}\n {E}\n {ax}\n {elim}\n {kticks}\n {interp} kwargs are passed to matplotlib's command `plt.plot`. {return_ax} """ K, E, xticks, xticklabels = _validate_data(K, E, elim, kticks, interp) if interp: nk = interp if isinstance(interp, (list, tuple)) else (interp, 3) K, E = gu.interpolate_data(K, E, *nk) # handle broken paths breaks = [i for i in range(0, len(K)) if K[i - 1] == K[i]] K = np.insert(K, breaks, np.nan) E = np.insert(E, breaks, np.nan, axis=0) ax = get_axes() if ax is None else ax if "color" not in kwargs and "c" not in kwargs: kwargs["color"] = "C0" # default color from cycler to accommodate themes if "linewidth" not in kwargs and "lw" not in kwargs: kwargs["linewidth"] = 0.9 # default linewidth to make it look good lines = ax.plot(K, E, **kwargs) _ = [line.set_label(None) for line in lines[1:]] adjust_axes( ax=ax, ylabel="Energy (eV)", xticks=xticks, xticklabels=xticklabels, xlim=[min(K), max(K)], ylim=elim, vlines=True, top=True, right=True, ) return ax
def _make_line_collection( maxwidth=3, colors_list=None, rgb=False, shadow=False, uniwidth=False, **pros_data ): """ Returns a tuple of line collections for each given projection data. Parametrs --------- maxwidth : Default is 3. Max linewidth is scaled to maxwidth if an int of float is given. uniwidth : Default is False. If True, linewidth is set to maxwidth/2 for all lines. Only works for rgb_lines. colors_list: List of colors for multiple lines, length equal to 3rd axis length of colors. rgb : Default is False. If True and np.shape(colors)[-1] == 3, RGB line collection is returned in a tuple of length 1. Tuple is just to support iteration. **pros_data: Output dictionary from `_fix_data` containing kpath, evals, colors arrays. """ if not isinstance(maxwidth, (int, np.integer, float)): raise ValueError("maxwidth must be an int or float.") if not pros_data: raise ValueError("No pros_data given.") else: kpath = pros_data.get("kpath") evals = pros_data.get("evals") pros = pros_data.get("pros") for a, t in zip([kpath, evals, pros], ["kpath", "evals", "pros"]): if not np.any(a): raise ValueError("Missing {!r} from output of `_fix_data()`".format(t)) # Average pros on two consecutive KPOINTS to get that patch color. colors = pros[1:, :, :] / 2 + pros[:-1, :, :] / 2 # Near kpoints avearge colors = colors.transpose((1, 0, 2)).reshape( (-1, np.shape(colors)[-1]) ) # Must before lws if rgb: # Single channel line widths lws = np.sum(colors, axis=1) # Sum over RGB else: # For separate lines lws = colors.T # .T to access in for loop. lws = 0.1 + maxwidth * lws / ( float(np.max(lws)) or 1 ) # Rescale to maxwidth, with a residual with 0.1 as must be visible. if np.any(colors_list): lc_colors = colors_list else: cmap = plt.cm.get_cmap("viridis") lc_colors = cmap(np.linspace(0, 1, np.shape(colors)[-1])) lc_colors = lc_colors[:, :3] # Skip Alpha # Reshaping data same as colors reshaped above, nut making line patches too. kgrid = np.repeat(kpath, np.shape(evals)[1], axis=0).reshape( (-1, np.shape(evals)[1]) ) narr = np.concatenate((kgrid, evals), axis=0).reshape((2, -1, np.shape(evals)[1])) marr = ( np.concatenate((narr[:, :-1, :], narr[:, 1:, :]), axis=0) .transpose() .reshape((-1, 2, 2)) ) # Make Line collection path_shadow = None if shadow: path_shadow = [ path_effects.SimpleLineShadow(offset=(0, -0.8), rho=0.2), path_effects.Normal(), ] if rgb and np.shape(colors)[-1] == 3: return ( LineCollection( marr, colors=colors, linewidths=(maxwidth / 2,) if uniwidth else lws, path_effects=path_shadow, ), ) else: lcs = [ LineCollection(marr, colors=_cl, linewidths=lw, path_effects=path_shadow) for _cl, lw in zip(lc_colors, lws) ] return tuple(lcs) # Further fix data for all cases which have projections def _fix_data(K, E, pros, labels, interp, rgb=False, **others): "Input pros must be [m,nk,nb], output is [nk,nb, m]. `others` must have shape [nk,nb] for occupancies or [nk,3] for kpoints" if np.shape(pros)[-2:] != np.shape(E): raise ValueError("last two dimensions of `pros` must have same shape as `E`") if np.ndim(pros) == 2: pros = np.expand_dims(pros, 0) # still as [m,nk,nb] if others: for k, v in others.items(): if np.shape(v)[0] != len(K): raise ValueError(f"{k} must have same length as K") if rgb and len(pros) > 3: raise ValueError("In RGB lines mode, pros.shape[-1] <= 3 should hold") # Should be after exapnding dims but before transposing if labels and len(labels) != len(pros): raise ValueError("labels must be same length as pros") pros = np.transpose(pros, (1, 2, 0)) # [nk,nb,m] now # Normalize overall data because colors are normalized to 0-1 min_max_pros = (np.min(pros), np.max(pros)) # For data scales to use later c_max = np.ptp(pros) if c_max > 0.0000001: # Avoid division error pros = (pros - np.min(pros)) / c_max data = {"kpath": K, "evals": E, "pros": pros, **others, "ptp": min_max_pros} if interp: nk = interp if isinstance(interp, (list, tuple)) else (interp, 3) min_d, max_d = np.min(pros), np.max(pros) # For cliping _K, E = gu.interpolate_data(K, E, *nk) pros = gu.interpolate_data(K, pros, *nk)[1].clip(min=min_d, max=max_d) data.update({"kpath": _K, "evals": E, "pros": pros}) for k, v in others.items(): data[k] = gu.interpolate_data(K, v, *nk)[1] # Handle kpath discontinuities X = data["kpath"] breaks = [i for i in range(0, len(X)) if X[i - 1] == X[i]] if breaks: data["kpath"] = np.insert(data["kpath"], breaks, np.nan) data["evals"] = np.insert(data["evals"], breaks, np.nan, axis=0) data["pros"] = np.insert( data["pros"], breaks, data["pros"][breaks], axis=0 ) # Repeat the same data to keep color consistent for ( key ) in ( others ): # don't use items here, as interpolation may have changed the shape data[key] = np.insert( data[key], breaks, data[key][breaks], axis=0 ) # Repeat here too return data
[docs] @gu._fmt_doc(_docs) def splot_rgb_lines( K, E, pros, labels, ax=None, elim=None, kticks=None, interp=None, maxwidth=3, uniwidth=False, colormap=None, colorbar=True, N=9, shadow=False, ): """Plot projected band structure for a given projections. {params}\n {K}\n {E}\n {pros}\n {labels}\n {ax}\n {elim}\n {kticks}\n {interp}\n {maxwidth} uniwidth : bool If True, use same linewidth for all patches to maxwidth/2. Otherwise, use linewidth proportional to projection value. {colormap} colorbar : bool If True, add colorbar, otherwise add attribute to ax to add colorbar or color cube later N : int Number of colors in colormap shadow : bool If True, add shadow to lines {return_ax} Returned ax has additional attributes: .add_colorbar() : Add colorbar that represents most recent plot .color_cube() : Add color cube that represents most recent plot if `pros` is 3 components """ K, E, xticks, xticklabels = _validate_data(K, E, elim, kticks, interp) ax = get_axes() if ax is None else ax # ===================================================== pros_data = _fix_data( K, E, pros, labels, interp, rgb=True ) # (nk,), (nk, nb), (nk, nb, m) at this point colors = pros_data["pros"] how_many = np.shape(colors)[-1] if how_many == 1: percent_colors = colors[:, :, 0] percent_colors = percent_colors / np.max(percent_colors) pros_data["pros"] = plt.cm.get_cmap(colormap or "copper", N)(percent_colors)[ :, :, :3 ] # Get colors in RGB space. elif how_many == 2: _sum = np.sum(colors, axis=2) _sum[_sum == 0] = 1 # Avoid division error percent_colors = colors[:, :, 1] / _sum # second one is on top _colors = plt.cm.get_cmap(colormap or "coolwarm", N)(percent_colors)[ :, :, :3 ] # Get colors in RGB space. _colors[np.sum(colors, axis=2) == 0] = [ 0, 0, 0, ] # Set color to black if no total projection pros_data["pros"] = _colors else: # Normalize color at each point only for 3 projections. c_max = np.max(colors, axis=2, keepdims=True) c_max[c_max == 0] = 1 # Avoid division error: colors = colors / c_max # Weights to be used for color interpolation. nsegs = np.linspace(0, 1, N, endpoint=True) for low, high in zip(nsegs[:-1], nsegs[1:]): colors[(colors >= low) & (colors < high)] = ( low + high ) / 2 # Center of squre is taken in color_cube A, B, C = plt.cm.get_cmap(colormap or "brg", N)([0, 0.5, 1])[:, :3] pros_data["pros"] = np.array( [ [(r * A + g * B + b * C) / ((r + g + b) or 1) for r, g, b in _cols] for _cols in colors ] ) # Normalize after picking colors from colormap as well to match the color_cube. c_max = np.max(pros_data["pros"], axis=2, keepdims=True) c_max[c_max == 0] = 1 # Avoid division error: pros_data["pros"] = pros_data["pros"] / c_max (line_coll,) = _make_line_collection( **pros_data, rgb=True, colors_list=None, maxwidth=maxwidth, shadow=shadow, uniwidth=uniwidth, ) ax.add_collection(line_coll) ax.autoscale_view() adjust_axes( ax, xticks=xticks, xticklabels=xticklabels, xlim=[min(K), max(K)], ylim=elim, vlines=True, top=True, right=True, ) # ==================================================== # Add colorbar/legend etc. cmap = colormap or ( "copper" if how_many == 1 else "brg" if how_many == 3 else "coolwarm" ) ticks = ( np.linspace(*pros_data["ptp"], 5, endpoint=True) if how_many == 1 else None if how_many == 3 else [0, 1] ) ticklabels = [f"{t:4.2f}" for t in ticks] if how_many == 1 else labels if colorbar: if how_many < 3: cax = add_colorbar( ax, N=N, vertical=True, ticklabels=ticklabels, ticks=ticks, cmap_or_clist=cmap, ) if how_many == 1: cax.set_title(labels[0]) else: color_cube(ax, colormap=colormap or "brg", labels=labels, N=N) else: # MAKE PARTIAL COLOR CUBE AND COLORBAR HERE FOR LATER USE. def recent_colorbar( cax=None, tickloc="right", vertical=True, digits=2, fontsize=8 ): return add_colorbar( ax=ax, cax=cax, cmap_or_clist=cmap, N=N, ticks=ticks, ticklabels=ticklabels, tickloc=tickloc, vertical=vertical, digits=digits, fontsize=fontsize, ) ax.add_colorbar = recent_colorbar def recent_color_cube(loc=(0.67, 0.67), size=0.3, color="k", fontsize=10): return color_cube( ax=ax, colormap=cmap, labels=labels, N=N, loc=loc, size=size, color=color, fontsize=fontsize, ) ax.color_cube = recent_color_cube return ax
[docs] @gu._fmt_doc(_docs) def splot_color_lines( K, E, pros, labels, axes=None, elim=None, kticks=None, interp=None, maxwidth=3, colormap=None, shadow=False, showlegend=True, xyc_label=[0.2, 0.85, "black"], # x, y, color only if showlegend = False **kwargs, ): """Plot projected band structure for a given projections. {params}\n {K}\n {E}\n {pros}\n {labels} axes : matplotlib.axes.Axes or list of Axes Number of axes should be 1 or equal to the number of projections to plot separately. If None, creates new axes. {elim}\n {kticks}\n {interp}\n {maxwidth}\n {colormap} shadow : bool If True, add shadow to lines showlegend : bool If True, add legend, otherwise adds a label to the plot. xyc_label : list or tuple List of (x, y, color) for the label. Used only if showlegend = False kwargs are passed to matplotlib's command `ax.legend`. Returns ------- One or as many matplotlib.axes.Axes as given by `axes` parameter. """ K, E, xticks, xticklabels = _validate_data(K, E, elim, kticks, interp) pros_data = _fix_data(K, E, pros, labels, interp, rgb=False) if colormap not in plt.colormaps(): c_map = plt.cm.get_cmap("viridis") print( "colormap = {!r} not exists, falling back to default color map.".format( colormap ) ) else: c_map = plt.cm.get_cmap(colormap) c_vals = np.linspace( 0, 1, pros_data["pros"].shape[-1] ) # Output pros data has shape (nk, nb, projections) colors = c_map(c_vals) if not np.any([axes]): axes = get_axes() axes = np.array([axes]).ravel() # Safe list any axes size if len(axes) == 1: axes = [ axes[0] for _ in range(pros_data["pros"].shape[-1]) ] # Make a list of axes for each projection elif len(axes) != pros_data["pros"].shape[-1]: raise ValueError("Number of axes should be 1 or same as number of projections") lcs = _make_line_collection( maxwidth=maxwidth, colors_list=colors, rgb=False, shadow=shadow, **pros_data ) _ = [ax.add_collection(lc) for ax, lc in zip(axes, lcs)] _ = [ax.autoscale_view() for ax in axes] if showlegend: # Default values for legend_kwargs are overwritten by **kwargs legend_kws = { "ncol": 4, "anchor": (0, 1.05), "handletextpad": 0.5, "handlelength": 1, "fontsize": "small", "frameon": False, **kwargs, } add_legend( ax=axes[0], colors=colors, labels=labels, widths=maxwidth, **legend_kws ) else: xs, ys, colors = xyc_label _ = [ add_text(ax, xs=xs, ys=ys, colors=colors, txts=lab) for ax, lab in zip(axes, labels) ] _ = [ adjust_axes( ax=ax, xticks=xticks, xticklabels=xticklabels, xlim=[min(K), max(K)], ylim=elim, vlines=True, top=True, right=True, ) for ax in axes ] return axes
def _fix_dos_data(energy, dos_arrays, labels, colors, interp): if colors is not None: if len(colors) != len(labels): raise ValueError("If colors is given,they must have same length as labels.") if len(dos_arrays) != len(labels): raise ValueError("dos_arrays and labels must have same length.") for i, arr in enumerate(dos_arrays): if len(energy) != len(arr): raise ValueError( f"array {i+1} in dos_arrays must have same length as energy." ) if len(dos_arrays) < 1: raise ValueError("dos_arrays must have at least one array.") if interp and not isinstance(interp, (int, np.integer, list, tuple)): raise ValueError("interp must be an integer or a list/tuple of (n,k).") if isinstance(interp, (list, tuple)) and len(interp) != 2: raise ValueError("interp must be an integer or a list/tuple of (n,k).") if interp: nk = ( interp if isinstance(interp, (list, tuple)) else (interp, 3) ) # default spline order is 3. en, arr1 = gu.interpolate_data(energy, dos_arrays[0], nk) arrays = [arr1] for a in dos_arrays[1:]: arrays.append(gu.interpolate_data(energy, a, nk)[1]) return en, arrays, labels, colors return energy, dos_arrays, labels, colors
[docs] @gu._fmt_doc(_docs) def splot_dos_lines( energy, dos_arrays, labels, ax=None, elim=None, colormap="tab10", colors=None, fill=True, vertical=False, stack=False, interp=None, showlegend=True, legend_kws={ "ncol": 4, "anchor": (0, 1.0), }, **kwargs, ): """Plot density of states (DOS) lines. {params} energy : array-like, shape (n,) dos_arrays : list of array_like, each of shape (n,) or array-like (m,n) labels : list of str, length = len(dos_arrays) should hold. {ax}\n {elim}\n {colormap} colors : list of str, length = len(dos_arrays) should hold if given, and will override colormap. fill : bool, default True, if True, fill the area under the DOS lines. vertical : bool, default False, if True, plot DOS lines vertically. stack : bool, default False, if True, stack the DOS lines. Only works for horizontal plots. {interp} showlegend : bool, default True, if True, show legend. legend_kws : dict, default is just hint, anything that `ipyvasp.add_legend` accepts can be passed, only used if showlegend is True. keyword arguments are passed to matplotlib.axes.Axes.plot or matplotlib.axes.Axes.fill_between or matplotlib.axes.Axes.fill_betweenx. {return_ax}""" energy, dos_arrays, labels, colors = _fix_dos_data( energy, dos_arrays, labels, colors, interp ) # validate data brfore plotting. if colors is None: colors = plt.cm.get_cmap(colormap)(np.linspace(0, 1, len(labels))) if ax is None: ax = get_axes() if "c" in kwargs: kwargs.pop("c") if "color" in kwargs: kwargs.pop("color") if stack: if vertical: raise NotImplementedError("stack is not supported for vertical plots.") else: ax.stackplot(energy, *dos_arrays, labels=labels, colors=colors, **kwargs) else: for arr, label, color in zip(dos_arrays, labels, colors): if fill: fill_func = ax.fill_betweenx if vertical else ax.fill_between fill_func(energy, arr, color=mpl.colors.to_rgba(color, 0.4)) if vertical: ax.plot(arr, energy, label=label, color=color, **kwargs) else: ax.plot(energy, arr, label=label, color=color, **kwargs) if showlegend: kwargs = { "ncol": 4, "anchor": (0, 1.0), "handletextpad": 0.5, "handlelength": 1, "fontsize": "small", "frameon": False, **legend_kws, } add_legend(ax, **kwargs) # Labels are picked from plot elim = elim if elim is not None else [] kws = dict(ylim=elim) if vertical else dict(xlim=elim) xlabel, ylabel = "Energy (eV)", "DOS" if vertical: xlabel, ylabel = ylabel, xlabel adjust_axes(ax, xlabel=xlabel, ylabel=ylabel, **kws) return ax
# PLOTLY PLOTS def _format_rgb_data( K, E, pros, labels, interp, occs, kpoints, maxwidth=10, indices=None ): "Transform data to 1D for rgb lines to plot effectently. Output is a dictionary." data = _fix_data(K, E, pros, labels, interp, rgb=True, occs=occs, kpoints=kpoints) # Note that data['pros'] is normalized to 0-1 rgb = np.zeros( (*np.shape(data["evals"]), 3) ) # Initialize rgb array, because there could be less than three channels if data["pros"].shape[2] == 3: rgb = data["pros"] elif data["pros"].shape[2] == 2: rgb[:, :, :2] = data["pros"] # Normalized overall color data labels = [*labels, ""] elif data["pros"].shape[2] == 1: rgb[:, :, :1] = data["pros"] # Normalized overall color data labels = [*labels, "", ""] # Since normalized data is Y = (X - X_min)/(X_max - X_min), so X = Y*(X_max - X_min) + X_min is the actual data. low, high = data["ptp"] data["norms"] = np.round( rgb * (high - low) + low, 3 ) # Read actual data back from normalized data. if data["pros"].shape[2] == 2: data["norms"][:, :, 2] = np.nan # Avoid wrong info here elif data["pros"].shape[2] == 1: data["norms"][:, :, 1:] = np.nan lws = np.sum(rgb, axis=2) # Sum of all colors lws = maxwidth * lws / (float(np.max(lws)) or 1) # Normalize to maxwidth data["widths"] = ( 0.0001 + lws ) # should be before scale colors, almost zero size of a data point with no contribution. # Now scale colors to 1 at each point. cl_max = np.max(data["pros"], axis=2) cl_max[cl_max == 0.0] = 1 # avoid divide by zero. Contributions are 4 digits only. data["pros"] = (rgb / cl_max[:, :, np.newaxis] * 255).astype( int ) # Normalized per point and set rgb data back to data. if indices is None: # make sure indices are in range indices = range(np.shape(data["evals"])[1]) # Now process data to make single data for faster plotting. K, E, C, S, CDATA = [], [], [], [], [] for i, b in enumerate(indices): K = [*K, *data["kpath"], np.nan] E = [*E, *data["evals"][:, i], np.nan] C = [ *C, *[f"rgb({r},{g},{b})" for (r, g, b) in data["pros"][:, i, :]], "rgb(0,0,0)", ] S = [*S, *data["widths"][:, i], data["widths"][-1, i]] CDATA = [*CDATA , *[ { "nk":j+1, **{f"k{u}":v for u,v in zip("xyz",xyz)}, "nb":b+1, "occ":occ, **{c:"" if np.isnan(v) else v for c,v in zip("rgb",rgb)} } for (j, xyz), occ, rgb in zip( enumerate(data["kpoints"]), data["occs"][:, i],data["norms"][:, i] ) ], {k:np.nan for k in ("nk","kx","ky","kz","nb","occ","r","g","b")}] return { "K": K, "E": E, "C": C, "S": S, "CDATA": CDATA, "labels": labels, } # K, energy, marker color, marker size, text, labels that get changed def _fmt_labels(ticklabels): if isinstance(ticklabels, Iterable): labels = [ re.sub( r"\$\_\{(.*)\}\$|\$\_(.*)\$", r"<sub>\1\2</sub>", lab, flags=re.DOTALL ) for lab in ticklabels ] # will match _{x} or _x not both at the same time. return [ re.sub( r"\$\^\{(.*)\}\$|\$\^(.*)\$", r"<sup>\1\2</sup>", lab, flags=re.DOTALL ) for lab in labels ] return ticklabels _hover_temp = { # keep order same "xy":"(%{x}, %{y})", "k": "<br>K<sub>%{customdata.nk}</sub>: %{customdata.kx:.3f} %{customdata.ky:.3f} %{customdata.kz:.3f}", "b":"Band: %{customdata.nb}, Occ: %{customdata.occ:.4f}" }
[docs] @gu._fmt_doc(_docs) def iplot_bands( K, E, occs = None, fig=None, elim=None, kticks=None, interp=None, title=None, **kwargs ): """Plot band structure using plotly. {params}\n {K}\n {E} fig : plotly.graph_objects.Figure If not given, create a new figure. {elim}\n {kticks}\n {interp} title : str, title of plot kwargs are passed to plotly.graph_objects.Scatter {return_fig}""" if isinstance(K, dict): # Provided by Bands class, don't do is yourself K, indices = K["K"], K["indices"] else: K, indices = K, range(np.shape(E)[1]) # Assume K is provided by user K, E, xticks, xticklabels = _validate_data(K, E, elim, kticks, interp) data = _format_rgb_data( K, E, [E], # don't let it fail if no projections ["X"], interp, E if occs is None else occs, np.array([K, K, K]).reshape((-1, 3)), maxwidth=1, indices=indices, ) # moking other arrays, we need only K, E = data["K"], data["E"] if fig is None: fig = go.Figure() kwargs = { "mode": "markers + lines", "marker": dict(size=0.1), "hovertemplate": "<br>".join(_hover_temp.values()), "customdata": [{k:v for k,v in d.items() if not k in 'rgb'} for d in data["CDATA"]], # useless rgb data to skip **kwargs, } # marker so that it is selectable by box, otherwise it does not fig.add_trace(go.Scatter(x=K, y=E, **kwargs)) fig.update_layout( template="plotly_white", title=( title or "" ), # Do not set autosize = False, need to be responsive in widgets boxes margin=go.layout.Margin(l=60, r=50, b=40, t=75, pad=0), yaxis=go.layout.YAxis(title_text="Energy (eV)", range=elim or [min(E), max(E)]), xaxis=go.layout.XAxis( ticktext=_fmt_labels(xticklabels), tickvals=xticks, tickmode="array", range=[min(K), max(K)], ), font=dict(family="stix, serif", size=14), ) return fig
[docs] @gu._fmt_doc(_docs) def iplot_rgb_lines( K, E, pros, labels, occs, kpoints, fig=None, elim=None, kticks=None, interp=None, maxwidth=10, mode="markers + lines", title=None, **kwargs, ): """Interactive plot of band structure with rgb data points using plotly. {params}\n {K}\n {E}\n {pros}\n {labels}\n {elim}\n {kticks}\n {interp} occs : array-like, shape (nk,nb) kpoints : array-like, shape (nk,3) fig : plotly.graph_objects.Figure, if not provided, a new figure will be created maxwidth : float, maximum linewidth, 10 by default mode : str, plotly mode, 'markers + lines' by default, see modes in `plotly.graph_objects.Scatter`. title : str, title of the figure, labels are added to the end of the title. kwargs are passed to `plotly.graph_objects.Scatter`. {return_fig}""" if isinstance(K, dict): # Provided by Bands class, don't do is yourself K, indices = K["K"], K["indices"] else: K, indices = K, range(np.shape(E)[1]) # Assume K is provided by user K, E, xticks, xticklabels = _validate_data(K, E, elim, kticks, interp) data = _format_rgb_data( K, E, pros, labels, interp, occs, kpoints, maxwidth=maxwidth, indices=indices ) K, E, C, S, labels = [data[key] for key in "K E C S labels".split()] if fig is None: fig = go.Figure() kwargs.pop("marker_color", None) # Provided by C kwargs.pop("marker_size", None) # Provided by S kwargs.update( { "marker": { "line_color": "rgba(0,0,0,0)", **kwargs.get("marker", {}), "color": C, "size": S, }, "hovertemplate": "<br>".join([_hover_temp["xy"], "<br>Projection: [{}, {}, {}]".format(*labels), # clean labels instead of '' "Value: [%{customdata.r}, %{customdata.g}, %{customdata.b}]", _hover_temp["k"], _hover_temp["b"], ]), "customdata": data["CDATA"], # need for selection and hover template } ) # marker edge should be free fig.add_trace(go.Scatter(x=K, y=E, mode=mode, **kwargs)) fig.update_layout( template="plotly_white", title=(title or "") + "[" + ", ".join(labels) + "]", # Do not set autosize = False, need to be responsive in widgets boxes margin=go.layout.Margin(l=60, r=50, b=40, t=75, pad=0), yaxis=go.layout.YAxis(title_text="Energy (eV)",range=elim or [min(E), max(E)]), xaxis=go.layout.XAxis( ticktext=_fmt_labels(xticklabels), tickvals=xticks, tickmode="array", range=[min(K), max(K)], ), font=dict(family="stix, serif", size=14), ) return fig
[docs] def iplot_dos_lines( energy, dos_arrays, labels, fig=None, elim=None, colormap="tab10", colors=None, fill=True, vertical=False, stack=False, mode="lines", interp=None, **kwargs, ): """ Plot density of states (DOS) lines. Parameters ---------- energy : array-like, shape (n,) dos_arrays : list of array_like, each of shape (n,) or array-like (m,n) labels : list of str, length = len(dos_arrays) should hold. fig : plotly.graph_objects.Figure, if not provided, a new figure will be created {elim} {colormap} colors : list of str, length = len(dos_arrays) should hold if given, and will override colormap. Should be valid CSS colors. fill : bool, default True, if True, fill the area under the DOS lines. vertical : bool, default False, if True, plot DOS lines vertically. mode : str, default 'lines', plotly mode, see modes in `plotly.graph_objects.Scatter`. stack : bool, default False, if True, stack the DOS lines. Only works for horizontal plots. {interp} keyword arguments are passed to `plotly.graph_objects.Scatter`. {return_fig}""" energy, dos_arrays, labels, colors = _fix_dos_data( energy, dos_arrays, labels, colors, interp ) if fig is None: fig = go.Figure() fig.update_layout( margin=go.layout.Margin(l=60, r=50, b=40, t=75, pad=0), font=dict(family="stix, serif", size=14), ) # Do not set autosize = False, need to be responsive in widgets boxes if elim: ylim = [min(elim), max(elim)] else: ylim = [min(energy), max(energy)] if colors is None: from matplotlib.pyplot import cm _colors = cm.get_cmap(colormap)(np.linspace(0, 1, 2 * len(labels))) colors = [ "rgb({},{},{})".format(*[int(255 * x) for x in c[:3]]) for c in _colors ] if vertical: if stack: raise NotImplementedError("stack is not supported for vertical plots") _fill = "tozerox" if fill else None fig.update_yaxes(range=ylim, title="Energy (eV)") fig.update_xaxes(title="DOS") for arr, label, color in zip(dos_arrays, labels, colors): fig.add_trace( go.Scatter( y=energy, x=arr, line_color=color, fill=_fill, mode=mode, name=label, **kwargs, ) ) else: extra_args = {"stackgroup": "one"} if stack else {} _fill = "tozeroy" if fill else None fig.update_xaxes(range=ylim, title="Energy (eV)") fig.update_yaxes(title="DOS") for arr, label, color in zip(dos_arrays, labels, colors): fig.add_trace( go.Scatter( x=energy, y=arr, line_color=color, fill=_fill, mode=mode, name=label, **kwargs, **extra_args, ) ) return fig