Source code for clease.plot_post_process

from typing import List, Optional, Tuple

from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np

from clease import ConvexHull, Evaluate
from clease.evaluate import supports_alpha_cv

[docs] def plot_fit( evaluate: Evaluate, plot_args: Optional[dict] = None, interactive: bool = False ) -> Figure: """ Figure object calculated (DFT) and predicted energies. If the plot_args dictionary contains keys, return figure object to relate plot_args keys :param evaluate: Use the evaluate object to define the plot argument. :param plot_args: plot_args dictionary contains: - "xlabel": x-axis label - "ylabel": y-axis label - "title": title of plot :param interactive: Add interactive elements to the plot. :return: Figure instance of plot """ if plot_args is None: plot_args = {} X = evaluate.get_energy_true() Y = evaluate.get_energy_predict() xlabel = plot_args.get("xlabel", r"E$_{DFT}$ (eV/atom)") ylabel = plot_args.get("ylabel", r"E$_{CE}$ (eV/atom)") title = plot_args.get("title", f"Fit using {len(evaluate.e_dft)} data points.") # rmin, rmax set the plot range of x, y coordinate if np.size(X) and np.size(Y) != 0: e_range = max(np.append(X, Y)) - min(np.append(X, Y)) rmin = min(np.append(X, Y)) - 0.05 * e_range rmax = max(np.append(X, Y)) + 0.05 * e_range else: rmin = -10 rmax = 10 linear_fit = np.arange(rmin - 10, rmax + 10, 1) cv_name = evaluate.scoring_scheme.lower() if cv_name == "k-fold": # Figure out the k nsplits = evaluate.nsplits cv_name = f"{nsplits}-fold CV" cv = evaluate.get_cv_score() * 1000 rmse = evaluate.rmse() * 1000 fig = plt.figure() ax = fig.add_subplot(111) ax.set_title(title) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) ax.axis([rmin, rmax, rmin, rmax]) ax.text( 0.95, 0.01, cv_name + f" = {cv:.3f} meV/atom\n" f"RMSE = {rmse:.3f} meV/atom", verticalalignment="bottom", horizontalalignment="right", transform=ax.transAxes, fontsize=12, ) ax.plot(linear_fit, linear_fit, "r") plot_line = ax.plot(X, Y, "bo", mfc="none")[0] if interactive: from clease.interactive_plot import AnnotatedAx, ShowStructureOnClick annotations = _make_annotations_plot_fit(evaluate) # Construct the annotated axis objects. annotated_ax = AnnotatedAx( ax, [plot_line], annotations, structure_names=[evaluate.names], ) # Attach interactivity to the fig object. ShowStructureOnClick(fig, annotated_ax, evaluate.settings.db_name) return fig
[docs] def plot_fit_residual( evaluate: Evaluate, plot_args: Optional[dict] = None, interactive: bool = False ) -> Figure: """ Figure object subtracted (DFT) and predicted energies. If the plot_args dictionary contains keys, return figure object to relate plot_args keys :param evaluate: Use the evaluate object to define the plot argument. :param plot_args: plot_args dictionary contains: - "xlabel": x-axis label - "ylabel": y-axis label - "title": title of plot :param interactive: Add interactive elements to the plot. :return: Figure instance of plot """ if plot_args is None: plot_args = {} X = evaluate.get_energy_true() Y = evaluate.get_energy_predict() - X # eV/atom Y *= 1000 # meV/atom xlabel = plot_args.get("xlabel", "#OCC") ylabel = plot_args.get("ylabel", r"$E_{DFT} - E_{pred}$ (meV/atom)") title = plot_args.get("title", "Residual (v)") gridspec_kw = {"wspace": 0.0, "width_ratios": [5, 1]} fig, ax = plt.subplots(ncols=2, sharey=True, gridspec_kw=gridspec_kw) ax[0].axhline(0, ls="--") plot_line = ax[0].plot(X, Y, "v", mfc="none")[0] ax[0].set_xlabel(r"$E_{DFT}$ (eV/atom)") ax[0].set_ylabel(ylabel) ax[0].set_title(title) hist, bin_edges = np.histogram(Y, bins=30) h = bin_edges[1] - bin_edges[0] ax[1].barh(bin_edges[:-1], hist, height=h, color="#bdbdbd") ax[1].set_xlabel(xlabel) ax[1].spines["right"].set_visible(False) ax[1].spines["top"].set_visible(False) if interactive: from clease.interactive_plot import AnnotatedAx, ShowStructureOnClick annotations = _make_annotations_plot_fit(evaluate) # Construct the annotated axis objects. annotated_ax = AnnotatedAx( ax[0], [plot_line], annotations, structure_names=[evaluate.names], ) # Attach interactivity to the fig object. ShowStructureOnClick(fig, annotated_ax, evaluate.settings.db_name) return fig
[docs] def plot_eci( evaluate: Evaluate, plot_args: Optional[dict] = None, ignore_sizes=(), interactive: bool = False, ) -> Figure: """ Figure object of ECI value according to cluster diameter If the plot_args dictionary contains keys, return figure object to relate plot_args keys :param evaluate: Use the evaluate object to define the plot argument. :param plot_args: plot_args dictionary contains: - "xlabel": x-axis label - "ylabel": y-axis label - "title": title of plot - "sizes": list of int to include n-body cluster in plot :param ignore_sizes: list of ints Sizes listed in this list will not be plotted. E.g. ``ignore_sizes=[0]`` will exclude the 0-body cluster. Default is to not ignore any clusters. :param interactive: Add interactive elements to the plot. :return: Figure instance of plot """ if plot_args is None: plot_args = {} # eci_by_size dictionary contains eci, name, distance eci_by_size = evaluate.get_eci_by_size() xlabel = plot_args.get("xlabel", r"Cluster diameter ($n^{th}$ nearest neighbor)") ylabel = plot_args.get("ylabel", "ECI (eV/atom)") title = plot_args.get("title", "Plot ECI") sizes = plot_args.get("sizes", list(eci_by_size.keys())) markers = ["o", "v", "x", "D", "^", "h", "s", "p"] lines = [] annotations = [] fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.axhline(0.0, ls="--", color="grey") ax.xaxis.set_major_locator(plt.MultipleLocator(1)) ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) for size in sizes: if ignore_sizes and size in ignore_sizes: continue data = eci_by_size[size] # Add 1, as NN starts from 1 and not 0 X = np.array(data["distance"]) + 1 Y = data["eci"] mrk = markers[size % len(markers)] line = ax.plot(X, Y, label=f"{size}-body", marker=mrk, mfc="none", ls="", markersize=8) # Make annotations for interactive plots, since we have all the data we need prepared here. if interactive: lines.append(line[0]) annot = [ ( f"Size: {size}\nDiameter: {dist:d}\nName: {name}\n" f"Radius: {radius:.3f} Å\nECI: {eci:.4f} eV/atom" ) for dist, name, eci, radius in zip(X, data["name"], Y, data["radius"]) ] annotations.append(annot) ax.legend() if interactive: from clease.interactive_plot import AnnotatedAx, InteractivePlot # Construct the annotated axis objects. annotated_ax = AnnotatedAx( ax, lines, annotations, ) # Attach interactivity to the fig object. InteractivePlot(fig, annotated_ax) return fig
[docs] def plot_cv(evaluate: Evaluate, plot_args: Optional[dict] = None) -> Figure: """ Figure object of CV values according to alpha values If the plot_args dictionary contains keys, return figure object to relate plot_args keys :param evaluate: Use the evaluate object to define the plot argument. :param plot_args: plot_args dictionary contains: - "xlabel": x-axis label - "ylabel": y-axis label - "title": title of plot :return: Figure instance of plot """ if not supports_alpha_cv(evaluate.scheme): raise ValueError(f"Scheme {evaluate.scheme!r} doesn't support alpha CV.") if plot_args is None: plot_args = {} alpha_cv_data = evaluate.cv_scores xlabel = plot_args.get("xlabel", "alpha") ylabel = plot_args.get("ylabel", "CV score (meV/atom)") title = plot_args.get("title", "CV score vs. alpha") fig = plt.figure() ax = fig.add_subplot(111) ax.set_title(title) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) # this variable is for getting minimum cv min_cv = min(alpha_cv_data, key=lambda x: x["cv"]) X = [] Y = [] for data in alpha_cv_data: X.append(data["alpha"]) # Convert from eV/atom to meV/atom Y.append(data["cv"] * 1000) ax.plot(X, Y) ax.text( 0.65, 0.01, f"min. CV score:\n" f"alpha = {min_cv['alpha']:.10f} \n" f"CV = {min_cv['cv'] * 1000.0:.3f}" f" meV/atom", verticalalignment="bottom", horizontalalignment="left", transform=ax.transAxes, fontsize=10, ) # Display the point with the best CV score as a red circle idx = np.argmin(Y) ax.plot(X[idx], Y[idx], "ro", markersize=5) return fig
[docs] def plot_convex_hull(evaluate: Evaluate, interactive: bool = False) -> Figure: """Plot the convex hull of an evaluate object. Args: evaluate (Evaluate): The Evaluate object to draw the convex hull from. interactive (bool, optional): Plot as an interactive figure?. Defaults to False. """ e_pred = evaluate.get_energy_predict() cnv_hull = ConvexHull(evaluate.settings.db_name, select_cond=evaluate.select_cond) # Construct the figure object, and plot the DFT points. # Also fetch the Line2D objects which are drawn on the figure (i.e., the scatter points) # dft_lines is for interactive plotting. fig, dft_lines = cnv_hull.plot(return_lines=True) # `conc_per_frame` is the concentration with respect to the total number of atoms for # each frame conc_per_frame = evaluate.atomic_concentrations # `concs` is dictionary with the keys of species with value being the # concentrations among the frames concs = {key: [] for key in cnv_hull._unique_elem} for frame_conc in conc_per_frame: for key, value in concs.items(): value.append(frame_conc.get(key, 0.0)) form_en = [cnv_hull.get_formation_energy(c, e) for c, e in zip(conc_per_frame, e_pred.tolist())] # Draw the CE energies, and fetch the lines, in case we need them for interactive plotting. _, ce_lines = cnv_hull.plot( fig=fig, concs=concs, energies=form_en, marker="x", return_lines=True, ) fig.suptitle("Convex hull DFT (o), CE (x)") if interactive: from clease.interactive_plot import AnnotatedAx, ShowStructureOnClick ax_list = fig.get_axes() annotations = _make_annotations_hull(evaluate) # Construct the annotated axis objects. annotated_axes = [] for ii, ax in enumerate(ax_list): data_points = (dft_lines[ii], ce_lines[ii]) annotated_ax = AnnotatedAx( ax, data_points, annotations, structure_names=[evaluate.names, evaluate.names], ) annotated_axes.append(annotated_ax) ShowStructureOnClick(fig, annotated_axes, evaluate.settings.db_name) return fig
def _make_annotations_hull(evaluate: Evaluate) -> Tuple[List[str], List[str]]: """Helper function to make annotations for interactive plots.""" e_pred = evaluate.get_energy_predict() e_dft = evaluate.get_energy_true() def format_annotation_dft(idx): name = evaluate.names[idx] row_id = evaluate.row_ids[idx] en = e_dft[idx] return f"DB ID: {row_id}\nName: {name}\nE(DFT): {en:.4f} eV/atom" def format_annotation_ce(idx): name = evaluate.names[idx] row_id = evaluate.row_ids[idx] e_ce = e_pred[idx] return f"DB ID: {row_id}\nName: {name}\nE(CE): {e_ce:.4f} eV/atom" N = len(evaluate.names) an_dft = [format_annotation_dft(idx) for idx in range(N)] an_ce = [format_annotation_ce(idx) for idx in range(N)] return an_dft, an_ce def _make_annotations_plot_fit(evaluate: Evaluate) -> Tuple[List[str], List[str]]: """Helper function to make annotations for interactive plots.""" e_pred = evaluate.get_energy_predict() e_dft = evaluate.get_energy_true() def format_annotation(idx): name = evaluate.names[idx] row_id = evaluate.row_ids[idx] en = e_dft[idx] e_ce = e_pred[idx] return f"DB ID: {row_id}\nName: {name}\nE(DFT): {en:.4f} eV/atom\nE(CE): {e_ce:.4f} eV/atom" N = len(e_pred) return ([format_annotation(idx) for idx in range(N)],)