# 文件名: visualization.py
# 作者: Gilbert Young
# 修改日期: 2025-03-27
# 文件描述: 实现晶体结构和应力-应变关系的可视化工具。
"""
可视化模块
包含 Visualizer 类,用于可视化晶胞结构和应力-应变关系
"""
import logging
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
from sklearn.metrics import r2_score
from thermoelasticsim.core.structure import Cell
# 配置日志记录
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
plt.ioff() # 关闭交互模式,避免 GUI 启动警告
[文档]
class Visualizer:
"""晶体结构和数据可视化工具类"""
[文档]
def __init__(self):
"""初始化可视化工具类"""
pass
def _ensure_directory_exists(self, filepath):
"""
确保文件目录存在
Parameters
----------
filepath : str
文件路径
"""
directory = os.path.dirname(filepath)
if directory:
os.makedirs(directory, exist_ok=True)
logger.debug(f"Ensured directory exists: {directory}")
[文档]
def plot_cell_structure(self, cell_structure: Cell, show=True):
"""
绘制晶体结构的 3D 图形
Parameters
----------
cell_structure : Cell
包含原子的 Cell 实例
show : bool, optional
是否立即显示图形,默认为 True
Returns
-------
fig : matplotlib.figure.Figure
绘制的图形对象
ax : matplotlib.axes._subplots.Axes3DSubplot
3D 子图对象
"""
logger.info("Plotting crystal structure.")
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
# 检查原子是否都是同一种元素
atom_symbols = {atom.symbol for atom in cell_structure.atoms}
single_element = len(atom_symbols) == 1
color = "b" if single_element else None # 设置为蓝色,如果单一元素
if single_element:
# 如果只有一种元素,绘制为单一颜色
positions = np.array([atom.position for atom in cell_structure.atoms])
ax.scatter(
positions[:, 0],
positions[:, 1],
positions[:, 2],
color=color,
label=next(iter(atom_symbols)),
s=50,
)
logger.debug(
f"Plotted atoms with single element: {next(iter(atom_symbols))}"
)
else:
# 不同元素使用不同颜色和标签
for atom in cell_structure.atoms:
ax.scatter(
atom.position[0],
atom.position[1],
atom.position[2],
label=atom.symbol,
s=50,
)
logger.debug("Plotted atoms with multiple elements.")
# 绘制晶格矢量
origin = np.array([0, 0, 0])
lattice_vectors = cell_structure.lattice_vectors
for i in range(3):
vec = lattice_vectors[i]
ax.quiver(
*origin, *vec, color="r", arrow_length_ratio=0.1, linewidth=1.5
) # 使用箭头表示晶格矢量
logger.debug(f"Plotted lattice vector v{i + 1}: {vec}")
# 设置坐标轴标签
ax.set_xlabel("X (Å)")
ax.set_ylabel("Y (Å)")
ax.set_zlabel("Z (Å)")
plt.title("Crystal Structure")
# 只有在多种元素情况下才显示图例
if not single_element:
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles, strict=False))
ax.legend(by_label.values(), by_label.keys())
# 设置相同的轴比例
self.set_axes_equal(ax)
# 显示或返回图形对象
if show:
plt.show()
logger.info("Crystal structure plot completed.")
return fig, ax
[文档]
def set_axes_equal(self, ax):
"""
使3D坐标轴比例相同
Parameters
----------
ax : matplotlib.axes._subplots.Axes3DSubplot
3D 子图对象
"""
logger.debug("Setting equal aspect ratio for axes.")
limits = np.array(
[
ax.get_xlim3d(),
ax.get_ylim3d(),
ax.get_zlim3d(),
]
)
origin = np.min(limits, axis=1)
size = np.max(limits, axis=1) - origin
max_size = max(size)
origin -= (max_size - size) / 2
ax.set_xlim3d(origin[0], origin[0] + max_size)
ax.set_ylim3d(origin[1], origin[1] + max_size)
ax.set_zlim3d(origin[2], origin[2] + max_size)
logger.debug(f"Axes limits set to origin: {origin}, size: {max_size}")
[文档]
def plot_stress_strain_multiple(
self, strain_data: np.ndarray, stress_data: np.ndarray, show=True
):
"""
绘制多个应力-应变关系图,每个应力分量对应一个子图
Parameters
----------
strain_data : numpy.ndarray
应变数据,形状为 (N, 6)
stress_data : numpy.ndarray
应力数据,形状为 (N, 6)
show : bool, optional
是否立即显示图形,默认为 True
Returns
-------
fig : matplotlib.figure.Figure
绘制的图形对象
axes : numpy.ndarray
子图对象数组
"""
logger.info("Plotting multiple stress-strain relationships.")
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()
# 使用颜色映射
cmap = plt.get_cmap("tab10")
components = ["11", "22", "33", "23", "13", "12"]
for i in range(6):
ax = axes[i]
ax.scatter(
strain_data[:, i],
stress_data[:, i],
color=cmap(i),
s=50,
label=f"Component {components[i]}",
alpha=0.7,
)
ax.set_xlabel(f"Strain {components[i]}")
ax.set_ylabel(f"Stress {components[i]} (GPa)")
ax.set_title(f"Stress {components[i]} vs Strain {components[i]}")
ax.legend()
ax.grid(True)
logger.debug(f"Plotted stress-strain for component {components[i]}.")
plt.tight_layout()
if show:
plt.show()
logger.info("Multiple stress-strain plots completed.")
return fig, axes
[文档]
def create_optimization_animation(
self, trajectory, filename, title="Optimization", pbc=True, show=True
):
"""
创建优化过程的动画
Parameters
----------
trajectory : list of dict
记录的轨迹数据,每个元素包含 'positions', 'volume', 'lattice_vectors'
filename : str
保存动画的文件名(包含路径)
title : str, optional
动画标题,默认为 "Optimization"
pbc : bool, optional
是否应用周期性边界条件,默认为 True
show : bool, optional
是否立即显示动画,默认为 True
"""
logger.info(f"Creating optimization animation: {filename}")
self._ensure_directory_exists(filename)
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
# 初始化图像
scatter = ax.scatter([], [], [], color="b", s=20)
quivers = [
ax.quiver(0, 0, 0, 0, 0, 0, color="r", arrow_length_ratio=0.1)
for _ in range(3)
]
text_volume = ax.text2D(0.05, 0.95, "", transform=ax.transAxes)
text_lattice = ax.text2D(0.05, 0.90, "", transform=ax.transAxes)
text_atoms = ax.text2D(0.05, 0.85, "", transform=ax.transAxes)
# 同时创建一个日志文件
log_filename = os.path.splitext(filename)[0] + "_trajectory.log"
with open(log_filename, "w") as f:
for i, frame in enumerate(trajectory):
f.write(f"Frame {i}:\n")
f.write(f"Positions:\n{frame['positions']}\n")
f.write(f"Volume: {frame['volume']:.3f}\n")
f.write(f"Lattice vectors:\n{frame['lattice_vectors']}\n")
f.write("-" * 50 + "\n")
logger.debug(f"Saved trajectory log to {log_filename}")
def init():
scatter._offsets3d = ([], [], [])
for quiver in quivers:
quiver.remove()
quivers.clear()
for _ in range(3):
quivers.append(
ax.quiver(0, 0, 0, 0, 0, 0, color="r", arrow_length_ratio=0.1)
)
text_volume.set_text("")
text_lattice.set_text("")
text_atoms.set_text("")
return scatter, *quivers, text_volume, text_lattice, text_atoms
def update(frame):
data = trajectory[frame]
positions = np.array(data["positions"])
volume = data["volume"]
lattice_vectors = np.array(data["lattice_vectors"])
num_atoms = positions.shape[0]
scatter._offsets3d = (positions[:, 0], positions[:, 1], positions[:, 2])
# 更新晶格矢量箭头
for i, vec in enumerate(lattice_vectors.T):
quivers[i].remove()
quivers[i] = ax.quiver(
0, 0, 0, vec[0], vec[1], vec[2], color="r", arrow_length_ratio=0.1
)
# 更新文本信息,使用透明背景框减少重叠
text_volume.set_text(f"Volume: {volume:.2f} ų")
text_volume.set_position((0.05, 1.08)) # 将位置向上移出绘图区域
text_volume.set_bbox(dict(facecolor="white", alpha=0.6, edgecolor="none"))
lattice_str = "\n".join(
[
f"v{i + 1}: [{vec[0]:.2f}, {vec[1]:.2f}, {vec[2]:.2f}]"
for i, vec in enumerate(lattice_vectors.T)
]
)
text_lattice.set_text(f"Lattice Vectors:\n{lattice_str}")
text_lattice.set_position((0.05, 1.02)) # 调整 lattice 矢量文本框位置
text_lattice.set_bbox(dict(facecolor="white", alpha=0.6, edgecolor="none"))
text_atoms.set_text(f"Number of Atoms: {num_atoms}")
text_atoms.set_position((0.05, 0.96)) # 调整原子数信息的位置
text_atoms.set_bbox(dict(facecolor="white", alpha=0.6, edgecolor="none"))
# 保持标题设置不变
ax.set_title(f"{title} - Step {frame + 1}/{len(trajectory)}")
# 更新坐标轴范围
all_positions = positions
max_range = np.ptp(all_positions, axis=0).max()
mid_points = np.mean(all_positions, axis=0)
ax.set_xlim(mid_points[0] - max_range / 2, mid_points[0] + max_range / 2)
ax.set_ylim(mid_points[1] - max_range / 2, mid_points[1] + max_range / 2)
ax.set_zlim(mid_points[2] - max_range / 2, mid_points[2] + max_range / 2)
return scatter, *quivers, text_volume, text_lattice, text_atoms
ani = FuncAnimation(
fig,
update,
frames=len(trajectory),
init_func=init,
blit=False,
interval=200,
)
# 保存动画
ani.save(filename, writer=PillowWriter(fps=5))
logger.info(f"Saved optimization animation to {filename}")
if show:
plt.show()
plt.close(fig)
[文档]
def create_stress_strain_animation(
self, strain_data, stress_data, filename, show=True
):
"""
创建应力-应变关系的动画
Parameters
----------
strain_data : numpy.ndarray
应变数据,形状为 (N, 6)
stress_data : numpy.ndarray
应力数据,形状为 (N, 6)
filename : str
保存动画的文件名(支持 .gif, .mp4 等格式)
show : bool, optional
是否立即显示动画,默认为 True
"""
logger.info(f"Creating stress-strain animation: {filename}")
self._ensure_directory_exists(filename)
fig, ax = plt.subplots(figsize=(10, 6))
components = ["11", "22", "33", "23", "13", "12"]
lines = [ax.plot([], [], label=f"Stress {comp}")[0] for comp in components]
ax.set_xlim(strain_data.min() * 1.05, strain_data.max() * 1.05)
ax.set_ylim(stress_data.min() * 1.05, stress_data.max() * 1.05)
ax.set_xlabel("Strain")
ax.set_ylabel("Stress (GPa)")
ax.set_title("Stress-Strain Relationship Animation")
ax.legend()
ax.grid(True)
def init():
for line in lines:
line.set_data([], [])
return lines
def update(frame):
for i, line in enumerate(lines):
line.set_data(strain_data[:frame, i], stress_data[:frame, i])
ax.set_title(
f"Stress-Strain Relationship Animation - Frame {frame}/{len(strain_data)}"
)
return lines
ani = FuncAnimation(
fig,
update,
frames=len(strain_data) + 1,
init_func=init,
blit=False,
interval=100,
)
plt.tight_layout()
ani.save(filename, writer=PillowWriter(fps=5))
logger.info(f"Saved stress-strain relationship animation to {filename}")
if show:
plt.show()
plt.close(fig)
[文档]
def create_simulation_animation(self, trajectory_data, filename, interval=100):
"""
创建分子动力学模拟过程的动画
Parameters
----------
trajectory_data : list of dict
轨迹数据列表,每个字典包含:
- positions: array_like, 原子位置
- cell: Cell object, 晶胞对象
- time: float, 时间
- temperature: float, 温度
- energy: float, 能量
filename : str
保存动画的文件名
interval : int, optional
帧之间的时间间隔(毫秒),默认100
"""
fig = plt.figure(figsize=(15, 8))
# 创建子图
ax_structure = fig.add_subplot(121, projection="3d")
ax_props = fig.add_subplot(122)
# 初始化原子散点图
scatter = ax_structure.scatter([], [], [], c="b", marker="o")
# 初始化晶格箭头
quivers = [ax_structure.quiver(0, 0, 0, 0, 0, 0, color="r") for _ in range(3)]
# 初始化属性图
(temp_line,) = ax_props.plot([], [], "r-", label="Temperature (K)")
(energy_line,) = ax_props.plot([], [], "b-", label="Energy (eV)")
# 存储时间数据
times = [frame["time"] for frame in trajectory_data]
temperatures = [frame["temperature"] for frame in trajectory_data]
energies = [frame["energy"] for frame in trajectory_data]
def init():
scatter._offsets3d = ([], [], [])
temp_line.set_data([], [])
energy_line.set_data([], [])
return scatter, temp_line, energy_line
def update(frame):
# 更新结构视图
positions = trajectory_data[frame]["positions"]
cell = trajectory_data[frame]["cell"]
# 更新原子位置
scatter._offsets3d = (positions[:, 0], positions[:, 1], positions[:, 2])
# 更新晶格向量
for i, quiver in enumerate(quivers):
quiver.remove()
vec = cell.lattice_vectors[:, i]
quivers[i] = ax_structure.quiver(
0, 0, 0, vec[0], vec[1], vec[2], color="r"
)
# 更新属性图
temp_line.set_data(times[:frame], temperatures[:frame])
energy_line.set_data(times[:frame], energies[:frame])
return scatter, temp_line, energy_line, *quivers
# 设置坐标轴标签
ax_structure.set_xlabel("X (Å)")
ax_structure.set_ylabel("Y (Å)")
ax_structure.set_zlabel("Z (Å)")
ax_props.set_xlabel("Time (ps)")
ax_props.set_ylabel("Value")
ax_props.legend()
ax_props.grid(True)
# 创建动画
anim = FuncAnimation(
fig,
update,
frames=len(trajectory_data),
init_func=init,
blit=True,
interval=interval,
)
# 保存动画
anim.save(filename, writer="pillow")
plt.close()