Source code for bioneuralnet.network.pysmccnet.wrappers

import torch
import numpy as np
from tqdm import tqdm
from typing import List, Union, Optional, Dict

# internal imports
from .math_helpers import r_scale_torch, _splsda, r_scale
from .core import my_multi_cca
import statsmodels.api as sm

[docs] def get_can_weights_multi(X: List[torch.Tensor], Trait: Optional[torch.Tensor] = None, Lambda: Optional[Union[List[float], np.ndarray]] = None, cc_coef: Optional[Union[List[float], np.ndarray, torch.Tensor]] = None, no_trait: bool = True, trace: bool = False, trait_weight: bool = False) -> List[torch.Tensor]: """PyTorch version of get_can_weights_multi wrapper. Args: X (List[torch.Tensor]): List of input data matrices on target device. Trait (torch.Tensor | None): Optional trait data tensor. Lambda (List[float] | np.ndarray | None): Penalty parameters; required. cc_coef (List | np.ndarray | torch.Tensor | None): Pairwise correlation coefficients. no_trait (bool): If True, run unsupervised CCA; if False, include trait. trace (bool): If True, print trace info during optimization. trait_weight (bool): If True, return trait weights in output list. Returns: List[torch.Tensor]: List of weight tensors for each input matrix. """ # input validation if Lambda is None: raise ValueError("Lambda must be provided.") Lambda = np.atleast_1d(np.array(Lambda, dtype=float)) for lam in Lambda: if abs(lam - 0.5) > 0.5: raise ValueError("Invalid penalty parameter. Lambda1 needs to be between zero and one.") if np.min(Lambda) == 0: raise ValueError("Invalid penalty parameter. Both Lambda1 and Lambda2 has to be greater than 0.") # penalty calculation current_X = list(X) L = [] for i in range(len(current_X)): ncol = current_X[i].shape[1] val = max(1, np.sqrt(ncol) * Lambda[i]) L.append(val) # cca execution if no_trait: out = my_multi_cca(current_X, penalty=L, cc_coef=cc_coef, trace=trace) else: if Trait is None: raise ValueError("Trait must be provided if no_trait is False.") scaled_trait = r_scale_torch(Trait) current_X.append(scaled_trait) trait_ncol = scaled_trait.shape[1] L.append(np.sqrt(trait_ncol)) out = my_multi_cca(current_X, penalty=L, cc_coef=cc_coef, trace=trace) # output extraction if trait_weight: ws = out['ws'] else: if no_trait: ws = out['ws'] else: ws = out['ws'][:-1] return ws
[docs] def get_robust_weights_multi(X: List[torch.Tensor], Trait: Optional[torch.Tensor], Lambda: Union[List[float], np.ndarray], s: Optional[Union[List[float], np.ndarray]] = None, no_trait: bool = False, subsampling_num: int = 1000, cc_coef: Optional[np.ndarray] = None, trace: bool = False, trait_weight: bool = False) -> torch.Tensor: """PyTorch version of get_robust_weights_multi with subsampling loop. Args: X (List[torch.Tensor]): List of input data matrices on target device. Trait (torch.Tensor | None): Trait data tensor or None. Lambda (List[float] | np.ndarray): Penalty parameters for CCA/PLS. s (List[float] | np.ndarray | None): Subsampling proportions for each omics layer. no_trait (bool): If True, compute weights without using Trait information (unsupervised). subsampling_num (int): Number of subsampling iterations to perform. cc_coef (np.ndarray | None): Scaling coefficients for between-omics relationships. trace (bool): If True, print trace information during execution. trait_weight (bool): If True, include trait weights in the output. Returns: torch.Tensor: Matrix of weights with shape (total_features, subsampling_num) on the same device as X. """ # device management device = X[0].device dtype = X[0].dtype if s is None: raise ValueError("s (subsampling proportions) must be provided.") s = np.atleast_1d(np.array(s, dtype=float)) if s.size == 1 and len(X) > 1: s = np.repeat(s, len(X)) Lambda = np.atleast_1d(np.array(Lambda, dtype=float)) # validation checks if np.sum(s == 0) > 1: raise ValueError("Subsampling proportion needs to be greater than zero.") else: if np.sum(np.abs(s - 0.5) > 0.5) > 0: raise ValueError("Subsampling proportions can not exceed one.") if (np.sum(np.abs(Lambda - 0.5) > 0.5) > 0) or (np.sum(Lambda == 0) > 0): raise ValueError("Invalid penalty parameter. Lambda1 needs to be between zero and one.") # setup dimensions p_data = np.array([x.shape[1] for x in X]) p = int(np.sum(p_data)) p_sub = np.ceil(p_data * s).astype(int) # subsampling loop results = [] iter_range = range(subsampling_num) if subsampling_num > 1: iter_range = tqdm(range(subsampling_num), desc="Robust Weights") for _ in iter_range: # sampling samp = [] for h in range(len(p_data)): indices = np.random.choice(p_data[h], p_sub[h], replace=False) indices.sort() samp.append(indices) # subset and scale x_par = [] for h in range(len(p_data)): subset = X[h][:, samp[h]] x_par.append(r_scale_torch(subset)) # compute weights if Trait is not None: out = get_can_weights_multi(x_par, Trait, Lambda, no_trait=no_trait, trace=trace, cc_coef=cc_coef) else: out = get_can_weights_multi(x_par, None, Lambda, no_trait=True, trace=trace, cc_coef=cc_coef) # reconstruct weight vector w = torch.zeros(p, device=device, dtype=dtype) p_cum = np.insert(np.cumsum(p_data), 0, 0) for h in range(len(p_cum) - 1): global_indices = samp[h] + int(p_cum[h]) idx = torch.tensor(global_indices, device=device, dtype=torch.long) val = out[h] if not isinstance(val, torch.Tensor): val = torch.tensor(np.array(val).flatten(), device=device, dtype=dtype) else: val = val.flatten() w[idx] = val results.append(w) # stack results beta = torch.stack(results, dim=1) return beta
[docs] def get_robust_weights_single_binary(X1: np.ndarray, Trait: np.ndarray, Lambda1: float, s1: float = 0.7, subsampling_num: int = 1000, K: int = 3) -> np.ndarray: """Compute aggregated sparse PLS-DA canonical weights for single omics data with binary phenotype. Args: X1 (np.ndarray): Input data matrix of shape (n, p1). Trait (np.ndarray): Binary phenotype vector (0/1) of shape (n,). Lambda1 (float): LASSO penalty parameter for SPLSDA; between 0 and 1. s1 (float): Proportion of features to subsample per iteration. subsampling_num (int): Number of subsampling iterations. K (int): Number of latent components for PLS-DA. Returns: np.ndarray: Weight matrix of shape (p1, subsampling_num). """ X1 = np.array(X1, dtype=float) Trait = np.array(Trait).flatten() p1 = X1.shape[1] p1_sub = int(np.ceil(s1 * p1)) results = [] iter_range = range(subsampling_num) if subsampling_num > 1: iter_range = tqdm(range(subsampling_num), desc="Single Binary Weights") for _ in iter_range: # subsample features samp1 = np.sort(np.random.choice(p1, p1_sub, replace=False)) # scale subsampled data x1_par = r_scale(X1[:, samp1]) # run sparse pls-da out = _splsda(x=x1_par, y=Trait, K=K, eta=Lambda1, kappa=0.5, scale_x=False) u = np.zeros(p1_sub) w = np.zeros(p1) T_scores = out['T'] W_weights = out['W'] A_indices = out['A'] # fit logistic regression on latent factors try: model = sm.GLM(Trait, T_scores, family=sm.families.Binomial()) result = model.fit(disp=0) glm_coefs = result.params except Exception: results.append(np.zeros(p1)) continue # compute weights u[A_indices] = np.abs(W_weights) @ np.abs(glm_coefs) # normalize norm_val = np.linalg.norm(u[A_indices]) if norm_val > 0: u[A_indices] = u[A_indices] / norm_val # scatter back to full feature space w[samp1] = u results.append(w) beta = np.column_stack(results) return beta
[docs] def get_robust_weights_multi_binary(X: List[torch.Tensor], Y: Union[np.ndarray, torch.Tensor], between_discriminate_ratio: Optional[Union[List[float], np.ndarray]] = None, subsampling_percent: Optional[Union[List[float], np.ndarray]] = None, cc_coef: Optional[np.ndarray] = None, lambda_between: Optional[Union[List[float], np.ndarray]] = None, lambda_pheno: Optional[float] = None, subsampling_num: int = 1000, ncomp_pls: int = 3, eval_classifier: bool = False, test_data: Optional[List[Union[np.ndarray, torch.Tensor]]] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: """PyTorch version of get_robust_weights_multi_binary using hybrid GPU/CPU execution. Args: X (List[torch.Tensor]): List of omics data matrices (torch tensors on target device). Y (np.ndarray | torch.Tensor): Binary phenotype vector. between_discriminate_ratio (List[float] | np.ndarray | None): Ratio for weighting between-omics vs omics-phenotype contributions. subsampling_percent (List[float] | np.ndarray | None): Proportion of features to subsample per iteration. cc_coef (np.ndarray | None): Pairwise correlation coefficients. lambda_between (List[float] | np.ndarray | None): Penalty terms for between-omics CCA. lambda_pheno (float | None): Penalty term for omics-phenotype PLS. subsampling_num (int): Number of subsampling iterations. ncomp_pls (int): Number of latent components for PLS. eval_classifier (bool): If True, return projections for classifier evaluation instead of weights. test_data (List[np.ndarray | torch.Tensor] | None): Test data matrices required if eval_classifier is True. Returns: np.ndarray | Dict: Weight matrix (if eval_classifier=False) or dictionary containing train/test projections. """ if between_discriminate_ratio is None: between_discriminate_ratio = [1, 1] between_discriminate_ratio = np.array(between_discriminate_ratio, dtype=float) if lambda_between is None: raise ValueError("lambda_between must be provided.") lambda_between = np.array(lambda_between) eta = lambda_pheno # ensure Y is numpy if isinstance(Y, torch.Tensor): Y_np = Y.cpu().numpy().flatten() else: Y_np = np.array(Y).flatten() # step 1: between-omics smcca (gpu) between_omics_weight = get_robust_weights_multi( X, Trait=None, Lambda=lambda_between, s=subsampling_percent, no_trait=True, cc_coef=cc_coef, subsampling_num=subsampling_num ) # move to cpu for pls-da between_omics_weight = between_omics_weight.cpu().numpy() # column-bind all omics X_all = np.hstack([x.cpu().numpy() for x in X]) # feature type index type_index = np.concatenate([np.full(X[h].shape[1], h) for h in range(len(X))]) if not eval_classifier: # branch a: network construction n_subsamples = between_omics_weight.shape[1] p_total = between_omics_weight.shape[0] omics_pheno_weight = np.zeros_like(between_omics_weight) for iii in tqdm(range(n_subsamples), desc="Omics-Phenotype PLS"): selected_mask = between_omics_weight[:, iii] != 0 selected_indices = np.where(selected_mask)[0] if len(selected_indices) == 0: continue X_subset = X_all[:, selected_indices] try: # cpu: small matrix pls-da Ws_pheno = get_robust_weights_single_binary( X1=X_subset, Trait=Y_np.reshape(-1, 1), Lambda1=float(eta), s1=1.0, subsampling_num=1, K=ncomp_pls ) except Exception: continue omics_pheno_weight[selected_indices, iii] = Ws_pheno.flatten() # normalize per data type for j in range(len(X)): type_mask = type_index == j norm_val = np.linalg.norm(omics_pheno_weight[type_mask, iii]) if norm_val > 0: omics_pheno_weight[type_mask, iii] /= norm_val # zero out between-omics where pheno is zero between_omics_weight[omics_pheno_weight == 0] = 0 # remove zero/nan columns if subsampling_num > 1: zero_cols = [] for col_idx in range(omics_pheno_weight.shape[1]): col = omics_pheno_weight[:, col_idx] if np.all(col == 0) or np.any(np.isnan(col)): zero_cols.append(col_idx) if len(zero_cols) > 0: keep_cols = np.setdiff1d(np.arange(omics_pheno_weight.shape[1]), zero_cols) between_omics_weight = between_omics_weight[:, keep_cols] omics_pheno_weight = omics_pheno_weight[:, keep_cols] # aggregate ratio_sum = np.sum(between_discriminate_ratio) w1 = between_discriminate_ratio[0] / ratio_sum w2 = between_discriminate_ratio[1] / ratio_sum cc_weight = w1 * between_omics_weight + w2 * omics_pheno_weight return cc_weight else: # branch b: classifier evaluation if subsampling_num != 1: raise ValueError("Subsampling number must be 1 when evaluating the classifier.") if test_data is None: raise ValueError("test_data must be provided when eval_classifier=True.") selected_mask = between_omics_weight[:, 0] != 0 selected_indices = np.where(selected_mask)[0] X_subset = X_all[:, selected_indices] try: out = _splsda(x=r_scale(X_subset), y=Y_np, K=ncomp_pls, eta=lambda_pheno, kappa=0.5, scale_x=False) out_data = np.zeros((X_subset.shape[1], ncomp_pls)) out_data[out['A'], :] = out['W'] # process test data X_all_test = np.hstack([t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) for t in test_data]) X_subset_test = X_all_test[:, selected_indices] out_test = X_subset_test @ out_data out_train = out['T'] return {'out_train': out_train, 'out_test': out_test} except Exception as e: print(f"Caught an error: {e}") n_train = X_all.shape[0] n_test = np.hstack([t.cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t) for t in test_data]).shape[0] out_train = np.zeros((n_train, ncomp_pls)) out_test = np.zeros((n_test, ncomp_pls)) return {'out_train': out_train, 'out_test': out_test}