import torch
import numpy as np
import itertools
from typing import List, Union, Optional, Dict, Any
from .math_helpers import binary_search, soft, l2n, r_scale_torch
[docs]
def my_get_crit(xlist: List[torch.Tensor], ws: List[torch.Tensor], pair_cc: Union[np.ndarray, torch.Tensor], cc_coef: Union[np.ndarray, torch.Tensor]) -> float:
"""PyTorch version of my_get_crit computing SmCCA objective on GPU.
Args:
xlist (List[torch.Tensor]): List of input data matrices.
ws (List[torch.Tensor]): List of weight vectors.
pair_cc (np.ndarray | torch.Tensor): Matrix of pair indices.
cc_coef (np.ndarray | torch.Tensor): Vector of scaling coefficients.
Returns:
float: Computed objective function value.
"""
# device management
device = xlist[0].device
if not isinstance(cc_coef, torch.Tensor):
cc_coef = torch.tensor(cc_coef, device=device, dtype=torch.float32)
if isinstance(pair_cc, torch.Tensor):
pair_cc = pair_cc.cpu().numpy()
num_pairs = pair_cc.shape[1]
crits = []
# iterate pairs
for k in range(num_pairs):
i = int(pair_cc[0, k])
j = int(pair_cc[1, k])
proj_i = torch.matmul(xlist[i], ws[i])
proj_j = torch.matmul(xlist[j], ws[j])
val = torch.matmul(proj_i.t(), proj_j)
crits.append(val.view(-1))
# weighted sum
if len(crits) > 0:
crits_vec = torch.cat(crits)
crit = torch.sum(crits_vec * cc_coef)
else:
crit = torch.tensor(0.0, device=device)
return crit.item()
[docs]
def my_get_cors(xlist: List[torch.Tensor], ws: List[torch.Tensor], pair_cc: Union[np.ndarray, torch.Tensor], cc_coef: Union[np.ndarray, torch.Tensor]) -> float:
"""PyTorch version of my_get_cors computing total weighted canonical correlations.
Args:
xlist (List[torch.Tensor]): List of input data matrices.
ws (List[torch.Tensor]): List of weight vectors.
pair_cc (np.ndarray | torch.Tensor): Matrix of pair indices.
cc_coef (np.ndarray | torch.Tensor): Vector of scaling coefficients.
Returns:
float: Total weighted correlation value.
"""
device = xlist[0].device
if not isinstance(cc_coef, torch.Tensor):
cc_coef = torch.tensor(cc_coef, device=device, dtype=torch.float32)
if isinstance(pair_cc, torch.Tensor):
pair_cc = pair_cc.cpu().numpy()
num_pairs = pair_cc.shape[1]
ccs = []
for k in range(num_pairs):
i = int(pair_cc[0, k])
j = int(pair_cc[1, k])
# calculate projections
u = torch.matmul(xlist[i], ws[i]).flatten()
v = torch.matmul(xlist[j], ws[j]).flatten()
# pearson correlation
u_mean = torch.mean(u)
v_mean = torch.mean(v)
u_centered = u - u_mean
v_centered = v - v_mean
numerator = torch.sum(u_centered * v_centered)
denom_u = torch.sqrt(torch.sum(u_centered**2))
denom_v = torch.sqrt(torch.sum(v_centered**2))
denominator = denom_u * denom_v
if denominator == 0:
corr = torch.tensor(0.0, device=device)
else:
corr = numerator / denominator
if torch.isnan(corr):
corr = torch.tensor(0.0, device=device)
ccs.append(corr)
# weighted sum
if len(ccs) > 0:
ccs_vec = torch.stack(ccs)
cors = torch.sum(ccs_vec * cc_coef)
else:
cors = torch.tensor(0.0, device=device)
return cors.item()
[docs]
def my_update_w(xlist: List[torch.Tensor], i: int, K: int, sumabsthis: float, ws: List[torch.Tensor], ws_final: List[torch.Tensor], pair_cc: Union[np.ndarray, torch.Tensor], cc_coef: Union[np.ndarray, torch.Tensor, List[float]], type: str = "standard") -> torch.Tensor:
"""PyTorch version of my_update_w using hybrid GPU/CPU approach.
Args:
xlist (List[torch.Tensor]): List of input data matrices.
i (int): Index of the current omics layer to update.
K (int): Number of latent components.
sumabsthis (float): Sparsity penalty (L1 sum constraint) for this layer.
ws (List[torch.Tensor]): Current weight vectors.
ws_final (List[torch.Tensor]): Final weight vectors from previous components.
pair_cc (np.ndarray | torch.Tensor): Matrix of pair indices.
cc_coef (np.ndarray | torch.Tensor | List[float]): Scaling coefficients.
type (str): Analysis type; currently supports 'standard'.
Returns:
torch.Tensor: Updated weight vector for the i-th layer.
"""
device = xlist[i].device
dtype = xlist[i].dtype
if isinstance(cc_coef, torch.Tensor):
cc_coef_list = cc_coef.cpu().tolist()
elif isinstance(cc_coef, np.ndarray):
cc_coef_list = cc_coef.tolist()
else:
cc_coef_list = list(cc_coef)
if isinstance(pair_cc, torch.Tensor):
pair_cc = pair_cc.cpu().numpy()
else:
pair_cc = np.array(pair_cc)
tots = 0
num_pairs = len(cc_coef_list)
# phase 1: matrix operations (gpu)
for x in range(num_pairs):
pairx = pair_cc[:, x]
if pairx[0] != i and pairx[1] != i:
continue
else:
if pairx[0] == i:
j = int(pairx[1])
elif pairx[1] == i:
j = int(pairx[0])
Xi = xlist[i]
Xj = xlist[j]
diagmat = torch.matmul(
torch.matmul(ws_final[i].t(), Xi.t()),
torch.matmul(Xj, ws_final[j])
)
diagmat = torch.diag(torch.diag(diagmat))
term1 = torch.matmul(Xi.t(), torch.matmul(Xj, ws[j]))
term2_inner = torch.matmul(diagmat, torch.matmul(ws_final[j].t(), ws[j]))
term2 = torch.matmul(ws_final[i], term2_inner)
y = term1 - term2
y = y * cc_coef_list[x]
tots = tots + y
# phase 2: scalar optimization (cpu)
if type == "standard":
tots_cpu = tots.cpu().numpy()
sumabsthis = binary_search(tots_cpu, sumabsthis)
numerator = soft(tots_cpu, sumabsthis)
denominator = l2n(numerator)
w_cpu = numerator / denominator
w = torch.tensor(w_cpu, device=device, dtype=dtype)
else:
raise ValueError("Current version requires all element types to be standard (not ordered).")
return w
[docs]
def my_multi_cca(xlist: List[torch.Tensor], penalty: Optional[Union[float, List[float], np.ndarray]] = None, ws: Optional[List[torch.Tensor]] = None, niter: int = 25, type: str = "standard", ncomponents: int = 1, trace: bool = True, standardize: bool = True, cc_coef: Optional[Union[List[float], np.ndarray, torch.Tensor]] = None) -> Dict[str, Any]:
"""PyTorch version of my_multi_cca performing sparse multiple canonical correlation analysis (SmCCA) on GPU.
Args:
xlist (List[torch.Tensor]): List of data matrices (torch tensors on the target device).
penalty (float | List[float] | np.ndarray | None): Penalty parameters for each omics layer.
ws (List[torch.Tensor] | None): Initial weight vectors; if None, initialized via SVD.
niter (int): Maximum number of iterations for convergence.
type (str): Analysis type; currently supports "standard".
ncomponents (int): Number of canonical components to extract.
trace (bool): If True, print iteration progress.
standardize (bool): If True, standardize input data before analysis.
cc_coef (List | np.ndarray | torch.Tensor | None): Scaling coefficients for between-omics pairs.
Returns:
Dict[str, Any]: Dictionary containing weights ('ws'), initialization ('ws_init'), and correlations ('cors').
"""
# device management
device = xlist[0].device
dtype = xlist[0].dtype
last_dims = xlist[-1].shape
last_ncol = last_dims[1] if len(last_dims) > 1 else 1
# branch 1: standard multicca
if last_ncol > 1:
K = len(xlist)
pairs = list(itertools.combinations(range(K), 2))
pair_cc = np.array(pairs).T
num_cc = pair_cc.shape[1]
if cc_coef is None:
cc_coef = torch.ones(num_cc, device=device, dtype=dtype)
else:
if not isinstance(cc_coef, torch.Tensor):
cc_coef = torch.tensor(cc_coef, device=device, dtype=dtype)
if cc_coef.numel() != num_cc:
raise ValueError(f"Invalid coefficients. Provide {num_cc} values.")
if isinstance(type, str):
if type != "standard":
raise ValueError("Phenotype data must be continuous/standard.")
type_vec = np.array([type] * K)
else:
type_vec = np.array(type)
if len(type_vec) != K:
raise ValueError("Type must be vector of length 1 or K.")
if standardize:
xlist = [r_scale_torch(x) for x in xlist]
if ws is not None:
make_null = False
for i in range(K):
if ws[i].shape[1] < ncomponents:
make_null = True
if make_null:
ws = None
if ws is None:
ws = []
for i in range(K):
U, S, Vh = torch.linalg.svd(xlist[i], full_matrices=False)
V = Vh.T
ws.append(V[:, :ncomponents])
ws_init = [w.clone() for w in ws]
if penalty is None:
penalty = np.full(K, np.nan)
for k in range(K):
if type_vec[k] == "standard":
penalty[k] = 4
if np.ndim(penalty) == 0:
penalty = np.full(K, penalty)
else:
penalty = np.array(penalty)
ws_final = [w.clone() for w in ws_init]
for i in range(K):
ws_final[i] = torch.zeros((xlist[i].shape[1], ncomponents), device=device, dtype=dtype)
cors = []
# optimization loop
for comp in range(ncomponents):
ws_curr = []
for i in range(K):
ws_curr.append(ws_init[i][:, comp].reshape(-1, 1))
curiter = 1
crit_old = -10.0
crit = -20.0
storecrits = []
while (curiter <= niter and abs(crit_old - crit) / abs(crit_old) > 0.001 and crit_old != 0):
crit_old = crit
crit = my_get_crit(xlist, ws_curr, pair_cc, cc_coef)
storecrits.append(crit)
if trace:
print(curiter, end=" ", flush=True)
curiter += 1
for i in range(K):
ws_curr[i] = my_update_w(xlist, i, K, penalty[i], ws_curr, ws_final, pair_cc, cc_coef, type=type_vec[i])
if trace:
print("")
for i in range(K):
ws_final[i][:, comp] = ws_curr[i].flatten()
cors.append(my_get_cors(xlist, ws_curr, pair_cc, cc_coef))
return {
"ws": ws_final,
"ws_init": ws_init,
"K": K,
"type": type_vec,
"penalty": penalty,
"cors": cors
}
else:
# branch 2: phenotype included
K = len(xlist)
pairs = list(itertools.combinations(range(K), 2))
pair_cc = np.array(pairs).T
num_cc = pair_cc.shape[1]
if cc_coef is None:
cc_coef = torch.ones(num_cc, device=device, dtype=dtype)
else:
if not isinstance(cc_coef, torch.Tensor):
cc_coef = torch.tensor(cc_coef, device=device, dtype=dtype)
if isinstance(type, str):
if type != "standard":
raise ValueError("Phenotype data must be continuous/standard.")
type_vec = np.array([type] * K)
else:
type_vec = np.array(type)
if standardize:
xlist = [r_scale_torch(x) for x in xlist]
if ws is not None:
make_null = False
for i in range(K - 1):
if ws[i].shape[1] < ncomponents:
make_null = True
if make_null:
ws = None
if ws is None:
ws = []
for i in range(K - 1):
U, S, Vh = torch.linalg.svd(xlist[i], full_matrices=False)
V = Vh.T
ws.append(V[:, :ncomponents])
# phenotype weight fixed at 1.0
ws.append(torch.tensor([[1.0]], device=device, dtype=dtype))
ws_init = [w.clone() for w in ws]
if penalty is None:
penalty = np.full(K, np.nan)
for k in range(K):
if type_vec[k] == "standard":
penalty[k] = 4
if np.ndim(penalty) == 0:
penalty = np.full(K, penalty)
else:
penalty = np.array(penalty)
ws_final = [w.clone() for w in ws_init]
for i in range(K - 1):
ws_final[i] = torch.zeros((xlist[i].shape[1], ncomponents), device=device, dtype=dtype)
cors = []
# optimization loop
for comp in range(ncomponents):
ws_curr = []
for i in range(K - 1):
ws_curr.append(ws_init[i][:, comp].reshape(-1, 1))
ws_curr.append(torch.tensor([[1.0]], device=device, dtype=dtype))
curiter = 1
crit_old = -10.0
crit = -20.0
storecrits = []
while (curiter <= niter and abs(crit_old - crit) / abs(crit_old) > 0.001 and crit_old != 0):
crit_old = crit
crit = my_get_crit(xlist, ws_curr, pair_cc, cc_coef)
storecrits.append(crit)
if trace:
print(curiter, end=" ", flush=True)
curiter += 1
# update k-1 datasets
for i in range(K - 1):
ws_curr[i] = my_update_w(xlist, i, K, penalty[i], ws_curr, ws_final, pair_cc, cc_coef, type=type_vec[i])
if trace:
print("")
for i in range(K - 1):
ws_final[i][:, comp] = ws_curr[i].flatten()
cors.append(my_get_cors(xlist, ws_curr, pair_cc, cc_coef))
return {
"ws": ws_final,
"ws_init": ws_init,
"K": K,
"type": type_vec,
"penalty": penalty,
"cors": cors
}