r"""
DPMON: Optimized Network Embedding and Fusion for Disease Prediction.
This module implements an end-to-end Graph Neural Network (GNN) pipeline
integrating network topology with subject-level omics data.
References:
Hussein, S. et al. (2024), "Learning from Multi-Omics Networks to
Enhance Disease Prediction: An Optimized Network Embedding and
Fusion Approach" IEEE BIBM.
Algorithm:
The pipeline consists of three distinct phases:
Phase 1: Task-Aware Embedding Generation
1. Construct a multi-omics network.
2. Initialize node features using clinical correlation vectors.
3. Pass graph through a GNN (GAT/GCN/GIN).
Phase 2: Dimensionality Reduction
Compress embeddings into scalar weights via AutoEncoder/MLP.
Phase 3: Fusion and Prediction
Fuse embeddings with subject-level data via element-wise
multiplication (Feature Reweighting).
Notes:
The embedding space is optimized dynamically using the loss function:
.. math::
L_{total} = L_{classification} + \lambda L_{regularization}
The fusion acts as a **Network-Guided Attention Mechanism**,
amplifying features that are topologically central.
"""
from __future__ import annotations
import os
import re
import logging
import statistics
import tempfile
import shutil
import pandas as pd
import numpy as np
import networkx as nx
from pathlib import Path
from typing import Optional, List, Tuple,Dict, Any
try:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
except ModuleNotFoundError:
raise ImportError(
"DPMON (Disease Prediction using Multi-Omics Networks) requires PyTorch Geometric. "
"Please install it by following the instructions at: "
"https://bioneuralnet.readthedocs.io/en/latest/installation.html"
)
try:
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune import Checkpoint
from ray.tune.error import TuneError
from ray.tune.stopper import TrialPlateauStopper
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.basic_variant import BasicVariantGenerator
os.environ["TUNE_DISABLE_IPY_WIDGETS"] = "1"
os.environ["TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S"] = "0"
os.environ["RAY_DEDUP_LOGS"] = "0"
for logger_name in ("ray", "raylet", "ray.train.session", "ray.tune", "torch_geometric"):
logging.getLogger(logger_name).setLevel(logging.WARNING)
except ModuleNotFoundError:
raise ImportError(
"DPMON (Disease Prediction using Multi-Omics Networks) requires Ray Tune"
"Please install it by following the instructions at: "
"https://bioneuralnet.readthedocs.io/en/latest/installation.html"
)
from sklearn.model_selection import train_test_split,StratifiedKFold,RepeatedStratifiedKFold
from sklearn.preprocessing import label_binarize
from scipy.stats import pointbiserialr
from sklearn.metrics import f1_score, roc_auc_score, recall_score,precision_score,average_precision_score, matthews_corrcoef
from bioneuralnet.utils import set_seed
from bioneuralnet.network_embedding import GCN, GAT, SAGE, GIN
from ..utils import get_logger
logger= get_logger(__name__)
[docs]
class DPMON:
"""DPMON (Disease Prediction using Multi-Omics Networks) end-to-end pipeline for multi-omics disease prediction.
Instead of node-level MSE regression, DPMON aggregates node embeddings with patient-level omics data and feeds them to a downstream classification head (e.g., a softmax layer with CrossEntropyLoss) for sample-level disease prediction. This end-to-end setup leverages both local (node-level) and global (patient-level) network information.
Attributes:
adjacency_matrix (pd.DataFrame): Adjacency matrix of the feature-level network; index/columns are feature names.
omics_list (List[pd.DataFrame] | pd.DataFrame): List of omics data matrices or a single merged omics DataFrame (samples x features).
phenotype_data (pd.DataFrame | pd.Series): Phenotype labels used for supervision.
clinical_data (Optional[pd.DataFrame]): Optional clinical covariates (samples x clinical features); may be None.
phenotype_col (str): Column name in phenotype_data that stores the target labels.
model (str): GNN backbone; one of {"GCN", "GAT", "SAGE", "GIN"}.
gnn_hidden_dim (int): Hidden dimension size of GNN layers.
gnn_layer_num (int): Number of stacked GNN layers.
gnn_dropout (float): Dropout rate applied within the GNN.
gnn_activation (str): Non-linear activation used in GNN layers (e.g., "relu").
dim_reduction (str): Dimensionality reduction strategy for omics input (e.g., "ae" for autoencoder).
ae_encoding_dim (int): Encoding dimension of the autoencoder bottleneck if dim_reduction="ae".
nn_hidden_dim1 (int): Hidden dimension of the first fully connected layer in the downstream classifier.
nn_hidden_dim2 (int): Hidden dimension of the second fully connected layer in the downstream classifier.
num_epochs (int): Number of training epochs per run.
repeat_num (int): Number of repeated training runs (for repeated train/test splits or repeated CV).
n_folds (int): Number of folds to use when cv=True.
lr (float): Learning rate for the optimizer.
weight_decay (float): L2 weight decay (regularization) coefficient.
tune (bool): If True, perform hyperparameter tuning before final training.
tune_trials (int): Number of trials to perform if tune=True.
gpu (bool): If True, use GPU if available.
cv (bool): If True, use K-fold cross-validation; otherwise use repeated train/test splits.
cuda (int): CUDA device index to use when gpu=True.
seed (int): Random seed for reproducibility.
seed_trials (bool): If True, use a fixed seed for hyperparameter sampling to ensure reproducibility across trials.
output_dir (Path): Directory where logs, checkpoints, and results are written.
"""
def __init__(
self,
adjacency_matrix: pd.DataFrame,
omics_list: List[pd.DataFrame],
phenotype_data: pd.DataFrame,
clinical_data: Optional[pd.DataFrame] = None,
correlation_mode: str = "abs_pearson",
model: str = "GAT",
phenotype_col: str = "phenotype",
gnn_hidden_dim: int = 16,
gnn_layer_num: int = 4,
gnn_dropout: float = 0.1,
gnn_activation: str = "relu",
dim_reduction: str = "ae",
ae_architecture: str = "original",
ae_encoding_dim: int = 8,
nn_hidden_dim1: int = 16,
nn_hidden_dim2: int = 8,
num_epochs: int = 100,
repeat_num: int = 1,
n_folds: int = 5,
lr: float = 1e-1,
weight_decay: float = 1e-4,
gat_heads: int = 1,
tune: bool = False,
tune_trials: int = 20,
gpu: bool = False,
cv: bool = False,
cuda: int = 0,
seed: int = 1804,
seed_trials: bool = False,
output_dir: Optional[str] = None,
):
if adjacency_matrix.empty:
raise ValueError("Adjacency matrix cannot be empty.")
if isinstance(omics_list, list):
if not omics_list or any(df.empty for df in omics_list):
raise ValueError("All provided omics data files must be non-empty.")
self.combined_omics_input = pd.concat(omics_list, axis=1)
elif isinstance(omics_list, pd.DataFrame):
if omics_list.empty:
raise ValueError("Provided omics DataFrame cannot be empty.")
self.combined_omics_input = omics_list
else:
raise TypeError("omics_list must be a pandas DataFrame or a list of DataFrames.")
if isinstance(phenotype_data, pd.DataFrame):
if phenotype_data.empty or phenotype_col not in phenotype_data.columns:
raise ValueError(f"Phenotype DataFrame must have a '{phenotype_col}' column.")
self.phenotype_series = phenotype_data[phenotype_col]
elif isinstance(phenotype_data, pd.Series):
if phenotype_data.empty:
raise ValueError("Phenotype Series cannot be empty.")
self.phenotype_series = phenotype_data
else:
raise TypeError("phenotype_data must be a pandas DataFrame or Series.")
if clinical_data is not None and clinical_data.empty:
logger.warning(
"Clinical data provided is empty => treating as None => random features."
)
clinical_data = None
self.adjacency_matrix = adjacency_matrix
self.omics_list = omics_list
self.phenotype_data = phenotype_data
self.clinical_data = clinical_data
self.phenotype_col = phenotype_col
self.model = model
self.gnn_hidden_dim = gnn_hidden_dim
self.gnn_layer_num = gnn_layer_num
self.gnn_dropout = gnn_dropout
self.gnn_activation = gnn_activation
self.dim_reduction = dim_reduction
self.ae_encoding_dim = ae_encoding_dim
self.nn_hidden_dim1 = nn_hidden_dim1
self.nn_hidden_dim2 = nn_hidden_dim2
self.num_epochs = num_epochs
self.repeat_num = repeat_num
self.n_folds = n_folds
self.lr = lr
self.weight_decay = weight_decay
self.tune = tune
self.tune_trials = tune_trials
self.gpu = gpu
self.cuda = cuda
self.seed = seed
self.seed_trials = seed_trials
self.cv = cv
self.correlation_mode = correlation_mode
self.ae_architecture= ae_architecture
self.gat_heads =gat_heads
if output_dir is None:
self.output_dir = Path(os.getcwd()) / "dpmon"
else:
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory set to: {self.output_dir}")
logger.info(f"Initialized DPMON with model: {self.model}")
[docs]
def run(self) -> Tuple[pd.DataFrame, object, torch.Tensor | None]:
"""Execute the DPMON pipeline.
This method aligns the graph and omics features, optionally performs hyperparameter tuning, and then trains and evaluates the chosen GNN model using either K-fold cross-validation (cv=True) or repeated train/test splits (cv=False). It returns prediction outputs, a metrics/config object, and optionally the learned embeddings.
Returns:
Tuple[pd.DataFrame, object, torch.Tensor | None]: A tuple (predictions_df, metrics, embeddings) where:
predictions_df (pd.DataFrame): If cv=False, per-sample predictions with actual vs predicted labels; if cv=True, aggregated CV performance or fold-level results depending on the backend
metrics (object): Dictionary or configuration object containing evaluation metrics and, when tuning is enabled, information about the selected hyperparameters.
embeddings (torch.Tensor | None): Learned embedding tensor (e.g., node or sample embeddings) if produced by the training routine, otherwise None.
"""
set_seed(self.seed)
logger.info(f"Random seed set to: {self.seed}")
dpmon_params = {
"model": self.model,
"phenotype_col": self.phenotype_col,
"gnn_hidden_dim": self.gnn_hidden_dim,
"gnn_layer_num": self.gnn_layer_num,
"gnn_dropout":self.gnn_dropout,
"gnn_activation":self.gnn_activation,
"dim_reduction": self.dim_reduction,
"ae_encoding_dim": self.ae_encoding_dim,
"nn_hidden_dim1": self.nn_hidden_dim1,
"nn_hidden_dim2": self.nn_hidden_dim2,
"num_epochs": self.num_epochs,
"n_folds": self.n_folds,
"repeat_num": self.repeat_num,
"lr": self.lr,
"weight_decay": self.weight_decay,
"gpu": self.gpu,
"cuda": self.cuda,
"tune": self.tune,
"tune_trials": self.tune_trials,
"seed": self.seed,
"seed_trials": self.seed_trials,
"correlation_mode": self.correlation_mode,
"ae_architecture": self.ae_architecture,
"gat_heads": self.gat_heads,
}
graph_nodes = set(self.adjacency_matrix.index)
omics_features = set(self.combined_omics_input.columns)
common_features = list(graph_nodes.intersection(omics_features))
if not common_features:
raise ValueError("No common features found between adjacency matrix and omics data.")
dropped_graph_nodes = len(graph_nodes) - len(common_features)
dropped_omics_features = len(omics_features) - len(common_features)
if dropped_graph_nodes > 0 or dropped_omics_features > 0:
logger.info(
f"Graph/omics mismatch. Aligning to {len(common_features)} common features. "
f"Dropped {dropped_graph_nodes} from graph and {dropped_omics_features} from omics. "
"To prevent this, ensure data is pre-aligned."
)
self.adjacency_matrix = self.adjacency_matrix.loc[common_features, common_features]
combined_omics = self.combined_omics_input[common_features]
phenotype_series = self.phenotype_series.rename(self.phenotype_col)
if self.phenotype_col not in combined_omics.columns:
combined_omics = combined_omics.merge(
phenotype_series,
left_index=True,
right_index=True,
)
else:
logger.warning(f"Column '{self.phenotype_col}' already exists in omics data. Using existing column.")
predictions_df, metrics, embeddings = run_standard_training(
dpmon_params,
self.adjacency_matrix,
combined_omics,
self.clinical_data,
seed=self.seed,
cv=self.cv,
output_dir=self.output_dir
)
logger.info("DPMON run completed.")
return predictions_df, metrics, embeddings
[docs]
def setup_device(gpu, cuda):
if gpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
logger.debug(f"Using GPU {cuda}")
else:
logger.warning(f"GPU {cuda} requested but not available, using CPU")
else:
device = torch.device("cpu")
logger.debug("Using CPU")
return device
[docs]
def slice_omics_datasets(omics_dataset: pd.DataFrame, adjacency_matrix: pd.DataFrame, phenotype_col: str = "phenotype") -> List[pd.DataFrame]:
logger.debug("Slicing omics dataset based on network nodes.")
omics_network_nodes_names = adjacency_matrix.index.tolist()
# Clean omics dataset columns
clean_columns = []
for node in omics_dataset.columns:
node_clean = re.sub(r"[^0-9a-zA-Z_]", ".", node)
if not node_clean[0].isalpha():
node_clean = "X" + node_clean
clean_columns.append(node_clean)
omics_dataset.columns = clean_columns
missing_nodes = set(omics_network_nodes_names) - set(omics_dataset.columns)
if missing_nodes:
logger.error(f"Nodes missing in omics data: {missing_nodes}")
raise ValueError("Missing nodes in omics dataset.")
selected_columns = omics_network_nodes_names + [phenotype_col]
return [omics_dataset[selected_columns]]
[docs]
def prepare_node_features(
adjacency_matrix: pd.DataFrame,
omics_datasets: List[pd.DataFrame],
clinical_data: Optional[pd.DataFrame],
phenotype_col: str,
correlation_mode: str = "abs_pearson",
) -> List[Data]:
"""Build node-level features and return a PyTorch Geometric graph.
Args:
adjacency_matrix: Symmetric adjacency matrix (node names as index/columns).
omics_datasets: List of omics matrices (samples x features); first element used.
clinical_data: Clinical covariates for correlation-based node features; may be None.
phenotype_col: Column name storing phenotype labels (dropped from features).
correlation_mode: How to compute node features from clinical correlations.
- "abs_pearson": Absolute Pearson correlation, no transforms = DPMON.
- "adaptive": Mixed correlation types + Fisher-Z + standardization.
Returns:
List[Data]: Single-element list with a PyG Data object.
"""
logger.debug(f"Building PyG Data object (correlation_mode={correlation_mode}).")
network_features = adjacency_matrix.columns
omics_data = omics_datasets[0]
if phenotype_col in omics_data.columns:
omics_feature_df = omics_data.drop(columns=[phenotype_col])
else:
omics_feature_df = omics_data
nodes = sorted(network_features.intersection(omics_feature_df.columns))
if len(nodes) == 0:
raise ValueError("No common features found between the network and omics data.")
omics_filtered = omics_feature_df[nodes]
network_filtered = adjacency_matrix.loc[nodes, nodes]
logger.info(f"Building graph with {len(nodes)} common features.")
G = nx.from_pandas_adjacency(network_filtered)
self_loops = list(nx.selfloop_edges(G))
if self_loops:
G.remove_edges_from(self_loops)
logger.debug(f"Removed {len(self_loops)} self-loop edges.")
if clinical_data is not None and not clinical_data.empty:
clinical_cols = list(clinical_data.columns)
common_index = clinical_data.index.intersection(omics_filtered.index)
if common_index.empty:
raise ValueError("No common indices between omics and clinical data.")
node_features_dict = {}
for node in nodes:
vec = pd.to_numeric(omics_filtered[node].loc[common_index], errors="coerce")
vec = vec.dropna()
corr_vector = {}
for cvar in clinical_cols:
clinical_series = clinical_data[cvar].loc[common_index]
common_valid = vec.index.intersection(clinical_series.dropna().index)
vec_aligned = vec.loc[common_valid]
clinical_aligned = clinical_series.loc[common_valid].astype("float64")
if clinical_aligned.nunique() <= 1 or vec_aligned.nunique() <= 1 or len(vec_aligned) < 2:
corr_vector[cvar] = 0.0
continue
if correlation_mode == "abs_pearson":
# OG DPMON: abs(Pearson correlation)
try:
corr_val = abs(vec_aligned.corr(clinical_aligned))
if pd.isna(corr_val):
corr_val = 0.0
except Exception:
corr_val = 0.0
corr_vector[cvar] = corr_val
elif correlation_mode == "adaptive":
# OPTION 2: mixed types + Fisher-Z
vec_is_binary = vec_aligned.nunique() == 2
clinical_is_binary = clinical_aligned.nunique() == 2
try:
if vec_is_binary and clinical_is_binary:
corr_val = matthews_corrcoef(vec_aligned, clinical_aligned)
elif vec_is_binary or clinical_is_binary:
corr_val, _ = pointbiserialr(clinical_aligned, vec_aligned)
if pd.isna(corr_val):
corr_val = 0.0
else:
corr_val = vec_aligned.corr(clinical_aligned)
if pd.isna(corr_val):
corr_val = 0.0
except Exception as e:
logger.debug(f"Correlation failed for {node}-{cvar}: {e}")
corr_val = 0.0
# Fisher-Z transform
if pd.isna(corr_val) or corr_val == 0.0:
z = 0.0
else:
r_clip = np.clip(corr_val, -0.999999, 0.999999)
z = np.arctanh(r_clip)
corr_vector[cvar] = z
else:
raise ValueError(f"Unknown correlation_mode: {correlation_mode}")
node_features_dict[node] = corr_vector
node_features_df = pd.DataFrame.from_dict(node_features_dict, orient="index")
node_features_df = node_features_df.fillna(0.0)
if correlation_mode == "adaptive":
# standardize only in adaptive mode DPMON uses raw.
std_vals = node_features_df.std()
std_vals = std_vals.replace(0, 1e-8)
node_features_df = (node_features_df - node_features_df.mean()) / std_vals
logger.info(f"Node feature matrix shape: {node_features_df.shape} (mode={correlation_mode})")
else:
# No clinical data -> generate random features as fallback
logger.warning("No clinical data provided. Using random node features.")
rng = np.random.default_rng(1998)
node_features_df = pd.DataFrame(
rng.standard_normal((len(nodes), 7)),
index=nodes,
columns=[f"rand_{i}" for i in range(7)],
)
# convert to PyG Data
x = torch.tensor(node_features_df.values, dtype=torch.float)
node_mapping = {name: i for i, name in enumerate(nodes)}
G_mapped = nx.relabel_nodes(G, node_mapping)
edges_list = list(G_mapped.edges())
if not edges_list:
logger.warning("Graph has no edges after self-loop removal.")
edge_index = torch.zeros((2, 0), dtype=torch.long)
edge_weight = torch.zeros(0, dtype=torch.float)
else:
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
# Make bidirectional
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
weights = []
for _, _, d in G_mapped.edges(data=True):
weights.append(d.get("weight", 1.0))
edge_weight = torch.tensor(weights, dtype=torch.float)
edge_weight = torch.cat([edge_weight, edge_weight], dim=0)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
return [data]
[docs]
def run_standard_training(dpmon_params, adjacency_matrix, combined_omics, clinical_data, seed, cv=False, output_dir=None):
phenotype_col = dpmon_params["phenotype_col"]
correlation_mode = dpmon_params["correlation_mode"]
device = setup_device(dpmon_params["gpu"], dpmon_params["cuda"])
omics_dataset = slice_omics_datasets(combined_omics, adjacency_matrix, phenotype_col)
omics_dataset = omics_dataset[0]
if not cv:
logger.info(f"Running in standard mode (cv=False) with {dpmon_params['repeat_num']} repeats.")
test_accuracies = []
all_predictions_list = []
best_accuracy = 0.0
best_model_state = None
best_predictions_df = None
f1_macros = []
f1_weighteds = []
recalls = []
aucs = []
auprs = []
X = omics_dataset.drop([phenotype_col], axis=1)
Y = omics_dataset[phenotype_col]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=seed, stratify=Y)
if clinical_data is None:
clinical_data_full = pd.DataFrame(index=X.index)
else:
clinical_data_full = clinical_data.reindex(X.index)
clinical_train = clinical_data_full.loc[X_train.index]
if dpmon_params['tune']:
clinical_train_tune = clinical_data_full.loc[X_train.index]
best_config = run_hyperparameter_tuning(
X_train, y_train,
adjacency_matrix,
clinical_train_tune,
dpmon_params
)
dpmon_params.update(best_config)
logger.info(f"Best config: {best_config}")
logger.info("Building 'clean' graph features for standard run using train split")
omics_train_fold_list = [X_train.join(y_train.rename(phenotype_col))]
omics_network_tg = prepare_node_features(
adjacency_matrix,
omics_train_fold_list,
clinical_train,
phenotype_col,
correlation_mode
)[0].to(device)
X_train_tensor = torch.FloatTensor(X_train.values).to(device)
y_train_tensor = torch.LongTensor(y_train.values).to(device)
X_test_tensor = torch.FloatTensor(X_test.values).to(device)
y_test_tensor = torch.LongTensor(y_test.values).to(device)
train_labels_dict = {
"labels": y_train_tensor,
"omics_network": omics_network_tg
}
for i in range(dpmon_params["repeat_num"]):
logger.info(f"Training iteration {i+1}/{dpmon_params['repeat_num']}")
model = NeuralNetwork(
model_type=dpmon_params["model"],
gnn_input_dim=omics_network_tg.x.shape[1],
gnn_hidden_dim=dpmon_params["gnn_hidden_dim"],
gnn_layer_num=dpmon_params["gnn_layer_num"],
dim_reduction=dpmon_params["dim_reduction"],
ae_encoding_dim=dpmon_params["ae_encoding_dim"],
ae_architecture=dpmon_params["ae_architecture"],
nn_input_dim=X_train_tensor.shape[1],
nn_hidden_dim1=dpmon_params["nn_hidden_dim1"],
nn_hidden_dim2=dpmon_params["nn_hidden_dim2"],
nn_output_dim=Y.nunique(),
gnn_dropout=dpmon_params["gnn_dropout"],
gnn_activation=dpmon_params["gnn_activation"],
gat_heads=dpmon_params["gat_heads"]
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=dpmon_params["lr"], weight_decay=dpmon_params["weight_decay"])
model = train_model(
model, criterion, optimizer,
X_train_tensor, train_labels_dict, dpmon_params["num_epochs"]
)
model.eval()
with torch.no_grad():
predictions, _, _ = model(X_test_tensor, omics_network_tg)
_, predicted = torch.max(predictions, 1)
probs = torch.softmax(predictions, dim=1)
y_test_np = y_test_tensor.cpu().numpy()
predicted_np = predicted.cpu().numpy()
probs_np = probs.cpu().numpy()
accuracy = (predicted == y_test_tensor).sum().item() / len(y_test_tensor)
f1_ma = f1_score(y_test_np, predicted_np, average='macro', zero_division=0)
f1_wt = f1_score(y_test_np, predicted_np, average='weighted', zero_division=0)
recall = recall_score(y_test_np, predicted_np, average='macro', zero_division=0)
try:
n_classes = probs_np.shape[1]
if n_classes == 2:
auc_score = roc_auc_score(y_test_np, probs_np[:, 1])
aupr = average_precision_score(y_test_np, probs_np[:, 1])
else:
auc_score = roc_auc_score(y_test_np, probs_np, multi_class='ovr', average='macro')
aupr = 0.0
except:
auc_score, aupr = 0.0, 0.0
logger.info(f"Iteration {i+1} Results:")
logger.info(f" Accuracy: {accuracy:.4f}")
logger.info(f" F1 Macro: {f1_ma:.4f}")
logger.info(f" F1 Weighted: {f1_wt:.4f}")
logger.info(f" Recall: {recall:.4f}")
logger.info(f" AUC: {auc_score:.4f}")
logger.info(f" AUPR: {aupr:.4f}\n")
test_accuracies.append(accuracy)
f1_macros.append(f1_ma)
f1_weighteds.append(f1_wt)
recalls.append(recall)
aucs.append(auc_score)
auprs.append(aupr)
if accuracy > best_accuracy:
best_accuracy = accuracy
best_model_state = model.state_dict()
if test_accuracies:
def get_stats(data_list):
avg = statistics.mean(data_list) if data_list else 0.0
std = statistics.stdev(data_list) if len(data_list) > 1 else 0.0
return avg, std
metrics_to_report = {
'Accuracy': test_accuracies,
'F1 Macro': f1_macros,
'F1 Weighted': f1_weighteds,
'Recall': recalls,
'AUC': aucs,
'AUPR': auprs
}
summary_rows = []
for name, values in metrics_to_report.items():
avg, std = get_stats(values)
summary_rows.append({'Metric': name, 'Average': avg, 'StdDev': std})
metrics_df = pd.DataFrame(summary_rows)
logger.info("--- Standard Run Final Summary (avg across repeats) ---")
for _, row in metrics_df.iterrows():
logger.info(f"Avg {row['Metric']}: {row['Average']:.4f} +/- {row['StdDev']:.4f}")
logger.info("------------------------------------------------------\n")
else:
metrics_df = pd.DataFrame()
if output_dir and best_model_state is not None:
model_save_path = os.path.join(output_dir, "best_model_standard_run.pt")
try:
torch.save(best_model_state, model_save_path)
logger.info(f"Successfully saved best model state to: {model_save_path}")
except Exception as e:
logger.error(f"Failed to save best model: {e}")
return best_predictions_df, all_predictions_list, metrics_df
else:
n_folds = dpmon_params["n_folds"]
logger.info(f"Running in Cross-Validation mode (cv=True) with {n_folds} folds.")
# these are to track the best model across folds and then save it
best_global_fold_accuracy = 0.0
best_global_fold_f1 = 0.0
best_global_model_state = None
best_global_embeddings = None
fold_accuracies = []
fold_f1_macros = []
fold_f1_weighteds = []
fold_auprs = []
fold_aucs = []
fold_recalls = []
fold_precisions = []
fold_best_configs = []
all_fold_results = []
X = omics_dataset.drop([phenotype_col], axis=1)
Y = omics_dataset[phenotype_col]
if clinical_data is None:
clinical_data_full = pd.DataFrame(index=X.index)
else:
clinical_data_full = clinical_data.reindex(X.index)
repeat_num_val = dpmon_params.get("repeat_num", 1)
total_splits = n_folds * repeat_num_val
if repeat_num_val > 1:
skf = RepeatedStratifiedKFold(
n_splits=n_folds,
n_repeats=repeat_num_val,
random_state=seed
)
logger.info(f"CV Setup: Repeated K-Fold ({n_folds}x{repeat_num_val} = {total_splits} splits total).")
else:
# fallback to single Stratified kfold
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
logger.info(f"CV Setup: Standard {n_folds}-fold split.")
for fold, (train_index, test_index) in enumerate(skf.split(X, Y)):
current_repeat = fold // n_folds + 1
current_fold = fold % n_folds + 1
if repeat_num_val > 1:
logger.info(f"Starting Repeat {current_repeat}/{repeat_num_val}, Fold {current_fold}/{n_folds} (Total Split {fold + 1}/{total_splits})")
else:
logger.info(f"Starting Fold {current_fold}/{n_folds}")
X_train, X_test = X.iloc[train_index], X.iloc[test_index]
y_train, y_test = Y.iloc[train_index], Y.iloc[test_index]
if dpmon_params['tune']:
best_config = run_hyperparameter_tuning(
X_train, y_train,
adjacency_matrix,
clinical_data_full.iloc[train_index],
dpmon_params
)
dpmon_params.update(best_config)
logger.info(f"Fold {fold+1} best config: {best_config}")
#save params
fold_record = best_config.copy()
fold_record['Fold'] = fold + 1
fold_best_configs.append(fold_record)
clinical_train = clinical_data_full.iloc[train_index]
clinical_test = clinical_data_full.iloc[test_index]
logger.info(f"Building graph features for Fold {fold+1} using train split only")
omics_train_fold_list = [X_train.join(y_train.rename(phenotype_col))]
omics_network_tg = prepare_node_features(
adjacency_matrix,
omics_train_fold_list,
clinical_train,
phenotype_col,
correlation_mode
)[0].to(device)
X_train_tensor = torch.FloatTensor(X_train.values).to(device)
y_train_tensor = torch.LongTensor(y_train.values).to(device)
X_test_tensor = torch.FloatTensor(X_test.values).to(device)
y_test_tensor = torch.LongTensor(y_test.values).to(device)
train_labels_dict = {
"labels": y_train_tensor,
"omics_network": omics_network_tg
}
model = NeuralNetwork(
model_type=dpmon_params["model"],
gnn_input_dim=omics_network_tg.x.shape[1],
gnn_hidden_dim=dpmon_params["gnn_hidden_dim"],
gnn_layer_num=dpmon_params["gnn_layer_num"],
ae_encoding_dim=dpmon_params["ae_encoding_dim"],
ae_architecture=dpmon_params["ae_architecture"],
nn_input_dim=X_train_tensor.shape[1],
nn_hidden_dim1=dpmon_params["nn_hidden_dim1"],
nn_hidden_dim2=dpmon_params["nn_hidden_dim2"],
nn_output_dim=Y.nunique(),
gnn_dropout=dpmon_params["gnn_dropout"],
gnn_activation=dpmon_params["gnn_activation"],
dim_reduction=dpmon_params["dim_reduction"],
gat_heads=dpmon_params["gat_heads"]
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=dpmon_params["lr"], weight_decay=dpmon_params["weight_decay"])
model = train_model(model, criterion, optimizer,X_train_tensor, train_labels_dict, dpmon_params["num_epochs"])
model.eval()
logger.info(f"Evaluating model for Fold {fold+1} on test set")
with torch.no_grad():
predictions, _, node_emb = model(X_test_tensor, omics_network_tg)
_, predicted = torch.max(predictions, 1)
probs = torch.softmax(predictions, dim=1)
y_test_np = y_test_tensor.cpu().numpy()
predicted_np = predicted.cpu().numpy()
probs_np = probs.cpu().numpy()
accuracy = (predicted == y_test_tensor).sum().item() / len(y_test_tensor)
f1_ma = f1_score(y_test_np, predicted_np, average='macro', zero_division=0)
f1_wt = f1_score(y_test_np, predicted_np, average='weighted', zero_division=0)
recall = recall_score(y_test_np, predicted_np, average='macro', zero_division=0)
precision = precision_score(y_test_np, predicted_np, average='macro', zero_division=0)
try:
n_classes = probs_np.shape[1]
# binary
if n_classes == 2:
# Ususinge probability of positive
auc_score = roc_auc_score(y_test_np, probs_np[:, 1])
aupr = average_precision_score(y_test_np, probs_np[:, 1])
logger.debug(f"Binary | AUC: {auc_score:.4f}, AUPR: {aupr:.4f}")
else:
auc_score = roc_auc_score(y_test_np, probs_np, multi_class='ovr', average='macro')
y_test_bin = label_binarize(y_test_np, classes=range(n_classes))
aupr_scores = []
for i in range(n_classes):
# checking if class exists in test set
if np.sum(y_test_bin[:, i]) > 0:
ap = average_precision_score(y_test_bin[:, i], probs_np[:, i])
aupr_scores.append(ap)
aupr = np.mean(aupr_scores) if aupr_scores else 0.0
logger.debug(f"Multiclass | AUC: {auc_score:.4f}, AUPR: {aupr:.4f}")
except Exception as e:
logger.error(f"Could not calculate AUC/AUPR: {e}")
import traceback
logger.error(traceback.format_exc())
auc_score = 0.0
aupr = 0.0
fold_predictions = {
'accuracy': accuracy,
'f1_ma': f1_ma,
'f1_wt': f1_wt,
'aupr': aupr,
'auc': auc_score,
'recall': recall,
'precision': precision
}
if accuracy > best_global_fold_accuracy:
best_global_fold_accuracy = accuracy
best_global_model_state = model.state_dict()
best_global_embeddings = node_emb.detach().cpu()
# Should be a parameter that way we can decide which model to optemize/save
# if f1_ma > best_global_fold_f1:
# best_global_fold_f1 = f1_ma
# best_global_model_state = model.state_dict()
# best_global_embeddings = node_emb.detach().cpu()
if fold_predictions:
fold_accuracies.append(fold_predictions['accuracy'])
fold_f1_macros.append(fold_predictions['f1_ma'])
fold_f1_weighteds.append(fold_predictions['f1_wt'])
fold_auprs.append(fold_predictions['aupr'])
fold_aucs.append(fold_predictions['auc'])
fold_recalls.append(fold_predictions['recall'])
fold_precisions.append(fold_predictions["precision"])
all_fold_results.append(fold_predictions)
logger.info(f"Fold {fold+1} results:")
logger.info(f" Accuracy: {accuracy:.4f}")
logger.info(f" F1 Macro: {f1_ma:.4f}")
logger.info(f" F1 Weighted: {f1_wt:.4f}")
logger.info(f" Recall: {recall:.4f}")
logger.info(f" Precision: {precision:.4f}")
logger.info(f" AUC: {auc_score:.4f}")
logger.info(f" AUPR: {aupr:.4f}\n")
if dpmon_params['gpu']:
torch.cuda.empty_cache()
logger.debug(f"Clearing cuda cache for fold {fold+1} \n")
avg_acc = statistics.mean(fold_accuracies) if fold_accuracies else 0.0
std_acc = statistics.stdev(fold_accuracies) if len(fold_accuracies) > 1 else 0.0
avg_f1_ma = statistics.mean(fold_f1_macros) if fold_f1_macros else 0.0
std_f1_ma = statistics.stdev(fold_f1_macros) if len(fold_f1_macros) > 1 else 0.0
avg_f1_wt = statistics.mean(fold_f1_weighteds) if fold_f1_weighteds else 0.0
std_f1_wt = statistics.stdev(fold_f1_weighteds) if len(fold_f1_weighteds) > 1 else 0.0
avg_aupr = statistics.mean(fold_auprs) if fold_auprs else 0.0
std_aupr = statistics.stdev(fold_auprs) if len(fold_auprs) > 1 else 0.0
avg_auc = statistics.mean(fold_aucs) if fold_aucs else 0.0
std_auc = statistics.stdev(fold_aucs) if len(fold_aucs) > 1 else 0.0
avg_recall = statistics.mean(fold_recalls) if fold_recalls else 0.0
std_recall = statistics.stdev(fold_recalls) if len(fold_recalls) > 1 else 0.0
avg_precision = statistics.mean(fold_precisions) if fold_precisions else 0.0
std_precision = statistics.stdev(fold_precisions) if len(fold_precisions) > 1 else 0.0
metrics_df = pd.DataFrame({
'Metric': ['Accuracy', 'F1 Macro', 'F1 Weighted', 'Recall', 'Precision', 'AUC', 'AUPR'],
'Average': [avg_acc, avg_f1_ma, avg_f1_wt, avg_recall, avg_precision, avg_auc, avg_aupr],
'StdDev': [std_acc, std_f1_ma, std_f1_wt, std_recall, std_precision, std_auc, std_aupr]
})
#final_cv_predictions_df = pd.concat(cv_predictions_list, ignore_index=True)
if output_dir and best_global_model_state is not None:
model_save_path = os.path.join(output_dir, "best_cv_model.pt")
try:
torch.save(best_global_model_state, model_save_path)
logger.info(f"Successfully saved global best CV model state to: {model_save_path}")
except Exception as e:
logger.error(f"Failed to save best CV model: {e}")
if output_dir and best_global_embeddings is not None:
emb_save_path = os.path.join(output_dir, "best_cv_model_embds.pt")
try:
torch.save(best_global_embeddings, emb_save_path)
logger.info(f"Successfully saved global best CV model state to: {emb_save_path}")
except Exception as e:
logger.error(f"Failed to save best CV model: {e}")
if output_dir and fold_best_configs:
try:
results_df = pd.DataFrame(all_fold_results)
config_save_path = os.path.join(output_dir, "cv_tuning_results.csv")
config_df = pd.DataFrame(fold_best_configs)
combined = pd.concat([results_df,config_df], axis=1)
combined.to_csv(config_save_path)
logger.info(f"Successfully saved CV tuning parameters to: {config_save_path}")
except Exception as e:
logger.warning(f"Failed to save CV tuning parameters: {e}")
logger.info("Cross-Validation Results:\n")
logger.info(f" Avg Accuracy: {avg_acc:.4f} +/- {std_acc:.4f}")
logger.info(f" Avg F1 Macro: {avg_f1_ma:.4f} +/- {std_f1_ma:.4f}")
logger.info(f" Avg F1 Weighted: {avg_f1_wt:.4f} +/- {std_f1_wt:.4f}")
logger.info(f" Avg Recall: {avg_recall:.4f} +/- {std_recall:.4f}")
logger.info(f" Avg Precision: {avg_precision:.4f} +/- {std_precision:.4f}")
logger.info(f" Avg AUC: {avg_auc:.4f} +/- {std_auc:.4f}")
logger.info(f" Avg AUPR: {avg_aupr:.4f} +/- {std_aupr:.4f}")
return pd.DataFrame(), metrics_df, best_global_embeddings
[docs]
def run_hyperparameter_tuning(X_train, y_train, adjacency_matrix, clinical_data, dpmon_params) -> Dict[str, Any]:
"""Run Ray Tune hyperparameter search with inner k-fold CV.
Each trial trains one model per inner fold, epoch-synchronised,
and reports the mean validation metrics. Asha early-stops on
the averaged signal, which is far more stable than a single split.
Args:
X_train: Training features for this outer fold (pd.DataFrame).
y_train: Training labels for this outer fold (pd.Series).
adjacency_matrix: Feature-level adjacency matrix.
clinical_data: Clinical covariates for the training fold.
dpmon_params: Full DPMON parameter dictionary.
Returns:
Dict with the best hyperparameter configuration.
"""
#os.environ["TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S"] = "0"
#os.environ["TUNE_DISABLE_IPY_WIDGETS"] = "1"
device = setup_device(dpmon_params["gpu"], dpmon_params["cuda"])
phenotype_col = dpmon_params["phenotype_col"]
correlation_mode = dpmon_params["correlation_mode"]
combined_omics_fold = X_train.join(y_train.rename(phenotype_col))
omics_dataset = slice_omics_datasets(
combined_omics_fold, adjacency_matrix, phenotype_col
)
omics_train_fold_list = [omics_dataset[0]]
omics_network_tg = prepare_node_features(
adjacency_matrix,
omics_train_fold_list,
clinical_data,
phenotype_col,
correlation_mode,
)[0].to(device)
pipeline_configs = {
"gnn_layer_num": tune.choice([2, 3, 4]),
"gnn_hidden_dim": tune.choice([32, 64]),
"lr": tune.loguniform(1e-4, 8e-4),
"weight_decay": tune.loguniform(1e-5, 5e-3),
"nn_hidden_dim1": tune.choice([128, 256]),
"nn_hidden_dim2": tune.choice([64]),
"ae_encoding_dim": tune.choice([4, 8]),
"ae_architecture": tune.choice(["original", "dynamic"]),
"num_epochs": tune.choice([128, 256]),
"gnn_dropout": tune.choice([0.4, 0.5, 0.6]),
"gnn_activation": tune.choice(["relu", "elu"]),
"dim_reduction": tune.choice(["ae", "linear", "mlp"]),
"gat_heads": tune.choice([1, 2]),
}
# prepare inner k-fold splits and push to ray object store
omics_data = omics_dataset[0]
X = omics_data.drop([phenotype_col], axis=1)
Y = omics_data[phenotype_col]
n_inner_folds = dpmon_params.get("tune_inner_folds", 5)
inner_cv = StratifiedKFold(
n_splits=n_inner_folds, shuffle=True, random_state=dpmon_params["seed"]
)
# pre-tensor on every inner fold and store in ray object store, this is so trials fetch data zero-copy instead of re-splitting.
fold_data_refs = []
for tr_idx, val_idx in inner_cv.split(X, Y):
y_tr = Y.iloc[tr_idx].values
fold_tensors = {
"X_train": torch.FloatTensor(X.iloc[tr_idx].values),
"y_train": torch.LongTensor(y_tr),
"X_val": torch.FloatTensor(X.iloc[val_idx].values),
"y_val": torch.LongTensor(Y.iloc[val_idx].values),
}
fold_data_refs.append(ray.put(fold_tensors))
omics_network_ref = ray.put(omics_network_tg.cpu())
logger.info(f"Inner CV: {n_inner_folds} folds | X shape: {X.shape} | Graph nodes: {omics_network_tg.x.shape}")
# pre-compute dims
gnn_input_dim = omics_network_tg.x.shape[1]
nn_input_dim = X.shape[1]
nn_output_dim = Y.nunique()
model_type = dpmon_params["model"]
# trial function trains k models epoch-sync
def tune_train_fn(config):
device_inner = setup_device(dpmon_params["gpu"], dpmon_params["cuda"])
omics_net = ray.get(omics_network_ref).to(device_inner)
# load every inner fold onto device
folds = []
for ref in fold_data_refs:
fd = ray.get(ref)
fold_dict = {
"X_train": fd["X_train"].to(device_inner),
"y_train": fd["y_train"].to(device_inner),
"X_val": fd["X_val"].to(device_inner),
"y_val": fd["y_val"].to(device_inner),
"criterion": nn.CrossEntropyLoss()
}
folds.append(fold_dict)
# one model + optimizer per inner fold
models, optimizers = [], []
for _ in range(len(folds)):
m = NeuralNetwork(
model_type=model_type,
gnn_input_dim=gnn_input_dim,
gnn_hidden_dim=config["gnn_hidden_dim"],
gnn_layer_num=config["gnn_layer_num"],
gnn_dropout=config["gnn_dropout"],
gnn_activation=config["gnn_activation"],
dim_reduction=config["dim_reduction"],
ae_encoding_dim=config["ae_encoding_dim"],
ae_architecture=config["ae_architecture"],
gat_heads=config["gat_heads"],
nn_input_dim=nn_input_dim,
nn_hidden_dim1=config["nn_hidden_dim1"],
nn_hidden_dim2=config["nn_hidden_dim2"],
nn_output_dim=nn_output_dim,
).to(device_inner)
o = optim.Adam(
m.parameters(),
lr=config["lr"],
weight_decay=config["weight_decay"],
)
models.append(m)
optimizers.append(o)
# epoch-sync training across all inner folds
for epoch in range(config["num_epochs"]):
epoch_val_losses = []
epoch_val_accs = []
epoch_train_losses = []
epoch_val_f1s = []
epoch_val_auprs = []
for fi, fold in enumerate(folds):
# train step
models[fi].train()
optimizers[fi].zero_grad()
out, _, _ = models[fi](fold["X_train"], omics_net)
loss = fold["criterion"](out, fold["y_train"])
loss.backward()
optimizers[fi].step()
epoch_train_losses.append(loss.item())
# val step
models[fi].eval()
with torch.no_grad():
val_out, _, _ = models[fi](fold["X_val"], omics_net)
vl = fold["criterion"](val_out, fold["y_val"]).item()
_, preds = torch.max(val_out, 1)
probs = torch.softmax(val_out, dim=1)
y_true_np = fold["y_val"].cpu().numpy()
predicted_np = preds.cpu().numpy()
probs_np = probs.cpu().numpy()
va = (preds == fold["y_val"]).sum().item() / fold["y_val"].size(0)
f1_mac = f1_score(y_true_np, predicted_np, average='macro', zero_division=0)
try:
n_classes = probs_np.shape[1]
if n_classes == 2:
aupr = average_precision_score(y_true_np, probs_np[:, 1])
else:
y_bin = label_binarize(y_true_np, classes=range(n_classes))
aupr_scores = []
for i in range(n_classes):
if np.sum(y_bin[:, i]) > 0:
aupr_scores.append(average_precision_score(y_bin[:, i], probs_np[:, i]))
aupr = np.mean(aupr_scores) if aupr_scores else 0.0
except:
aupr = 0.0
epoch_val_losses.append(vl)
epoch_val_accs.append(va)
epoch_val_f1s.append(f1_mac)
epoch_val_auprs.append(aupr)
composite = 0.5 * float(np.mean(epoch_val_accs)) + 0.5 * float(np.mean(epoch_val_f1s))
# report mean metrics
metrics = {
"val_loss": float(np.mean(epoch_val_losses)),
"val_accuracy": float(np.mean(epoch_val_accs)),
"val_f1_macro": float(np.mean(epoch_val_f1s)),
"val_aupr": float(np.mean(epoch_val_auprs)),
"train_loss": float(np.mean(epoch_train_losses)),
}
ckpt_dir = "trial_checkpoint"
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(
{"epoch": epoch, "model_state": models[0].state_dict()},
os.path.join(ckpt_dir, "checkpoint.pt"),
)
tune.report(
metrics=metrics,
checkpoint=Checkpoint.from_directory(ckpt_dir),
)
# launch Ray Tune
num_samples = dpmon_params.get("tune_trials", 20)
seed_trials = dpmon_params.get("seed_trials", False)
max_retries = 4
if seed_trials:
logger.debug(f"seed_trials=True: fixed seed {dpmon_params['seed']}")
else:
logger.debug("seed_trials=False: random hyperparameter sampling")
scheduler = ASHAScheduler(grace_period=50, reduction_factor=2)
stopper = TrialPlateauStopper(
metric="val_f1_macro", mode="max",
num_results=20, metric_threshold=0.01, grace_period=50,
)
def short_dirname_creator(trial):
return f"T{trial.trial_id}"
#dummy reporter that fixes Ray-rune screen-clearing for jupyter notebooks
class SilentReporter(CLIReporter):
def should_report(self, trials, done=False):
return False
def report(self, trials, done, *sys_info):
pass
use_gpu = bool(dpmon_params.get("gpu", False)) and torch.cuda.is_available()
if dpmon_params.get("gpu", False) and not torch.cuda.is_available():
logger.warning("gpu=True but CUDA not available; running on CPU.")
cpu_per_trial = 2
gpu_per_trial = 0.2 if use_gpu else 0.0
for attempt in range(max_retries):
try:
search_alg = None
if seed_trials:
search_alg = BasicVariantGenerator(
random_state=np.random.RandomState(dpmon_params["seed"])
)
tuner = tune.Tuner(
tune.with_resources(
tune_train_fn,
resources={"cpu": cpu_per_trial, "gpu": gpu_per_trial},
),
param_space=pipeline_configs,
tune_config=tune.TuneConfig(
metric="val_f1_macro",
mode="max",
num_samples=num_samples,
scheduler=scheduler,
search_alg=search_alg,
trial_dirname_creator=short_dirname_creator,
),
run_config=tune.RunConfig(
name="tune_dp",
verbose=0,
log_to_file=True,
stop=stopper,
storage_path=os.path.expanduser("~/ray_results"),
sync_config=tune.SyncConfig(sync_artifacts=False),
checkpoint_config=tune.CheckpointConfig(
num_to_keep=1,
checkpoint_score_attribute="val_f1_macro",
checkpoint_score_order="max",
),progress_reporter=SilentReporter(),
),
)
results = tuner.fit()
break
except TuneError as e:
msg = str(e)
if "Trials did not complete" not in msg and "OutOfMemoryError" not in msg:
raise
new_num_samples = max(1, num_samples // 2)
if use_gpu:
new_gpu_per_trial = min(1.0, gpu_per_trial + 0.2)
else:
new_gpu_per_trial = 0.0
if new_num_samples == num_samples and new_gpu_per_trial == gpu_per_trial:
logger.error("Cannot reduce num_samples or increase gpu_per_trial any further. Aborting.")
raise
logger.warning(
f"Ray Tune failed (attempt {attempt + 1}). "
f"Adjusting resources -> num_samples: {num_samples} to {new_num_samples}, "
f"gpu_per_trial: {gpu_per_trial:.2f} to {new_gpu_per_trial:.2f}."
)
num_samples = new_num_samples
gpu_per_trial= new_gpu_per_trial
else:
raise RuntimeError("Hyperparameter tuning failed after max retries.")
# extract best config
best_result = results.get_best_result(metric="val_f1_macro", mode="max")
best_config = best_result.config
logger.info(f"Best trial config: {best_config}")
logger.info(f"Best trial val_accuracy: {best_result.metrics.get('val_accuracy'):.4f}")
logger.info(f"Best trial val_loss: {best_result.metrics.get('val_loss'):.4f}")
logger.info(f"Best trial val_f1_macro: {best_result.metrics.get('val_f1_macro'):.4f}")
logger.info(f"Best trial val_aupr: {best_result.metrics.get('val_aupr'):.4f}")
# cleanup
try:
tune_dir = os.path.expanduser("~/ray_results/tune_dp")
if os.path.exists(tune_dir):
shutil.rmtree(tune_dir)
logger.debug(f"Cleaned up tuning directory: {tune_dir}")
except Exception as e:
logger.warning(f"Could not clean up tuning directory: {e}")
return best_config
[docs]
def train_model(model, criterion, optimizer, train_features, train_labels, epoch_num):
network = train_labels["omics_network"]
labels = train_labels["labels"]
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch_num, eta_min=1e-6)
model.train()
for epoch in range(epoch_num):
optimizer.zero_grad()
outputs, _, _ = model(train_features, network)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
if (epoch + 1) % 100 == 0 or epoch == 0:
logger.debug(f"Epoch [{epoch+1}/{epoch_num}], Loss: {loss.item():.4f}")
return model
[docs]
class NeuralNetwork(nn.Module):
"""Core DPMON model combining GNN feature weighting and sample-level prediction.
When using GAT with heads > 1, the GNN output is hidden_dim * heads.
"""
def __init__(
self,
model_type,
gnn_input_dim,
gnn_hidden_dim,
gnn_layer_num,
ae_encoding_dim,
nn_input_dim,
nn_hidden_dim1,
nn_hidden_dim2,
nn_output_dim,
gnn_dropout: float = 0.,
gnn_activation: str = "relu",
dim_reduction: str = "ae",
ae_architecture: str = "original",
gat_heads: int = 1,
):
super().__init__()
self.model_type = model_type
if model_type == "GCN":
self.gnn = GCN(
input_dim=gnn_input_dim,
hidden_dim=gnn_hidden_dim,
layer_num=gnn_layer_num,
final_layer="none",
dropout=gnn_dropout,
activation=gnn_activation,
)
gnn_out_dim = gnn_hidden_dim
elif model_type == "GAT":
self.gnn = GAT(
input_dim=gnn_input_dim,
hidden_dim=gnn_hidden_dim,
layer_num=gnn_layer_num,
final_layer="none",
dropout=gnn_dropout,
activation=gnn_activation,
heads=gat_heads,
)
# GAT output dim = hidden_dim * heads
gnn_out_dim = gnn_hidden_dim * gat_heads
elif model_type == "SAGE":
self.gnn = SAGE(
input_dim=gnn_input_dim,
hidden_dim=gnn_hidden_dim,
layer_num=gnn_layer_num,
final_layer="none",
dropout=gnn_dropout,
activation=gnn_activation,
)
gnn_out_dim = gnn_hidden_dim
elif model_type == "GIN":
self.gnn = GIN(
input_dim=gnn_input_dim,
hidden_dim=gnn_hidden_dim,
output_dim=gnn_hidden_dim,
layer_num=gnn_layer_num,
final_layer="none",
dropout=gnn_dropout,
activation=gnn_activation,
)
gnn_out_dim = gnn_hidden_dim
else:
raise ValueError(f"Unsupported GNN model type: {model_type}")
if dim_reduction == "ae":
self.autoencoder = AutoEncoder(
input_dim=gnn_out_dim, encoding_dim=1, architecture=ae_architecture
)
self.projection = nn.Identity()
elif dim_reduction == "linear":
self.autoencoder = AutoEncoder(
input_dim=gnn_out_dim, encoding_dim=ae_encoding_dim, architecture=ae_architecture
)
self.projection = ScalarProjection(encoding_dim=ae_encoding_dim)
elif dim_reduction == "mlp":
self.autoencoder = AutoEncoder(
input_dim=gnn_out_dim, encoding_dim=ae_encoding_dim, architecture=ae_architecture
)
self.projection = MLPProjection(encoding_dim=ae_encoding_dim, hidden_dim=8)
else:
raise ValueError(f"Unsupported dim_reduction: {dim_reduction}")
self.predictor = DownstreamTaskNN(
nn_input_dim, nn_hidden_dim1, nn_hidden_dim2, nn_output_dim
)
[docs]
def forward(self, omics_dataset, omics_network_tg, clinical_tensor=None):
# GNN embeddings
omics_network_nodes_embedding = self.gnn.get_embeddings(omics_network_tg)
# compress
omics_network_nodes_embedding_ae = self.autoencoder(omics_network_nodes_embedding)
# project to scalar weights
feature_weights = self.projection(omics_network_nodes_embedding_ae)
# reweight the original omics data (element-wise multiplication)
omics_dataset_with_embeddings = torch.mul(
omics_dataset,
feature_weights.expand(omics_dataset.shape[1], omics_dataset.shape[0]).t(),
)
# leaving clinical_tensor parameter as pontential addition after dicussion with colleagues.
# this way we would be making the final prediction on the scaled omics dataset + the clinical data.
if clinical_tensor is not None and clinical_tensor.shape[1] > 0:
predictor_input = torch.cat([omics_dataset_with_embeddings, clinical_tensor], dim=1)
else:
predictor_input = omics_dataset_with_embeddings
# predict
predictions = self.predictor(predictor_input)
return predictions, omics_dataset_with_embeddings, omics_network_nodes_embedding
"""
DPMON AutoEncoder & NeuralNetwork:
1. AutoEncoder: support both a hardcoded 3-layer encoder (input -> 8 -> 4 encoding_dim) and a 2-layer version (input -> input//2 -> encoding_dim).
2. NeuralNetwork: Supports a `correlation_mode` passthrough for prepare_node_features.
"""
[docs]
class AutoEncoder(nn.Module):
"""Compresses high-dimensional node embeddings into a lower-dimensional latent space.
Args:
input_dim: Input feature dimension (gnn_hidden_dim).
encoding_dim: Output latent dimension.
architecture: original or dynamic. "original" (input -> 8 -> 4 encoding_dim). "dynamic" (input -> input//2 -> encoding_dim).
"""
def __init__(self, input_dim: int, encoding_dim: int, architecture: str = "original"):
super().__init__()
if architecture == "original":
if encoding_dim == 1:
# fixed bottleneck
self.encoder = nn.Sequential(
nn.Linear(input_dim, 8),
nn.ReLU(),
nn.Linear(8, 4),
nn.ReLU(),
nn.Linear(4, 1),
)
else:
# EX1 if Tune picks 4: Flow is Input -> 8 -> 4. EX2: Tune picks 8: Flow is Input -> 16 -> 8
intermediate_dim = encoding_dim * 2
if intermediate_dim > input_dim:
intermediate_dim = input_dim
self.encoder = nn.Sequential(
nn.Linear(input_dim, intermediate_dim),
nn.ReLU(),
nn.Linear(intermediate_dim, encoding_dim),
)
elif architecture == "dynamic":
if encoding_dim == 1:
# 3-step funnel
h1 = max(input_dim // 2, 8)
h2 = max(input_dim // 4, 4)
self.encoder = nn.Sequential(
nn.Linear(input_dim, h1),
nn.ReLU(),
nn.Linear(h1, h2),
nn.ReLU(),
nn.Linear(h2, 1),
)
else:
# 2 step funnel for linear projections
hidden_dim = max(input_dim // 2, encoding_dim * 2)
if hidden_dim > input_dim:
hidden_dim = input_dim
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, encoding_dim),
)
else:
raise ValueError(f"Unknown architecture: {architecture}")
[docs]
def forward(self, x):
return self.encoder(x)
[docs]
class ScalarProjection(nn.Module):
def __init__(self, encoding_dim):
super().__init__()
self.proj = nn.Linear(encoding_dim, 1)
[docs]
def forward(self, x):
return self.proj(x)
[docs]
class MLPProjection(nn.Module):
def __init__(self, encoding_dim, hidden_dim=8):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(encoding_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
[docs]
def forward(self, x):
return self.mlp(x)
[docs]
class DownstreamTaskNN(nn.Module):
"""MLP for final prediction - outputs raw logits."""
def __init__(self, input_size, hidden_dim1, hidden_dim2, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_dim1)
self.bn1 = nn.BatchNorm1d(hidden_dim1)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.bn2 = nn.BatchNorm1d(hidden_dim2)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim2, output_dim)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.fc3(x)
return x