Source code for clease.settings.settings

"""Definition of ClusterExpansionSettings Class.

This module defines the base-class for storing the settings for performing
Cluster Expansion in different conditions.
"""
from __future__ import annotations
import logging
from copy import deepcopy
from typing import List, Dict, Optional, Union, Sequence, Set, Any

import numpy as np
from ase import Atoms
from ase.db import connect
from ase.db.core import Database

from clease.version import __version__
from clease.jsonio import jsonable
from clease.tools import get_size_from_cf_name, wrap_and_sort_by_position
from clease.cluster import ClusterManager, ClusterList, Cluster
from clease.basis_function import Polynomial, Trigonometric, BinaryLinear, BasisFunction
from clease.datastructures import TransMatrix
from .concentration import Concentration
from .template_filters import ValidConcentrationFilter
from .template_atoms import TemplateAtoms
from .atoms_manager import AtomsManager


__all__ = ("ClusterExpansionSettings",)

logger = logging.getLogger(__name__)


class PrimitiveCellNotFound(Exception):
    """Exception which is raised if the primitive isn't found in the database."""


[docs]@jsonable("ce_settings") class ClusterExpansionSettings: """Base class for all Cluster Expansion settings. Args: prim (Atoms): The primitive atoms object. concentration (Union[Concentration, dict]): Concentration object or dictionary specifying the basis elements and concentration range of constituting species. size (List[int] | None, optional): Size of the supercell (e.g., [2, 2, 2] for 2x2x2 cell). ``supercell_factor`` is ignored if both ``size`` and ``supercell_factor`` are specified. Defaults to None. supercell_factor (int, optional): Maximum multipilicity factor for limiting the size of supercell created from the primitive cell. ``supercell_factor`` is ignored if both `size` and `supercell_factor` are specified. Defaults to 27. db_name (str, optional): Name of the database file. Defaults to ``'clease.db'``. max_cluster_dia (Sequence[float], optional): A list of int or float containing the maximum diameter of clusters (in Å). Defaults to ``(5., 5., 5.)``, i.e. a 5 Å cutoff for 2-, 3-, and 4-body clusters. include_background_atoms (bool, optional): Whether background elements are to be included. An element is considered to be a background element, if there is only 1 possible species which be ever be placed in a given basis. Defaults to False. basis_func_type (str, optional): Type of basis function to use. Defaults to 'polynomial'. """ # pylint: disable=too-many-instance-attributes, too-many-public-methods # Keys which are important for saving/loading ARG_KEYS = ("prim_cell", "concentration") KWARG_KEYS = ( "size", "supercell_factor", "db_name", "max_cluster_dia", "include_background_atoms", "basis_func_type", ) # Other keys we want to input back into the loaded object, but after initialization OTHER_KEYS = ( "skew_threshold", # kwargs is a bookkeeping variable, for compatibility. # Just contains information about how the factory functions were called. "kwargs", ) def __init__( self, prim: Atoms, concentration: Union[Concentration, dict], size: Optional[List[int]] = None, supercell_factor: Optional[int] = 27, db_name: str = "clease.db", max_cluster_dia: Sequence[float] = (5.0, 5.0, 5.0), include_background_atoms: bool = False, basis_func_type="polynomial", ) -> None: self._include_background_atoms = include_background_atoms self._cluster_mng = None self._trans_matrix = None self._cluster_list = None self._basis_func_type = None self._prim_cell = None self.concentration = _get_concentration(concentration) self.basis_elements = deepcopy(self.concentration.basis_elements) self._check_first_elements() self.db_name = db_name self._set_prim_cell(prim) prim_mng = AtomsManager(self.prim_cell) prim_ind_by_basis = prim_mng.index_by_symbol([x[0] for x in self.basis_elements]) conc_filter = ValidConcentrationFilter(concentration, prim_ind_by_basis) self.atoms_mng = AtomsManager(None) # Construct the template atoms, this is a protected property # Access and changes to size and/or supercell factor are redirected to # this instance. self._template_atoms = TemplateAtoms( self.prim_cell, supercell_factor=supercell_factor, size=size, skew_threshold=40, filters=[conc_filter], ) self.set_active_template() self.max_cluster_dia = max_cluster_dia self.basis_func_type = basis_func_type if len(self.basis_elements) != self.num_basis: raise ValueError("list of elements is needed for each basis") # For storing the settings from the CLEASE factories. # Just for bookkeeping self.kwargs = {} @property def atoms(self) -> Atoms: """The currently active template.""" return self.atoms_mng.atoms @property def prim_cell(self) -> Atoms: """The primitive atoms object of the model.""" return self._prim_cell def _set_prim_cell(self, value: Atoms) -> None: """Set the primitive cell, ensure it gets properly tagged and stored. This should not be changed after initialization.""" self._prim_cell = value.copy() self._prim_cell.wrap() self._order_and_tag_prim_cell() self._ensure_primitive_exists() @property def max_cluster_dia(self) -> np.ndarray: """The maximum cluster diameter, expressed in a NumPy array starting from 2-body clusters at index 0. Diameters are given in units of Ångstrom. """ return self._max_cluster_dia @max_cluster_dia.setter def max_cluster_dia(self, value) -> None: # Ensure it's NumPy array, and make a copy. self._max_cluster_dia = np.array(value, dtype=float) self.clear_cache() @property def db_name(self) -> str: """Name of the underlaying data base.""" return self._db_name @db_name.setter def db_name(self, value: str) -> str: """Changing the DB name, needs to ensure the primitive cell exists in the new database as well.""" self._db_name = value if self.prim_cell is not None: # Ensure the primitive cell is stored in the new database. # None should only happen during initialization, where we let the # prim_cell setter store the primitive. self._ensure_primitive_exists()
[docs] def connect(self, **kwargs) -> Database: """Return the ASE connection object to the internal database.""" return connect(self.db_name, **kwargs)
@property def template_atoms(self) -> TemplateAtoms: return self._template_atoms @property def max_cluster_size(self): return len(self.max_cluster_dia) + 1 @property def all_elements(self) -> List[str]: return sorted([item for row in self.basis_elements for item in row]) @property def num_elements(self) -> int: return len(self.all_elements) @property def unique_elements(self) -> List[str]: return sorted(list(set(deepcopy(self.all_elements)))) @property def num_unique_elements(self) -> int: return len(self.unique_elements) @property def ref_index_trans_symm(self) -> List[int]: return [i[0] for i in self.index_by_sublattice] @property def skew_threshold(self): return self.template_atoms.skew_threshold @skew_threshold.setter def skew_threshold(self, threshold: int) -> None: """ Maximum acceptable skew level (ratio of max and min diagonal of the Niggli reduced cell) """ self.template_atoms.skew_threshold = threshold @property def background_indices(self) -> List[int]: """Get indices of the background atoms.""" # check if any basis consists of only one element type basis = [i for i, b in enumerate(self.basis_elements) if len(b) == 1] bkg_indices = [] for b_indx in basis: bkg_indices += self.index_by_basis[b_indx] return bkg_indices @property def non_background_indices(self) -> List[int]: """Indices of sites which are not background""" bkg = set(self.background_indices) all_indices = set(range(len(self.atoms))) # Remove all background indices return sorted(all_indices - bkg) @property def cluster_mng(self): if self._cluster_mng is None: kwargs = {} if not self.include_background_atoms: kwargs["background_syms"] = self.get_bg_syms() self._cluster_mng = ClusterManager(self.prim_cell, **kwargs) return self._cluster_mng @property def include_background_atoms(self) -> bool: return self._include_background_atoms @include_background_atoms.setter def include_background_atoms(self, value: bool) -> None: # pylint: disable=no-self-use msg = "The include_background_atoms setter has been removed in version 0.11.3.\n" msg += f"Please set 'include_background_atoms={value}' in the settings constructor, " msg += "instead." raise NotImplementedError(msg) @property def spin_dict(self) -> Dict[str, float]: return self.basis_func_type.spin_dict @property def basis_functions(self): return self.basis_func_type.basis_functions @property def ignore_background_atoms(self) -> bool: return not self.include_background_atoms @property def multiplicity_factor(self) -> Dict[str, float]: """Return the multiplicity factor of each cluster.""" num_sites_in_group = [len(x) for x in self.index_by_sublattice] return self.cluster_list.multiplicity_factors(num_sites_in_group) @property def all_cf_names(self) -> List[str]: num_bf = len(self.basis_functions) return self.cluster_list.get_all_cf_names(num_bf) @property def num_cf(self) -> int: """Return the number of correlation functions.""" return len(self.all_cf_names) @property def index_by_basis(self) -> List[List[int]]: first_symb_in_basis = [x[0] for x in self.basis_elements] return self.atoms_mng.index_by_symbol(first_symb_in_basis) @property def index_by_sublattice(self) -> List[List[int]]: return self.atoms_mng.index_by_tag() @property def num_basis(self) -> int: return len(self.basis_elements) @property def basis_func_type(self): return self._basis_func_type @property def size(self) -> Union[np.ndarray, None]: return self.template_atoms.size @size.setter def size(self, value): """Exposure of the template atoms size setter""" self.template_atoms.size = value @property def supercell_factor(self) -> Union[int, None]: return self.template_atoms.supercell_factor @supercell_factor.setter def supercell_factor(self, value) -> None: """Exposure of the template atoms supercell_factor setter""" self.template_atoms.supercell_factor = value
[docs] def get_sublattice_site_ratios(self) -> np.ndarray: """Return the ratios of number of sites per (grouped) sublattice""" # Number of sites per sublattice sites_per_basis = np.array([len(basis) for basis in self.index_by_basis]) num_sites = sites_per_basis.sum() return sites_per_basis / num_sites
@property def num_active_sublattices(self) -> int: """Number of active sublattices""" return sum(self.get_active_sublattices())
[docs] def get_active_sublattices(self) -> List[bool]: """List of booleans indicating if a (grouped) sublattice is active""" unique_no_bkg = self.unique_element_without_background() return [basis[0] in unique_no_bkg for basis in self.concentration.basis_elements]
@property def ignored_species_and_conc(self) -> Dict[str, float]: """ Return the ignored species and their concentrations normalised to the total number of atoms. """ unique_no_bkg = self.unique_element_without_background() # Find the concentration within grouped sublattices orig_basis = self.concentration.basis_elements # Concentration of sites within each sublattice ratios = self.get_sublattice_site_ratios() assert len(orig_basis) == len(ratios) ignored = {} for ratio, basis in zip(ratios, orig_basis): elem = basis[0] if elem not in unique_no_bkg: if len(basis) != 1: raise ValueError( ( "Ignored sublattice contains multiple elements -" "this does not make any sense" ) ) if elem not in ignored: ignored[elem] = ratio else: # This element is already on one of the ignored background here we # accumulate the concentration ignored[elem] += ratio return ignored @property def atomic_concentration_ratio(self) -> float: """ Ratio between true concentration (normalised to atoms) and the internal concentration used. For example, if one of the two basis is fully occupied, and hence ignored internally, the internal concentration is half of the actual atomic concentration. """ ratios = self.get_sublattice_site_ratios() # We only want to include active sublattices active_sublatt = self.get_active_sublattices() # Add up all ratios for the active sublattices return ratios[active_sublatt].sum() @basis_func_type.setter def basis_func_type(self, bf_type): """ Type of basis function to use. It should be one of "polynomial", "trigonometric" or "binary_linear" """ unique_element = self.unique_element_without_background() if isinstance(bf_type, BasisFunction): if bf_type.unique_elements != sorted(unique_element): raise ValueError( "Unique elements in BasisFunction instance " "is different from the one in settings", bf_type.unique_elements, sorted(unique_element), ) self._basis_func_type = bf_type elif isinstance(bf_type, str): if bf_type.lower() == "polynomial": self._basis_func_type = Polynomial(unique_element) elif bf_type.lower() == "trigonometric": self._basis_func_type = Trigonometric(unique_element) elif bf_type.lower() == "binary_linear": self._basis_func_type = BinaryLinear(unique_element) else: msg = f"basis function type {bf_type} is not supported." raise ValueError(msg) else: raise ValueError("basis_function has to be an instance of BasisFunction or a string")
[docs] def get_bg_syms(self) -> Set[str]: """ Return the symbols in the basis where there is only one element """ return set(x[0] for x in self.basis_elements if len(x) == 1)
[docs] def unique_element_without_background(self): """Remove background elements.""" if self.include_background_atoms: bg_sym = set() else: bg_sym = self.get_bg_syms() # Remove bg_syms that are also present in basis with more than one # element for elems in self.basis_elements: if len(elems) == 1: continue to_be_removed = set() for s in bg_sym: if s in elems: to_be_removed.add(s) bg_sym -= to_be_removed unique_elem = set() for x in self.basis_elements: unique_elem.update(x) return list(unique_elem - bg_sym)
[docs] def prepare_new_active_template(self, template): """Prepare necessary data structures when setting new template.""" logger.debug("Preparing new template in settings") self.atoms_mng.atoms = template self.clear_cache()
[docs] def set_active_template(self, atoms=None): """Set a new template atoms object.""" if atoms is not None: template = self.template_atoms.get_template_matching_atoms(atoms=atoms) else: template = self.template_atoms.weighted_random_template() template = wrap_and_sort_by_position(template) if atoms is not None: # Check that the positions of the generated template # matches the ones in the passed object atoms = wrap_and_sort_by_position(atoms) if not np.allclose(template.get_positions(), atoms.get_positions()): raise ValueError( f"Inconsistent positions. Passed object\n" f"{atoms.get_positions()}\nGenerated template" f"\n{template.get_positions()}" ) self.prepare_new_active_template(template)
def _order_and_tag_prim_cell(self): """ Add a tag to all the atoms in the unit cell to track the sublattice. Tags are added such that that the lowest tags corresponds to "active" sites, while the highest tags corresponds to background sites. An example is a system having three sublattices that can be occupied by more than one species, and two sublattices that can be occupied by only one species. The tags 0, 1, 2 will then be assigned to the three sublattices that can be occupied by more than one species, and the two remaining lattices will get the tag 3, 4. Re-orders the primitive cell in accordance to the order of the tags. """ bg_sym = self.get_bg_syms() tag = 0 # Tag non-background elements first for atom in self.prim_cell: if atom.symbol not in bg_sym: atom.tag = tag tag += 1 # Tag background elements for atom in self.prim_cell: if atom.symbol in bg_sym: atom.tag = tag tag += 1 # Rearange primitive cell in order of the tags. sorted_indx = np.argsort([atom.tag for atom in self.prim_cell]) # Set the primitive cell directly, without calling the setter. self._prim_cell = self.prim_cell[sorted_indx] def _ensure_primitive_exists(self) -> None: """Ensure that the primitive cell exists in the DB. Write it if it is missing.""" self.get_prim_cell_id(write_if_missing=True)
[docs] def get_prim_cell_id(self, write_if_missing=False) -> int: """Retrieve the ID of the primitive cell in the database. Raises a PrimitiveCellNotFound error if it is not found and write_if_missing is False. If ``write_if_missing`` is True a primitive cell is written to the database if it is missing. Returns the ID (an integer) of the row which corresponds to the primitive cell. """ with self.connect() as con: try: # Check if the primitive has already been written to the database uid = _get_prim_cell_id_from_connection(self.prim_cell, con) except PrimitiveCellNotFound: # Primitive wasn't found if write_if_missing: # Write it to the database. uid = con.write(self.prim_cell, name="primitive_cell") else: # We're not allowed to write to the database. Raise the error. raise # Ensure connection is closed before returning, to ensure that the primitive # has been written. return uid
@property def trans_matrix(self) -> TransMatrix: """Get the translation matrix, will be created upon request""" self.ensure_clusters_exist() return self._trans_matrix @property def cluster_list(self) -> ClusterList: """Get the cluster list, will be created upon request""" self.ensure_clusters_exist() return self._cluster_list
[docs] def clear_cache(self) -> None: """Clear the cached objects, due to a change e.g. in the template atoms""" logger.debug("Clearing the cache") self._trans_matrix = None self._cluster_list = None
[docs] def create_cluster_list_and_trans_matrix(self): """Prepares the internal cache objects by calculating cluster related properties""" logger.debug("Creating translation matrix and cluster list") self._cluster_list, self._trans_matrix = self.cluster_mng.build_all( self.atoms, self.max_cluster_dia, self.index_by_sublattice, )
[docs] def view_clusters(self) -> None: """Display all clusters along with their names.""" # pylint: disable=import-outside-toplevel from ase.gui.gui import GUI from ase.gui.images import Images figures = self.get_all_figures_as_atoms() images = Images() images.initialize(figures) gui = GUI(images, expr="") gui.show_name = True gui.run()
[docs] def get_all_figures_as_atoms(self) -> List[Atoms]: """Get the list of all possible figures, in their ASE Atoms representation.""" self.ensure_clusters_exist() return self.cluster_mng.get_figures()
[docs] def requires_build(self) -> bool: """Check if the cluster list and trans matrix exist. Returns True the cluster list and trans matrix needs to be built. """ return self._cluster_list is None or self._trans_matrix is None
[docs] def ensure_clusters_exist(self) -> None: """Ensure the cluster list and trans matrix has been populated. They are not calculated upon creaton of the settings instance, for performance reasons. They will be constructed if required. Nothing is done if the cache exists.""" if self.requires_build(): logger.debug("Triggered construction of clusters") self.create_cluster_list_and_trans_matrix() # It should not be possible for them to be None after this call. assert self._cluster_list is not None assert self._trans_matrix is not None
[docs] def get_all_templates(self): """ Return a list with all template atoms """ return self.template_atoms.get_all_templates()
[docs] def view_templates(self): """ Display all templates in the ASE GUi """ # pylint: disable=import-outside-toplevel from ase.visualize import view view(self.get_all_templates())
def _check_first_elements(self): basis_elements = self.basis_elements num_basis = self.num_basis # This condition can be relaxed in the future first_elements = [] for elements in basis_elements: first_elements.append(elements[0]) if len(set(first_elements)) != num_basis: raise ValueError("First element of different basis should not be the same.")
[docs] def todict(self) -> Dict: """Return a dictionary representation of the settings class. Example: >>> from clease.settings import CEBulk, Concentration >>> conc = Concentration([['Au', 'Cu']]) >>> settings = CEBulk(conc, crystalstructure='fcc', a=4.1) >>> dct = settings.todict() # Get the dictionary representation """ vars_to_save = self.ARG_KEYS + self.KWARG_KEYS + self.OTHER_KEYS dct = {"clease_version": str(__version__)} for key in vars_to_save: val = getattr(self, key) dct[key] = val return dct
[docs] @classmethod def from_dict(cls, dct: Dict[str, Any]) -> ClusterExpansionSettings: """Load a new ClusterExpansionSettings class from a dictionary representation. Example: >>> from clease.settings import CEBulk, Concentration, ClusterExpansionSettings >>> conc = Concentration([['Au', 'Cu']]) >>> settings = CEBulk(conc, crystalstructure='fcc', a=4.1) >>> dct = settings.todict() # Get the dictionary representation >>> # Remove the existing settings, perhaps due to being in a new environment >>> del settings >>> # Load in the settins from the dictionary representation >>> settings = ClusterExpansionSettings.from_dict(dct) """ dct = deepcopy(dct) # pop out the version. We can use this for compatibility checks if needed. dct.pop("clease_version", None) # Get the args and kwargs we expect in our function signature args = [dct.pop(key) for key in cls.ARG_KEYS] # Should we allow for missing kwargs keys? kwargs = {key: dct.pop(key) for key in cls.KWARG_KEYS} if "max_cluster_size" in dct: # For compatibility, since it's no longer in KWARG_KEYS # Avoid crashing. dct.pop("max_cluster_size") settings = cls(*args, **kwargs) for key in cls.OTHER_KEYS: # Populate "other" keys, which aren't *args or **kwargs value = dct.pop(key) setattr(settings, key, value) if dct: # We have some unexpected left-overs logger.warning("Unused items from dictionary: %s", dct) return settings
[docs] def clusters_table(self) -> str: """String with information about the clusters""" # We need the clusters to be constructed first. self.ensure_clusters_exist() mult_dict = self.multiplicity_factor columns = [ "Index", "Cluster Name", "Size", "Group", "Radius", "Figures", "Multiplicity", ] fmt = "| {:<5} | {:<12} | {:<4} | {:<5} | {:<6} | {:<7} | {:<12} |" header = fmt.format(*columns) rule = "-" * len(header) # Horizontal line of ----- lines = [rule, header, rule] for ii, cluster in enumerate(self.cluster_list.clusters): name = cluster.name size = f"{cluster.size:d}".center(4, " ") mult = f"{int(mult_dict[name]):d}".center(12, " ") radius = f"{cluster.diameter / 2:2.4f}" n_figures = f"{len(cluster.figures)}".center(7) group = f"{cluster.group:d}".center(5) index_s = f"{ii:d}".center(5) s = fmt.format(index_s, name, size, group, radius, n_figures, mult) lines.append(s) # Bottom horizontal line lines.append(rule) return "\n".join(lines)
[docs] def get_cluster_corresponding_to_cf_name(self, cf_name: str) -> Cluster: """Find the Cluster object which corresponds to a CF name. The cluster will not be specialized to the decoration number if such exists in the cf name. Example: >>> from clease.settings import CEBulk, Concentration >>> conc = Concentration([['Au', 'Cu']]) >>> settings = CEBulk(conc, crystalstructure='fcc', a=4.1) >>> cluster = settings.get_cluster_corresponding_to_cf_name("c1_0") >>> cluster.size 1 """ size = get_size_from_cf_name(cf_name) for cluster in self.cluster_list.clusters: # This CF name may contain the decoration number, but the cluster name doesn't. # Decoration number is last, so test everything before that is equal. if cluster.size == size and cf_name.startswith(cluster.name): return cluster raise RuntimeError(f"Didn't find cluster corresponding to name: {cf_name}")
def _get_concentration(concentration: Union[Concentration, dict]) -> Concentration: """Helper function to format the concentration""" if isinstance(concentration, Concentration): conc = concentration elif isinstance(concentration, dict): conc = Concentration.from_dict(concentration) else: raise TypeError("concentration has to be either dict or instance of Concentration") return conc def _get_prim_cell_id_from_connection(prim_cell: Atoms, connection: Database) -> int: """Retrieve the primitive cell ID from a database connection""" shape = prim_cell.cell.cellpar() for row in connection.select(name="primitive_cell"): loaded_shape = row.toatoms().cell.cellpar() if np.allclose(shape, loaded_shape): return row.id raise PrimitiveCellNotFound("The primitive cell was not found in the database.")