import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Optional, cast
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.manifold import TSNE
from .correlation import cluster_pca_correlation
from ..utils import get_logger
logger = get_logger(__name__)
[docs]
def plot_variance_distribution(df: pd.DataFrame, bins: int = 50):
"""
Compute the variance for each feature (column) in the DataFrame and plot
a histogram of these variances.
Args:
df (pd.DataFrame): Input data.
bins (int): Number of bins for the histogram.
Returns:
matplotlib.figure.Figure: Generated figure.
"""
variances = df.var()
logger.info("Computed variances for each feature.")
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(variances, bins=bins, edgecolor='black')
ax.set_title("Distribution of Feature Variances")
ax.set_xlabel("Variance")
ax.set_ylabel("Frequency")
logger.info("Variance distribution plot generated.")
return fig
[docs]
def plot_variance_by_feature(df: pd.DataFrame):
"""
Plot the variance for each feature against its index or name.
Args:
df (pd.DataFrame): Input data.
Returns:
matplotlib.figure.Figure: Generated figure.
"""
variances = df.var()
logger.info("Computed variances for each feature for index plot.")
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(variances.index, variances.values, 'o', markersize=4)
ax.set_title("Variance per Feature")
ax.set_xlabel("Feature")
ax.set_ylabel("Variance")
ax.tick_params(axis='x', rotation=90)
logger.info("Variance vs. feature index plot generated.")
return fig
[docs]
def plot_embeddings(embeddings, node_labels=None, legend_labels=None):
"""
Plot the embeddings in 2D space using t-SNE.
Args:
embeddings (array-like): High-dimensional embedding data.
node_labels (array-like or DataFrame, optional): Labels for the nodes to color the points.
legend_labels (list, optional): Labels for the legend corresponding to unique node labels.
"""
X = np.array(embeddings)
perplexity = min(30, X.shape[0] - 1)
if perplexity < 1:
logger.info(f"Skipping plot: not enough samples ({X.shape[0]}) for TSNE.")
return
reducer = TSNE(n_components=2, init="pca", perplexity=perplexity)
X_reduced = reducer.fit_transform(X)
if node_labels is None:
c_values = np.zeros(X.shape[0])
elif hasattr(node_labels, "iloc"):
node_labels= node_labels.to_frame(name="phenotype")
c_values = np.array(node_labels.iloc[:, 0], dtype=float).flatten()
else:
c_values = np.array(node_labels, dtype=float).flatten()
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
X_reduced[:, 0], X_reduced[:, 1],
c=c_values,
cmap="viridis",
s=60,
alpha=0.9,
edgecolor="k"
)
if legend_labels is not None:
handles, _ = scatter.legend_elements(prop="colors")
ax.legend(handles, legend_labels, title="Omics Type", loc="best")
ax.invert_yaxis()
ax.set_title(f"Embeddings in 2D space from {embeddings.shape[1]}D")
fig.tight_layout()
plt.show()
[docs]
def plot_network(adjacency_matrix, weight_threshold=0.0, show_labels=False, show_edge_weights=False, layout="kamada"):
"""
Plots a network graph from an adjacency matrix with improved visualization.
Also adds a summary table mapping node indexes to actual gene names.
Args:
adjacency_matrix (pd.DataFrame): The adjacency matrix of the network.
weight_threshold (float): Minimum weight to keep an edge (default: 0.0).
show_labels (bool): Whether to show node labels.
show_edge_weights (bool): Whether to show edge weights
Returns:
pd.DataFrame: Mapping of node indexes to actual gene names.
"""
full_G = nx.from_pandas_adjacency(adjacency_matrix)
total_nodes = full_G.number_of_nodes()
total_edges = full_G.number_of_edges()
G = full_G.copy()
for _, _, d in G.edges(data=True):
d["abs_weight"] = abs(d.get("weight", 0.0))
if weight_threshold > 0:
edges_to_remove = []
for u, v, d in G.edges(data=True):
raw_w = d.get('weight', 0)
abs_w = abs(raw_w)
d['abs_weight'] = abs_w
if abs_w < weight_threshold:
edges_to_remove.append((u, v))
G.remove_edges_from(edges_to_remove)
isolated_nodes = list(nx.isolates(G))
G.remove_nodes_from(isolated_nodes)
current_nodes = list(G.nodes())
current_edges = G.number_of_edges()
index_mapping = {}
for i, node in enumerate(current_nodes):
index_mapping[node] = i + 1
indexed_labels = {}
for node in current_nodes:
indexed_labels[node] = str(index_mapping[node])
degrees = {}
for node, degree in G.degree():
degrees[node] = degree
max_degree = max(degrees.values()) if degrees else 1
node_sizes = []
for node in G.nodes():
node_sizes.append(150 + (degrees[node] / max_degree) * 300)
edge_weights = []
for u, v, d in G.edges(data=True):
edge_weights.append(d['abs_weight'])
edge_widths = []
if edge_weights:
min_w = edge_weights[0]
max_w = edge_weights[0]
for w in edge_weights:
if w < min_w:
min_w = w
if w > max_w:
max_w = w
span = max_w - min_w if max_w > min_w else 1e-6
for w in edge_weights:
width = 2 + 4 * (w - min_w) / span
edge_widths.append(width)
else:
edge_widths = []
if layout == "spectral":
pos = nx.spectral_layout(G)
elif layout == "spring":
pos = nx.spring_layout(G, weight="abs_weight", seed=117, iterations=50)
else:
pos = nx.kamada_kawai_layout(G, weight="abs_weight")
fig, ax_graph = plt.subplots(figsize=(14, 8))
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color="gold", edgecolors="black", linewidths=1.5, alpha=0.9, ax=ax_graph)
nx.draw_networkx_edges(G, pos, alpha=0.9, width=edge_widths, edge_color="black", ax=ax_graph)
if show_edge_weights and edge_weights:
edge_labels = nx.get_edge_attributes(G, 'weight')
formatted_edge_labels = {}
for edge, weight in edge_labels.items():
formatted_edge_labels[edge] = f"{weight:.4f}"
nx.draw_networkx_edge_labels(G, pos, edge_labels=formatted_edge_labels, font_size=9, ax=ax_graph)
if show_labels:
nx.draw_networkx_labels(G, pos, labels=indexed_labels, font_size=11, font_color="black", ax=ax_graph)
ax_graph.set_xticks([])
ax_graph.set_yticks([])
ax_graph.set_frame_on(False)
ax_graph.set_title("Network Visualization", fontsize=16)
ax_graph.axis("off")
summary_text = f"""
Full Cluster Nodes: {total_nodes}
Full Cluster Edges: {total_edges}
Filtered Nodes: {len(current_nodes)}
Filtered Edges: {current_edges}
"""
ax_graph.text(0.9, 1.05, summary_text, transform=ax_graph.transAxes, fontsize=14, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.show()
mapping_data = []
for node in current_nodes:
mapping_data.append((index_mapping[node], node, degrees[node]))
mapping_df = pd.DataFrame(mapping_data, columns=["Index", "Omic", "Degree"])
mapping_df = mapping_df.sort_values(by="Degree", ascending=False).set_index("Index")
return mapping_df
[docs]
def find_omics_modality(mapping_df, dfs, source_names=None):
"""
Maps features in the mapping DataFrame to their omics source based on provided dataframes.
Args:
mapping_df (pd.DataFrame): DataFrame with an "Omic" column listing feature names.
dfs (list[pd.DataFrame]): List of DataFrames, each representing an omics modality.
source_names (list[str], optional): Names corresponding to each DataFrame in `dfs`. If None, default names will be used.
Returns:
pd.DataFrame: Updated mapping DataFrame with an additional "Source" column.
"""
if not source_names:
source_names = []
for i, df in enumerate(dfs):
src_name = getattr(df, "name", None)
if src_name is None:
src_name = f"source_{i}"
source_names.append(src_name)
feature_to_source = {}
for src_name, df in zip(source_names, dfs):
for feat in df.columns:
feature_to_source[feat] = src_name
out = mapping_df.copy()
out["Source"] = out["Omic"].map(feature_to_source).fillna("Unknown")
return out
[docs]
def compare_clusters(clusters1: list, clusters2: list, pheno: pd.DataFrame, label1: str = "Method 1", label2: str = "Method 2") -> pd.DataFrame:
"""Compare two cluster sets via phenotype correlation.
Clusters for each method are sorted by size (number of columns), the top min(len(clusters1), len(clusters2)) are paired by rank, and phenotype correlation is computed for each using ``cluster_correlation``.
A line plot of correlation vs. cluster index is displayed for both methods.
Args:
clusters1 (list[pd.DataFrame]): Clusters for the first method; one DataFrame per cluster (samples x features).
clusters2 (list[pd.DataFrame]): Clusters for the second method; same format as ``clusters1``.
pheno (pd.DataFrame): Phenotype data; the first column is used as the target.
label1 (str): Label for the first method (used in legend and result columns).
label2 (str): Label for the second method (used in legend and result columns).
Returns:
pd.DataFrame: One row per paired cluster with cluster index, sizes and correlations for both methods.
"""
# sizes for method 1
clusters1_sizes = []
i = 0
while i < len(clusters1):
df = clusters1[i]
size = df.shape[1]
clusters1_sizes.append((i, size))
i += 1
# sizes for method 2
clusters2_sizes = []
i = 0
while i < len(clusters2):
df = clusters2[i]
size = df.shape[1]
clusters2_sizes.append((i, size))
i += 1
def get_size(pair):
return pair[1]
clusters1_sizes.sort(key=get_size, reverse=True)
clusters2_sizes.sort(key=get_size, reverse=True)
if len(clusters1_sizes) < len(clusters2_sizes):
k = len(clusters1_sizes)
else:
k = len(clusters2_sizes)
# top indices for each method
top_idx1 = []
i = 0
while i < k:
top_idx1.append(clusters1_sizes[i][0])
i += 1
top_idx2 = []
i = 0
while i < k:
top_idx2.append(clusters2_sizes[i][0])
i += 1
# top clusters for each method
top_clusters1 = []
i = 0
while i < len(top_idx1):
idx = top_idx1[i]
top_clusters1.append(clusters1[idx])
i += 1
top_clusters2 = []
i = 0
while i < len(top_idx2):
idx = top_idx2[i]
top_clusters2.append(clusters2[idx])
i += 1
clusters1 = top_clusters1
clusters2 = top_clusters2
# compute correlations
results = []
for i, (df1, df2) in enumerate(zip(clusters1, clusters2), start=1):
size1, corr1 = cluster_pca_correlation(df1, pheno)
size2, corr2 = cluster_pca_correlation(df2, pheno)
if corr1 is not None and corr2 is not None:
results.append((f"Cluster_{i}", size1, corr1, size2, corr2))
col_size1 = f"{label1} Size"
col_corr1 = f"{label1} Correlation"
col_size2 = f"{label2} Size"
col_corr2 = f"{label2} Correlation"
df_results = pd.DataFrame(results,columns=["Cluster", col_size1, col_corr1, col_size2, col_corr2])
# plotting
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(
df_results.index + 1,
df_results[col_corr1],
marker="o",
linestyle="-",
label=label1,
color="blue",
)
ax.plot(
df_results.index + 1,
df_results[col_corr2],
marker="s",
linestyle="--",
label=label2,
color="red",
)
for i, row in df_results.iterrows():
ax.text(
i + 1,
row[col_corr1] + 0.05,
f"{row[col_size1]}",
ha="center",
fontsize=10,
color="blue",
fontweight="bold",
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
ax.text(
i + 1,
row[col_corr2] + 0.05,
f"{row[col_size2]}",
ha="center",
fontsize=10,
color="red",
fontweight="bold",
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
ax.set_xticks(range(1, len(df_results) + 1))
ax.set_xlabel("Cluster Index")
ax.set_ylabel("Correlation")
ax.set_title(f"Cluster correlation: {label1} vs {label2}")
ax.legend()
ax.grid(True)
fig.tight_layout(pad=3)
plt.show()
return df_results
[docs]
def plot_multiple_metrics(
metrics: dict[str, dict[str, tuple[float, float]]],
title_map: Optional[dict[str, str]] = None,
ylabel_map: Optional[dict[str, str]] = None,
filename: Optional[Union[str, Path]] = None
) -> None:
"""
Plot multiple grouped performance metrics as side-by-side bar plots.
Parameters:
metrics (dict): A nested dictionary of the form:{metric_name: {group_name: (mean, std)}}.
title_map (dict, optional): Maps each metric name to a custom plot title.
ylabel_map (dict, optional): Maps each metric name to a Y-axis label.
filename (str or Path, optional): If provided, saves the figure to this path.
Returns:
None. Displays or shows a matplotlib figure with grouped metric comparisons.
"""
logger.info(f"Plotting multiple metrics: {list(metrics.keys())}")
n = len(metrics)
fig, axes = plt.subplots(1, n, figsize=(5 * n, 5.5), sharey=True)
if n == 1:
axes = [axes]
for ax, (metric, sc) in zip(axes, metrics.items()):
groups = list(sc.keys())
sublabels = [""]
means_list: list[list[float]] = []
errs_list: list[list[float]] = []
for g in groups:
row_m: list[float] = []
row_e: list[float] = []
for s in sublabels:
m, e = sc[g]
row_m.append(m)
row_e.append(e)
means_list.append(row_m)
errs_list.append(row_e)
means = cast(np.ndarray, np.array(means_list, dtype=float))
errs = cast(np.ndarray, np.array(errs_list, dtype=float))
ind = np.arange(len(groups))
total = 0.7
width = total / len(sublabels)
group_colors = []
for i in range(len(groups)):
group_colors.append(cm.tab10(i))
lower_err = errs[:, 0]
upper_err = np.maximum(0, np.minimum(errs[:, 0], 1.0 - means[:, 0]))
asymmetric_err = np.array([lower_err, upper_err])
for i in range(len(sublabels)):
x = ind + i * width
y = means[:, i]
yerr = asymmetric_err
bars = ax.bar(x, y, width, yerr=yerr, capsize=4, color=group_colors, alpha=0.95, linewidth=0)
for j, bar in enumerate(bars):
height = bar.get_height()
bar_x = bar.get_x() + bar.get_width() / 2
label_y = means[j, i]
error_height = asymmetric_err[1, j]
bar_y = label_y + error_height + 0.01
ax.text(bar_x, bar_y, f"{height:.3f}", ha='center', va='bottom', fontsize=9, fontweight="bold")
ax.set_xticks(ind + total / 2 - width / 2)
ax.set_xticklabels(groups, fontsize=11, rotation=25, ha="right")
title = title_map.get(metric, metric) if title_map else metric
ylabel = ylabel_map.get(metric, metric) if ylabel_map else metric
ax.set_title(title, fontsize=14, pad=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.set_ylim(0.30, 1.05)
ax.grid(True, axis="y", linestyle="--", alpha=0.3)
plt.tight_layout(pad=2.0)
if filename:
fig.savefig(str(filename), dpi=300, bbox_inches="tight")
logger.info(f"Saved combined figure to {filename}")
else:
plt.show()
plt.close(fig)