"""CIF I/O helpers used when building CLEASE settings.
The reader focuses on structure parsing (cell, basis, spacegroup metadata).
Concentration constraints are still meant to be defined by the user.
"""
from __future__ import annotations
from pathlib import Path
import re
from typing import Any
import warnings
from ase import Atoms
from ase.io import read
from clease.settings import CECrystal, Concentration
from clease.tools import wrap_and_sort_by_position
__all__ = ("cecrystal_from_cif", "read_cif")
# Match the numeric prefix of a CIF occupancy token.
# This accepts integers, decimals, and scientific notation (e.g. "1", "0.5", ".25", "1e-3").
# CIF values may include uncertainty suffixes like "0.5(2)"; we intentionally capture only
# the leading numeric part ("0.5") and ignore the parenthesized uncertainty.
_NUM_RE = re.compile(r"^\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")
[docs]
def cecrystal_from_cif(
cif_path: str | Path,
concentration: Concentration,
*,
index: int | str = 0,
infer_spacegroup: bool = True,
symprec: float = 1e-3,
spacegroup_override: int | str | None = None,
**ce_kwargs: Any,
):
"""Build `CECrystal` from a CIF file.
This is the convenience wrapper for CIF-driven settings creation.
The CIF file contributes geometry (cell + basis positions) and
spacegroup information. Concentration must be passed explicitly.
This function does not infer concentration constraints from CIF occupancy.
For disordered systems, first call ``read_cif(...)`` to inspect symbols and
metadata, then build ``Concentration(basis_elements=..., grouped_basis=...)``
from your model assumptions, and finally call
``cecrystal_from_cif(..., concentration=con, spacegroup_override=...)``.
Args:
cif_path: Path to input CIF file.
concentration: User-defined concentration model for all basis sites.
index: CIF structure index passed to ASE read.
infer_spacegroup: Try spglib inference when CIF info is incomplete.
symprec: Tolerance used by spglib inference.
spacegroup_override: Explicit SG value to use instead of inferred/parsed.
**ce_kwargs: Extra keyword args forwarded to `CECrystal`.
Returns:
Cluster expansion settings object created by `CECrystal(...)`.
Raises:
ValueError: If no valid spacegroup is available and no override is given.
"""
# Read structure + metadata first so all downstream inputs come from one place.
atoms, metadata = read_cif(
path=cif_path,
index=index,
infer_spacegroup=infer_spacegroup,
symprec=symprec,
)
# Build the crystallographic inputs expected by CECrystal.
basis = atoms.get_scaled_positions().tolist()
cellpar = metadata["cellpar"]
spacegroup = (
metadata["spacegroup_number"] if spacegroup_override is None else spacegroup_override
)
if spacegroup is None:
raise ValueError("Could not determine spacegroup from CIF; pass spacegroup_override.")
return CECrystal(
concentration=concentration,
spacegroup=spacegroup,
basis=basis,
cellpar=cellpar,
**ce_kwargs,
)
[docs]
def read_cif(
path: str | Path,
index: int | str = 0,
infer_spacegroup: bool = True,
symprec: float = 1e-3,
) -> tuple[Atoms, dict[str, Any]]:
"""Read one CIF structure and return `(atoms, metadata)`.
The returned `Atoms` is normalized with CLEASE's wrapped/sorted convention.
Metadata is intended to support manual concentration setup before calling
`cecrystal_from_cif(...)`.
Metadata includes ``cellpar``, ``spacegroup_number``,
``spacegroup_symbol``, occupancy diagnostics, and collected warnings.
This reader does not translate partial occupancy into concentration
constraints. Use the metadata as guidance and build ``Concentration``
explicitly in your setup code.
Args:
path: Path to input CIF file.
index: CIF structure index passed to ASE read.
infer_spacegroup: Will try spglib inference when CIF info is incomplete.
symprec: Tolerance used by spglib inference.
Returns:
Tuple of `(atoms, metadata)`.
Raises:
ValueError: If CIF reading fails.
"""
path_obj = Path(path).expanduser()
try:
atoms = read(str(path), format="cif", index=index)
except Exception as exc: # pragma: no cover - ASE runtime specific
raise ValueError(f"Could not read CIF '{path}': {exc}") from exc
if not isinstance(atoms, Atoms):
raise ValueError(f"CIF read returned {type(atoms)}; provide an index for one structure.")
# Keep Atoms ordering/wrapping consistent with CLEASE conventions.
atoms = wrap_and_sort_by_position(atoms)
metadata = _build_metadata(
atoms=atoms,
path=path_obj,
index=index,
infer_spacegroup=infer_spacegroup,
symprec=symprec,
)
_emit_warnings(metadata)
return atoms, metadata
def _build_metadata(
atoms: Atoms,
path: Path,
index: int | str,
infer_spacegroup: bool,
symprec: float,
) -> dict[str, Any]:
"""Assemble metadata used for setup, checks, and user feedback."""
warnings_list: list[str] = []
# Occupancy is tracked for reporting and strict-mode checks.
occupancy = _extract_occupancy_info(atoms=atoms, path=path)
# Prefer explicit spacegroup info from ASE; infer only when needed.
sg_num, sg_sym = _spacegroup_from_info(atoms.info)
if infer_spacegroup and (sg_num is None or sg_sym is None):
sg_num_inf, sg_sym_inf, sg_warn = _infer_spacegroup(atoms=atoms, symprec=symprec)
if sg_warn is not None:
warnings_list.append(sg_warn)
if sg_num is None:
sg_num = sg_num_inf
if sg_sym is None:
sg_sym = sg_sym_inf
if occupancy["has_partial_occupancy"]:
warnings_list.append(
"CIF has partial occupancy; define concentration constraints explicitly."
)
metadata = {
"source_path": str(path),
"index": index,
"natoms": len(atoms),
"cellpar": [float(x) for x in atoms.cell.cellpar()],
"spacegroup_number": sg_num,
"spacegroup_symbol": sg_sym,
"occupancy": occupancy,
"has_partial_occupancy": occupancy["has_partial_occupancy"],
"warnings": warnings_list,
}
return metadata
def _spacegroup_from_info(info: dict[str, Any]) -> tuple[int | None, str | None]:
"""Get spacegroup number/symbol from `atoms.info` when available."""
sg_num: int | None = None
sg_sym: str | None = None
sg_obj = info.get("spacegroup")
if sg_obj is not None:
no_attr = getattr(sg_obj, "no", None)
sym_attr = getattr(sg_obj, "symbol", None)
if no_attr is not None:
sg_num = int(no_attr)
if sym_attr is not None:
sg_sym = str(sym_attr)
if sg_num is None:
raw_num = info.get("spacegroup_number")
if raw_num is not None:
try:
sg_num = int(raw_num)
except (TypeError, ValueError):
pass
if sg_sym is None:
raw_sym = info.get("spacegroup_symbol")
if raw_sym is not None:
sg_sym = str(raw_sym)
return sg_num, sg_sym
def _infer_spacegroup(atoms: Atoms, symprec: float) -> tuple[int | None, str | None, str | None]:
"""Infer spacegroup with spglib.
Returns `(number, symbol, warning_text)`
"""
try:
import spglib
except Exception:
return None, None, "spglib not available; could not infer spacegroup"
try:
dataset = spglib.get_symmetry_dataset(
(atoms.cell.array, atoms.get_scaled_positions(), atoms.numbers),
symprec=symprec,
)
except Exception as exc: # pragma: no cover - spglib runtime specific
return None, None, f"spacegroup inference failed: {exc}"
if dataset is None:
return None, None, "spacegroup inference returned no dataset"
num = _dataset_get(dataset, "number")
sym = _dataset_get(dataset, "international")
sg_num = int(num) if num is not None else None
sg_sym = str(sym) if sym is not None else None
return sg_num, sg_sym, None
def _dataset_get(dataset: Any, key: str) -> Any:
"""Read a key from dict-like or object-like spglib datasets."""
if isinstance(dataset, dict):
return dataset.get(key)
if hasattr(dataset, key):
return getattr(dataset, key)
return None
def _extract_occupancy_info(atoms: Atoms, path: Path) -> dict[str, Any]:
"""Extract occupancy values and detect partial occupancy.
Preferred source is `atoms.info['occupancy']`; CIF text parsing is a fallback.
"""
occupancy_map = atoms.info.get("occupancy")
values = _occupancies_from_info(occupancy_map)
source = "atoms.info" if values else "cif_text"
parse_warning: str | None = None
if not values:
try:
# Fallback path for CIFs where occupancy did not survive in atoms.info.
values = _read_occupancies_from_cif_text(path)
except Exception as exc: # pragma: no cover - parsing fallback path
parse_warning = f"could not parse occupancy values from CIF text: {exc}"
has_partial = any(abs(v - 1.0) > 1e-8 for v in values)
return {
"source": source,
"values": values,
"has_partial_occupancy": has_partial,
"warning": parse_warning,
}
def _occupancies_from_info(occupancy_map: Any) -> list[float]:
"""Extract occupancy values from `atoms.info['occupancy']`."""
if not isinstance(occupancy_map, dict):
return []
values: list[float] = []
for site_map in occupancy_map.values():
if not isinstance(site_map, dict):
continue
for occ in site_map.values():
try:
values.append(float(occ))
except (TypeError, ValueError):
continue
return values
def _read_occupancies_from_cif_text(path: Path) -> list[float]:
"""Parse `_atom_site_occupancy` values from CIF loop blocks."""
text = path.read_text(encoding="utf-8", errors="replace")
lines = text.splitlines()
values: list[float] = []
i = 0
while i < len(lines):
line = lines[i].strip()
if line != "loop_":
i += 1
continue
# Enter one CIF loop block and collect its column headers.
i += 1
headers: list[str] = []
while i < len(lines):
h = lines[i].strip()
if h.startswith("_"):
headers.append(h)
i += 1
continue
break
# Skip loops that do not define occupancy.
if not headers or "_atom_site_occupancy" not in headers:
continue
# Parse data rows using the occupancy column index from this loop.
occ_col = headers.index("_atom_site_occupancy")
while i < len(lines):
row = lines[i].strip()
# A blank/comment/new loop/new header means this loop's data ended.
if not row or row.startswith("#") or row.startswith("loop_") or row.startswith("_"):
break
tokens = row.split()
if len(tokens) > occ_col:
# Accept values like 0.5 and uncertainty forms like 0.5(2).
occ = _parse_float_with_uncertainty(tokens[occ_col])
if occ is not None:
values.append(occ)
i += 1
return values
def _parse_float_with_uncertainty(value: str) -> float | None:
"""Parse numeric tokens including uncertainty forms like `0.5(2)`."""
match = _NUM_RE.match(value)
if not match:
return None
try:
return float(match.group(1))
except ValueError:
return None
def _emit_warnings(metadata: dict[str, Any]) -> None:
"""Emit collected warnings."""
occ_warning = metadata.get("occupancy", {}).get("warning")
if occ_warning:
warnings.warn(occ_warning, stacklevel=2)
for msg in metadata.get("warnings", []):
warnings.warn(msg, stacklevel=2)