thermoelasticsim.utils.utils 源代码

# 文件名: utils.py
# 作者: Gilbert Young
# 修改日期: 2025-07-11
# 文件描述: 实现张量转换工具类和数据收集类,并定义了一些常用的单位转换常量。

"""
工具模块

包含 TensorConverter 类用于张量与 Voigt 表示之间的转换,DataCollector 类用于收集模拟过程中的数据,
及一些常用的单位转换常量
"""

import logging
from concurrent.futures import ThreadPoolExecutor

import h5py
import numpy as np

logger = logging.getLogger(__name__)

# 保持原有的单位转换常量
AMU_TO_EVFSA2 = 104.3968445
EV_TO_GPA = 160.2176634
KB_IN_EV = 8.617332385e-5


[文档] class TensorConverter: """张量转换工具类,支持应力和应变张量的 Voigt 与 3x3 矩阵表示之间的相互转换。"""
[文档] @staticmethod def to_voigt(tensor: np.ndarray, tensor_type: str, tol: float = 1e-8) -> np.ndarray: """ 将一个对称的 3x3 张量转换为 Voigt 表示的 6 元素向量。 Parameters ---------- tensor : np.ndarray 形状为 (3, 3) 的张量。 tensor_type : str 张量类型,必须是 'stress' 或 'strain'。 tol : float, optional 用于检查张量对称性的容差。如果非对称分量差异大于此值,将记录警告。 Returns ------- np.ndarray 形状为 (6,) 的 Voigt 表示向量。 """ if tensor.shape != (3, 3): raise ValueError(f"输入张量必须是 3x3 矩阵,但得到形状 {tensor.shape}") if tensor_type not in {"stress", "strain"}: raise ValueError( f"无效的张量类型 '{tensor_type}',必须是 'stress' 或 'strain'" ) # 检查对称性 if not np.allclose(tensor, tensor.T, atol=tol): logger.warning("输入张量不对称。将使用其对称部分进行计算。") tensor = 0.5 * (tensor + tensor.T) factor = 2.0 if tensor_type == "strain" else 1.0 return np.array( [ tensor[0, 0], tensor[1, 1], tensor[2, 2], tensor[1, 2] * factor, tensor[0, 2] * factor, tensor[0, 1] * factor, ] )
[文档] @staticmethod def from_voigt(voigt: np.ndarray, tensor_type: str) -> np.ndarray: """ 将 Voigt 表示的 6 元素向量转换为 3x3 的对称张量。 Parameters ---------- voigt : np.ndarray 形状为 (6,) 的 Voigt 向量。 tensor_type : str 张量类型,必须是 'stress' 或 'strain'。 Returns ------- np.ndarray 形状为 (3, 3) 的对称张量。 """ if voigt.shape != (6,): raise ValueError( f"输入 Voigt 向量必须有 6 个元素,但得到 {voigt.shape[0]} 个" ) if tensor_type not in {"stress", "strain"}: raise ValueError( f"无效的张量类型 '{tensor_type}',必须是 'stress' 或 'strain'" ) factor = 0.5 if tensor_type == "strain" else 1.0 return np.array( [ [voigt[0], voigt[5] * factor, voigt[4] * factor], [voigt[5] * factor, voigt[1], voigt[3] * factor], [voigt[4] * factor, voigt[3] * factor, voigt[2]], ] )
[文档] class DataCollector: """ 数据收集器,用于记录和分析模拟轨迹 提供了系统状态的收集、统计和保存功能,支持异步数据处理和自动保存。 Parameters ---------- capacity : int, optional 存储容量限制,None 表示无限制 save_interval : int, optional 数据自动保存间隔,默认为 1000 步 use_threading : bool, optional 是否使用线程进行数据处理,默认为 False output_file : str, optional 数据保存的文件名,默认为 'trajectory.h5' """
[文档] def __init__( self, capacity: int | None = None, save_interval: int = 1000, use_threading: bool = False, output_file: str = "trajectory.h5", ): self.data = [] self._capacity = capacity self.save_interval = save_interval self.use_threading = use_threading self.output_file = output_file self._executor = ThreadPoolExecutor() if use_threading else None # 初始化统计数据 self.stats = { "temperature": [], "potential_energy": [], "kinetic_energy": [], "total_energy": [], "volume": [], "pressure": [], } logger.debug( f"Initialized DataCollector with capacity={capacity}, " f"save_interval={save_interval}, use_threading={use_threading}" )
[文档] def __del__(self): """确保线程池正确关闭""" if self._executor: self._executor.shutdown() logger.debug("Shut down thread executor")
[文档] def collect(self, cell, potential=None): """ 收集完整的系统状态数据 Parameters ---------- cell : Cell 当前的晶胞对象 potential : Potential, optional 势能对象,用于计算系统能量 """ if self.use_threading: self._executor.submit(self._collect_async, cell, potential) else: self._collect_sync(cell, potential)
def _collect_sync(self, cell, potential=None): """同步收集数据""" try: state = self._create_state_dict(cell, potential) self._process_state(state) except Exception as e: logger.error(f"Error collecting data: {e}") raise def _collect_async(self, cell, potential=None): """异步收集数据""" try: state = self._create_state_dict(cell, potential) self._process_state(state) except Exception as e: logger.error(f"Error in async data collection: {e}") def _create_state_dict(self, cell, potential=None) -> dict: """创建系统状态字典""" state = { "step": len(self.data), "time": len(self.data) * cell.dt if hasattr(cell, "dt") else None, "positions": np.array([atom.position.copy() for atom in cell.atoms]), "velocities": np.array([atom.velocity.copy() for atom in cell.atoms]), "forces": np.array([atom.force.copy() for atom in cell.atoms]), "cell_vectors": cell.lattice_vectors.copy(), "volume": cell.volume, "temperature": cell.calculate_temperature(), } if potential: try: state["potential_energy"] = potential.calculate_energy(cell) except Exception as e: logger.warning(f"Could not calculate potential energy: {e}") return state def _process_state(self, state: dict): """处理并存储状态数据""" if self._capacity and len(self.data) >= self._capacity: self.data.pop(0) self.data.append(state) self._update_stats(state) # 检查是否需要自动保存 if len(self.data) % self.save_interval == 0: self.save_trajectory() def _update_stats(self, state: dict): """更新统计数据""" for key in self.stats: if key in state: self.stats[key].append(state[key])
[文档] def save_trajectory(self): """将轨迹数据保存到 HDF5 文件""" try: with h5py.File(self.output_file, "w") as f: # 保存轨迹数据 for i, state in enumerate(self.data): grp = f.create_group(f"step_{i}") for key, value in state.items(): if isinstance(value, np.ndarray): grp.create_dataset(key, data=value) else: grp.attrs[key] = value # 保存统计数据 stats_grp = f.create_group("statistics") for key, value in self.stats.items(): if value: # 只保存非空数据 stats_grp.create_dataset(key, data=np.array(value)) logger.info(f"Trajectory saved to {self.output_file}") except Exception as e: logger.error(f"Error saving trajectory: {e}") raise
[文档] class NeighborList: """ 邻居列表类,用于生成和维护原子的邻居列表 提供高效的邻居搜索和更新功能,支持周期性边界条件。 Parameters ---------- cutoff : float 截断半径,单位为 Å skin : float, optional 皮肤厚度,默认为 0.3 Å """
[文档] def __init__(self, cutoff: float, skin: float = 0.3): if not isinstance(cutoff, int | float) or cutoff <= 0: raise ValueError("Cutoff must be a positive number") if not isinstance(skin, int | float) or skin < 0: raise ValueError("Skin must be non-negative") self.cutoff = float(cutoff) self.skin = float(skin) self.cutoff_with_skin = cutoff + skin self.neighbor_list = None self.last_positions = None self.cell = None self.last_update_step = 0 # 性能统计 self._build_count = 0 self._update_count = 0 self._last_build_time = 0 logger.debug(f"Initialized NeighborList with cutoff={cutoff}, skin={skin}")
[文档] def get_neighbor_stats(self) -> dict: """ 返回邻居列表的统计信息 Returns ------- dict 包含最小、最大、平均邻居数等统计信息的字典 """ if not self.neighbor_list: return {} neighbor_counts = [len(neighbors) for neighbors in self.neighbor_list] return { "min_neighbors": min(neighbor_counts), "max_neighbors": max(neighbor_counts), "avg_neighbors": np.mean(neighbor_counts), "std_neighbors": np.std(neighbor_counts), "build_count": self._build_count, "update_count": self._update_count, "last_build_time": self._last_build_time, }
def _validate_cutoff(self, box_size: np.ndarray): """验证截断半径的合理性""" min_box_length = np.min(box_size) if self.cutoff_with_skin > min_box_length / 2: logger.warning( f"Cutoff radius ({self.cutoff_with_skin:.3f}) is too large " f"compared to box size ({min_box_length / 2:.3f})" ) # 禁用自动调整以保持cutoff一致性(修复C44计算问题) logger.warning( "Cutoff adjustment disabled to maintain consistency with EAM potential. " "This may affect neighbor list accuracy but preserves virial theorem." ) # 注释掉自动调整代码 # self.cutoff = min_box_length / 3 # self.cutoff_with_skin = self.cutoff + self.skin # logger.info(f"Adjusted cutoff radius to {self.cutoff:.3f}") def _compute_optimal_grid_size(self, box_size: np.ndarray, num_atoms: int) -> float: """计算最优网格大小""" volume = np.prod(box_size) density = num_atoms / volume mean_atomic_spacing = (1 / density) ** (1 / 3) return max(self.cutoff + 0.5, mean_atomic_spacing)
[文档] def build(self, cell): """构建邻居列表""" try: import time start_time = time.time() positions = cell.get_positions() num_atoms = cell.num_atoms box_size = cell.get_box_lengths() cutoff = self.cutoff_with_skin self.cell = cell self._validate_cutoff(box_size) # 根据系统大小选择构建方法 if num_atoms < 64: self._build_brute_force(cell, positions, num_atoms, cutoff) else: self._build_with_grid(cell, positions, num_atoms, box_size, cutoff) self.last_positions = positions.copy() self._build_count += 1 self._last_build_time = time.time() - start_time logger.debug( f"Built neighbor list for {num_atoms} atoms in {self._last_build_time:.3f}s" ) except Exception as e: logger.error(f"Error building neighbor list: {e}") raise
def _build_brute_force(self, cell, positions, num_atoms, cutoff): """使用双重循环构建小系统的邻居列表。""" self.neighbor_list = [[] for _ in range(num_atoms)] cutoff_squared = cutoff**2 for i in range(num_atoms): for j in range(i + 1, num_atoms): rij = positions[j] - positions[i] # 应用最小镜像原则 if cell.pbc_enabled: rij = cell.minimum_image(rij) distance_squared = np.dot(rij, rij) if distance_squared < cutoff_squared: self.neighbor_list[i].append(j) self.neighbor_list[j].append(i) def _build_with_grid(self, cell, positions, num_atoms, box_size, cutoff): """改进的基于网格的邻居列表构建。""" # 验证截断半径 self._validate_cutoff(box_size) # 计算最优网格大小 grid_size = self._compute_optimal_grid_size(box_size, num_atoms) grid_dims = np.maximum(np.floor(box_size / grid_size).astype(int), 1) # 初始化网格和邻居列表 grid = {} self.neighbor_list = [[] for _ in range(num_atoms)] cutoff_squared = cutoff * cutoff # 将原子分配到网格 for idx, pos in enumerate(positions): grid_idx = tuple(((pos / grid_size) % grid_dims).astype(int)) grid.setdefault(grid_idx, []).append(idx) # 构建邻居列表 for grid_idx, atom_indices in grid.items(): # 获取相邻网格(考虑周期性边界条件) neighbor_cells = self._get_neighbor_cells(grid_idx, grid_dims) for i in atom_indices: pos_i = positions[i] # 检查相邻网格中的原子 for neighbor_idx in neighbor_cells: for j in grid.get(neighbor_idx, []): if j <= i: continue rij = positions[j] - pos_i if cell.pbc_enabled: rij = cell.minimum_image(rij) dist_squared = np.dot(rij, rij) if dist_squared < cutoff_squared: self.neighbor_list[i].append(j) self.neighbor_list[j].append(i) def _get_neighbor_cells(self, grid_idx, grid_dims): """获取相邻网格索引(考虑周期性边界条件)""" neighbor_cells = [] for dx in [-1, 0, 1]: for dy in [-1, 0, 1]: for dz in [-1, 0, 1]: if dx == dy == dz == 0: continue neighbor_idx = tuple( (np.array(grid_idx) + [dx, dy, dz]) % grid_dims ) neighbor_cells.append(neighbor_idx) return neighbor_cells
[文档] def need_refresh(self): """ 判断是否需要更新邻居列表。 Returns ------- bool 如果需要更新,返回 True;否则返回 False。 """ if self.last_positions is None: return True positions = self.cell.get_positions() displacements = positions - self.last_positions if self.cell.pbc_enabled: # 考虑 PBC 下的位移 displacements = np.array( [self.cell.minimum_image(disp) for disp in displacements] ) max_displacement = np.max(np.linalg.norm(displacements, axis=1)) return max_displacement > (self.skin * 0.5)
[文档] def update(self): """更新邻居列表,如果需要的话。""" if self.need_refresh(): self.build(self.cell)
[文档] def get_neighbors(self, atom_index): """ 获取指定原子的邻居列表。 Parameters ---------- atom_index : int 原子的索引。 Returns ------- list of int 邻居原子的索引列表。 """ if self.neighbor_list is None: self.build(self.cell) return self.neighbor_list[atom_index]
[文档] def debug_neighbor_distribution(self): """分析和打印邻居分布情况""" logger = logging.getLogger(__name__) if self.neighbor_list is None: logger.warning("Neighbor list not built yet") return neighbor_counts = [len(neighbors) for neighbors in self.neighbor_list] logger.info("Neighbor distribution:") logger.info(f"Min neighbors: {min(neighbor_counts)}") logger.info(f"Max neighbors: {max(neighbor_counts)}") logger.info(f"Average neighbors: {np.mean(neighbor_counts):.2f}") # 检查边界原子 positions = self.cell.get_positions() box_lengths = self.cell.get_box_lengths() for i, pos in enumerate(positions): is_boundary = any( abs(p / l - 0.5) > 0.35 for p, l in zip(pos, box_lengths, strict=False) ) if is_boundary: logger.info(f"Boundary atom {i}:") logger.info(f" Position: {pos}") logger.info(f" Number of neighbors: {len(self.neighbor_list[i])}") logger.info(f" Neighbors: {self.neighbor_list[i]}")
[文档] def get_atomic_mass(symbol: str) -> float: """ 根据原子符号获取原子质量(amu) Parameters ---------- symbol : str 原子符号,例如 'Al', 'Si' Returns ------- float 原子质量(amu) Raises ------ KeyError 如果原子符号不存在 """ atomic_masses = { "Al": 26.9815386, "Si": 28.0855, "C": 12.0107, # ... 其他元素 } try: return atomic_masses[symbol] except KeyError as e: raise KeyError(f"Atomic mass for symbol '{symbol}' not found.") from e