Source code for bioneuralnet.datasets.dataset_loader
from pathlib import Path
import pandas as pd
[docs]
class DatasetLoader:
"""Load a pre-packaged multi-omics dataset from the package.
Options for 'dataset_name':
"example": Synthetic example.
"monet": Synthetic example.
"brca": Breast invasive carcinoma.
"lgg": Brain Lower Grade Glioma.
"kipan": Pan-kidney carcinoma.
Args:
dataset_name (str): Normalized dataset name.
base_dir (Path): Directory where the dataset folders live.
data (dict[str, pd.DataFrame]): Mapping from table name to loaded DataFrame.
"""
def __init__(self, dataset_name: str):
self.dataset_name = dataset_name.strip().lower()
self.base_dir = Path(__file__).parent
self.data: dict[str, pd.DataFrame] = {}
self._load_data()
def __getitem__(self, key):
return self.data[key]
def _load_data(self):
"""Internal helper to populate ``self.data`` from CSV files for the given dataset."""
folder = self.base_dir / self.dataset_name
if not folder.is_dir():
raise FileNotFoundError(f"Dataset folder '{folder}' not found.")
if self.dataset_name == "example":
self.data = {
"X1": pd.read_csv(folder / "X1.csv", index_col=0),
"X2": pd.read_csv(folder / "X2.csv", index_col=0),
"Y": pd.read_csv(folder / "Y.csv", index_col=0),
"clinical": pd.read_csv(folder / "clinical.csv", index_col=0),
}
elif self.dataset_name == "monet":
self.data = {
"gene": pd.read_csv(folder / "gene.csv"),
"mirna": pd.read_csv(folder / "mirna.csv"),
"phenotype": pd.read_csv(folder / "phenotype.csv"),
"rppa": pd.read_csv(folder / "rppa.csv"),
"clinical": pd.read_csv(folder / "clinical.csv"),
}
elif self.dataset_name == "brca":
self.data["mirna"] = pd.read_csv(folder / "mirna.csv", index_col=0)
self.data["target"] = pd.read_csv(folder / "target.csv", index_col=0)
self.data["clinical"] = pd.read_csv(folder / "clinical.csv", index_col=0)
self.data["rna"] = pd.read_csv(folder / "rna.csv", index_col=0)
self.data["methylation"] = pd.read_csv(folder / "methylation.csv", index_col=0)
elif self.dataset_name == "lgg":
self.data["mirna"] = pd.read_csv(folder / "mirna.csv", index_col=0)
self.data["target"] = pd.read_csv(folder / "target.csv", index_col=0)
self.data["clinical"] = pd.read_csv(folder / "clinical.csv", index_col=0)
self.data["rna"] = pd.read_csv(folder / "rna.csv", index_col=0)
self.data["methylation"] = pd.read_csv(folder / "methylation.csv", index_col=0)
elif self.dataset_name == "kipan":
self.data["mirna"] = pd.read_csv(folder / "mirna.csv", index_col=0)
self.data["target"] = pd.read_csv(folder / "target.csv", index_col=0)
self.data["clinical"] = pd.read_csv(folder / "clinical.csv", index_col=0)
self.data["rna"] = pd.read_csv(folder / "rna.csv", index_col=0)
self.data["methylation"] = pd.read_csv(folder / "methylation.csv", index_col=0)
else:
raise ValueError(f"Dataset '{self.dataset_name}' is not recognized.")
@property
def shape(self) -> dict[str, tuple[int, int]]:
"""Dictionary mapping each table name to its (n_rows, n_cols) shape."""
result: dict[str, tuple[int, int]] = {}
for name, df in self.data.items():
result[name] = df.shape
return result