"""Module for calculating correlation functions."""
from typing import Iterator, Tuple, Dict, Any
import logging
from ase.atoms import Atoms
from clease_cxx import PyCEUpdater
from .settings import ClusterExpansionSettings
from .tools import wrap_and_sort_by_position
from . import db_util
logger = logging.getLogger(__name__)
__all__ = ("CorrFunction", "ClusterNotTrackedError")
# Type alias for a Correlation function
CF_T = Dict[str, float]
class ClusterNotTrackedError(Exception):
"""A cluster is not being tracked"""
[docs]class CorrFunction:
"""Class for calculating the correlation functions.
Parameters:
settings (ClusterExpansionSettings): The settings object which defines the
cluster expansion parameters.
"""
def __init__(self, settings: ClusterExpansionSettings):
self.settings = settings
@property
def settings(self) -> ClusterExpansionSettings:
return self._settings
@settings.setter
def settings(self, value: Any) -> None:
if not isinstance(value, ClusterExpansionSettings):
raise TypeError(f"Setting must be a ClusterExpansionSettings object, got {value!r}")
self._settings = value
def connect(self, **kwargs):
return self.settings.connect(**kwargs)
[docs] def get_cf(self, atoms) -> CF_T:
"""
Calculate correlation functions for all possible clusters and return
them in a dictionary format.
Parameters:
atoms (Atoms): The atoms object
"""
if not isinstance(atoms, Atoms):
raise TypeError("atoms must be an Atoms object")
cf_names = self.settings.all_cf_names
return self.get_cf_by_names(atoms, cf_names)
[docs] def get_cf_by_names(self, atoms, cf_names) -> CF_T:
"""
Calculate correlation functions of the specified clusters and return
them in a dictionary format.
Parameters:
atoms: Atoms object
cf_names: list
names of correlation functions that will be calculated for
the structure provided in atoms
"""
if isinstance(atoms, Atoms):
self.set_template(atoms)
else:
raise TypeError("atoms must be Atoms object")
self._confirm_cf_names_exists(cf_names)
eci = {name: 1.0 for name in cf_names}
cf = {name: 1.0 for name in cf_names}
updater = PyCEUpdater(atoms, self.settings, cf, eci, self.settings.cluster_list)
cf = updater.calculate_cf_from_scratch(atoms, cf_names)
return cf
[docs] def reconfigure_single_db_entry(self, row_id: int) -> None:
"""Reconfigure a single DB entry. Assumes this is the initial structure,
and will not check that.
Parameters:
row_id: int
The ID of the row to be reconfigured.
"""
with self.connect() as db:
atoms = wrap_and_sort_by_position(db.get(id=row_id).toatoms())
cf = self.get_cf(atoms)
db_util.update_table(db, row_id, self.cf_table_name, cf)
@property
def cf_table_name(self) -> str:
"""Name of the table which holds the correlation functions."""
return f"{self.settings.basis_func_type.name}_cf"
[docs] def clear_cf_table(self) -> None:
"""Delete the external table which holds the correlation functions."""
with self.connect() as db:
db.delete_external_table(self.cf_table_name)
[docs] def check_consistency_of_cf_table_entries(self):
"""Get IDs of the structures with inconsistent correlation functions.
Note: consisent structures have the exactly the same list of cluster
names as stored in settings.cf_names.
"""
db = self.connect()
tab_name = self.cf_table_name
cf_names = sorted(self.settings.all_cf_names)
inconsistent_ids = []
for row in db.select("struct_type=initial"):
tab_entries = row.get(tab_name, {})
row_cnames = sorted(list(tab_entries.keys()))
if row_cnames != cf_names:
inconsistent_ids.append(row.id)
if len(inconsistent_ids) > 0:
logger.warning(
"%d inconsistent entries found in table %s",
len(inconsistent_ids),
tab_name,
)
for bad_id in inconsistent_ids:
logger.warning(" id: %s, name: %s", bad_id, db.get(bad_id).name)
else:
logger.info("'%s' table has no inconsistent entries.", tab_name)
return inconsistent_ids
[docs] def set_template(self, atoms: Atoms) -> None:
"""Check the size of provided cell and set as the currently active
template in the settings object.
Parameters:
atoms (Atoms):
Unrelaxed structure
"""
self.settings.set_active_template(atoms=atoms)
def _cf_name_exists(self, cf_name):
"""Return True if cluster name exists. Otherwise False.
Parameters:
cluster_name: str
Cluster name to check
"""
return cf_name in self.settings.all_cf_names
def _confirm_cf_names_exists(self, cf_names):
if not set(cf_names).issubset(self.settings.all_cf_names):
raise ClusterNotTrackedError(
"The correlation function of non-existing cluster is "
"requested, but the name does not exist in "
"ClusterExpansionSettings. Check that the cutoffs are "
"correct, and try to run reconfigure_settings"
)
def format_selection(select_cond=None, default_struct_type="initial"):
"""DB selection formatter. Will default to selecting
all initial structures if None is specified."""
select = []
if select_cond is not None:
for cond in select_cond:
select.append(cond)
else:
select = [("struct_type", "=", default_struct_type)]
return select