Visualization¶
MDTerp provides four built-in visualization functions for analyzing results.
Feature Importance Bar Plot¶
Plot mean feature importance as a horizontal bar chart for a transition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result_path |
str |
Path to MDTerp_results_all.pkl. |
required |
feature_names_path |
str |
Path to MDTerp_feature_names.npy. |
required |
transition |
str |
Transition key (e.g., "0_1"). |
required |
show_std |
bool |
Show standard deviation error bars (default: False). |
False |
top_n |
Optional[int] |
Only show the top N features. None shows all non-zero. |
None |
save_path |
Optional[str] |
If provided, save the figure to this path. |
None |
figsize |
tuple |
Figure size as (width, height) tuple. |
(8, 8) |
display_names |
Optional[dict] |
Optional dict mapping original feature names to display names (e.g., LaTeX-formatted strings). Names not in the dict are kept as-is. |
None |
Returns:
| Type | Description |
|---|---|
Figure |
matplotlib Figure object. |
Source code in MDTerp/visualization.py
def plot_feature_importance(
result_path: str,
feature_names_path: str,
transition: str,
show_std: bool = False,
top_n: Optional[int] = None,
save_path: Optional[str] = None,
figsize: tuple = (8, 8),
display_names: Optional[dict] = None,
) -> plt.Figure:
"""
Plot mean feature importance as a horizontal bar chart for a transition.
Args:
result_path: Path to MDTerp_results_all.pkl.
feature_names_path: Path to MDTerp_feature_names.npy.
transition: Transition key (e.g., "0_1").
show_std: Show standard deviation error bars (default: False).
top_n: Only show the top N features. None shows all non-zero.
save_path: If provided, save the figure to this path.
figsize: Figure size as (width, height) tuple.
display_names: Optional dict mapping original feature names to display
names (e.g., LaTeX-formatted strings). Names not in the dict are
kept as-is.
Returns:
matplotlib Figure object.
"""
feature_names = np.load(feature_names_path, allow_pickle=True)
summary = transition_summary(result_path, importance_coverage=1.0)
if transition not in summary:
raise ValueError(
f"Transition '{transition}' not found. "
f"Available: {list(summary.keys())}"
)
mean_imp = summary[transition][0]
std_imp = summary[transition][1]
nonzero_mask = mean_imp > 0
ordered_indices = np.argsort(mean_imp)[::-1]
ordered_indices = ordered_indices[nonzero_mask[ordered_indices]]
if top_n is not None:
ordered_indices = ordered_indices[:top_n]
ordered_mean = mean_imp[ordered_indices]
ordered_std = std_imp[ordered_indices]
ordered_names = _apply_display_names(feature_names[ordered_indices], display_names)
fig, ax = plt.subplots(figsize=figsize)
y_pos = np.arange(len(ordered_mean))
if show_std:
ax.barh(y_pos, ordered_mean, xerr=ordered_std, capsize=4,
color='steelblue', edgecolor='black', linewidth=0.5)
else:
ax.barh(y_pos, ordered_mean,
color='steelblue', edgecolor='black', linewidth=0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(ordered_names, fontsize=12)
ax.set_xlabel('Feature Importance', fontsize=14)
ax.set_title(
f'Feature Importance for Transition {transition}\n'
f'Coverage: {int(100 * np.sum(ordered_mean))}%',
fontsize=16,
)
ax.invert_yaxis()
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
Importance Heatmap¶
Plot a heatmap of feature importance across all transitions.
Rows are features, columns are transitions, color intensity represents mean importance. Only features with non-zero importance in at least one transition are shown.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result_path |
str |
Path to MDTerp_results_all.pkl. |
required |
feature_names_path |
str |
Path to MDTerp_feature_names.npy. |
required |
importance_coverage |
float |
Filter features by cumulative importance per transition (default: 0.8). |
0.8 |
save_path |
Optional[str] |
If provided, save the figure to this path. |
None |
figsize |
Optional[tuple] |
Figure size. Auto-scaled if None. |
None |
display_names |
Optional[dict] |
Optional dict mapping original feature names to display names (e.g., LaTeX-formatted strings). |
None |
Returns:
| Type | Description |
|---|---|
Figure |
matplotlib Figure object. |
Source code in MDTerp/visualization.py
def plot_importance_heatmap(
result_path: str,
feature_names_path: str,
importance_coverage: float = 0.8,
save_path: Optional[str] = None,
figsize: Optional[tuple] = None,
display_names: Optional[dict] = None,
) -> plt.Figure:
"""
Plot a heatmap of feature importance across all transitions.
Rows are features, columns are transitions, color intensity represents
mean importance. Only features with non-zero importance in at least one
transition are shown.
Args:
result_path: Path to MDTerp_results_all.pkl.
feature_names_path: Path to MDTerp_feature_names.npy.
importance_coverage: Filter features by cumulative importance
per transition (default: 0.8).
save_path: If provided, save the figure to this path.
figsize: Figure size. Auto-scaled if None.
display_names: Optional dict mapping original feature names to display
names (e.g., LaTeX-formatted strings).
Returns:
matplotlib Figure object.
"""
feature_names = np.load(feature_names_path, allow_pickle=True)
summary = transition_summary(result_path, importance_coverage=importance_coverage)
transitions = sorted(summary.keys())
n_features = len(feature_names)
imp_matrix = np.zeros((n_features, len(transitions)))
for j, trans in enumerate(transitions):
imp_matrix[:, j] = summary[trans][0]
active_mask = np.any(imp_matrix > 0, axis=1)
active_indices = np.where(active_mask)[0]
imp_matrix = imp_matrix[active_indices, :]
active_names = _apply_display_names(feature_names[active_indices], display_names)
if figsize is None:
figsize = (max(6, len(transitions) * 2), max(6, len(active_names) * 0.4))
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(imp_matrix, aspect='auto', cmap='YlOrRd', interpolation='nearest')
ax.set_xticks(np.arange(len(transitions)))
ax.set_xticklabels(transitions, fontsize=12, rotation=45, ha='right')
ax.set_yticks(np.arange(len(active_names)))
ax.set_yticklabels(active_names, fontsize=10)
ax.set_xlabel('Transition', fontsize=14)
ax.set_ylabel('Feature', fontsize=14)
ax.set_title('Feature Importance Across Transitions', fontsize=16)
for i in range(imp_matrix.shape[0]):
for j in range(imp_matrix.shape[1]):
val = imp_matrix[i, j]
if val > 0.01:
ax.text(j, i, f'{val:.2f}', ha='center', va='center',
fontsize=8, color='black' if val < 0.5 else 'white')
fig.colorbar(im, ax=ax, label='Importance', shrink=0.8)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
Unfaithfulness Curve¶
Plot the unfaithfulness vs number of features curve for a single point.
Shows how surrogate model quality improves as more features are included in the linear explanation, visualizing the TERP free energy trade-off.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result_dir |
str |
Directory containing per-point result files. |
required |
transition |
str |
Transition key (e.g., "0_1"). |
required |
point_index |
int |
Point index within the transition. |
required |
save_path |
Optional[str] |
If provided, save the figure to this path. |
None |
figsize |
tuple |
Figure size. |
(8, 5) |
Returns:
| Type | Description |
|---|---|
Figure |
matplotlib Figure object. |
Source code in MDTerp/visualization.py
def plot_unfaithfulness_curve(
result_dir: str,
transition: str,
point_index: int,
save_path: Optional[str] = None,
figsize: tuple = (8, 5),
) -> plt.Figure:
"""
Plot the unfaithfulness vs number of features curve for a single point.
Shows how surrogate model quality improves as more features are included
in the linear explanation, visualizing the TERP free energy trade-off.
Args:
result_dir: Directory containing per-point result files.
transition: Transition key (e.g., "0_1").
point_index: Point index within the transition.
save_path: If provided, save the figure to this path.
figsize: Figure size.
Returns:
matplotlib Figure object.
"""
filename = f"{transition}_point{point_index}_result.npz"
filepath = os.path.join(result_dir, filename)
if not os.path.exists(filepath):
raise FileNotFoundError(
f"Result file not found: {filepath}. "
f"Ensure keep_checkpoints=True when running MDTerp."
)
data = np.load(filepath, allow_pickle=True)
unfaithfulness = data['unfaithfulness_all']
fig, ax = plt.subplots(figsize=figsize)
k_values = np.arange(1, len(unfaithfulness) + 1)
ax.plot(k_values, unfaithfulness, 'o-', color='steelblue',
linewidth=2, markersize=6)
ax.set_xlabel('Number of Features (k)', fontsize=14)
ax.set_ylabel('Unfaithfulness (1 - |r|)', fontsize=14)
ax.set_title(
f'Unfaithfulness Curve\n'
f'Transition {transition}, Point {point_index}',
fontsize=16,
)
ax.set_xticks(k_values)
ax.grid(True, alpha=0.3)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
Per-Point Variability¶
Plot per-point importance variability for the top features in a transition.
Shows a strip plot where each dot represents one point's importance value for a given feature, revealing how consistent the explanations are.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result_path |
str |
Path to MDTerp_results_all.pkl. |
required |
feature_names_path |
str |
Path to MDTerp_feature_names.npy. |
required |
transition |
str |
Transition key (e.g., "0_1"). |
required |
top_n |
int |
Number of top features to show (default: 5). |
5 |
save_path |
Optional[str] |
If provided, save the figure to this path. |
None |
figsize |
tuple |
Figure size. |
(10, 6) |
display_names |
Optional[dict] |
Optional dict mapping original feature names to display names (e.g., LaTeX-formatted strings). |
None |
Returns:
| Type | Description |
|---|---|
Figure |
matplotlib Figure object. |
Source code in MDTerp/visualization.py
def plot_point_variability(
result_path: str,
feature_names_path: str,
transition: str,
top_n: int = 5,
save_path: Optional[str] = None,
figsize: tuple = (10, 6),
display_names: Optional[dict] = None,
) -> plt.Figure:
"""
Plot per-point importance variability for the top features in a transition.
Shows a strip plot where each dot represents one point's importance value
for a given feature, revealing how consistent the explanations are.
Args:
result_path: Path to MDTerp_results_all.pkl.
feature_names_path: Path to MDTerp_feature_names.npy.
transition: Transition key (e.g., "0_1").
top_n: Number of top features to show (default: 5).
save_path: If provided, save the figure to this path.
figsize: Figure size.
display_names: Optional dict mapping original feature names to display
names (e.g., LaTeX-formatted strings).
Returns:
matplotlib Figure object.
"""
feature_names = np.load(feature_names_path, allow_pickle=True)
with open(result_path, 'rb') as f:
all_results = pickle.load(f)
importances = []
for sample_idx, (trans, imp) in all_results.items():
if trans == transition:
importances.append(np.array(imp))
if not importances:
raise ValueError(
f"No results found for transition '{transition}'. "
f"Available: {list(set(v[0] for v in all_results.values()))}"
)
imp_array = np.array(importances)
mean_imp = np.mean(imp_array, axis=0)
top_indices = np.argsort(mean_imp)[::-1][:top_n]
fig, ax = plt.subplots(figsize=figsize)
positions = np.arange(top_n)
for i, feat_idx in enumerate(top_indices):
values = imp_array[:, feat_idx]
jitter = np.random.uniform(-0.15, 0.15, size=len(values))
ax.scatter(
positions[i] + jitter, values,
alpha=0.6, s=40, color='steelblue', edgecolors='black',
linewidth=0.5,
)
ax.scatter(positions[i], mean_imp[feat_idx],
marker='D', s=80, color='red', zorder=5)
ax.set_xticks(positions)
tick_names = _apply_display_names(feature_names[top_indices], display_names)
ax.set_xticklabels(tick_names, fontsize=12, rotation=30, ha='right')
ax.set_ylabel('Feature Importance', fontsize=14)
ax.set_title(
f'Per-Point Importance Variability\n'
f'Transition {transition} (red = mean)',
fontsize=16,
)
ax.grid(True, axis='y', alpha=0.3)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight')
return fig