Pretrained Models

This directory contains pretrained VAE and reconstruction network models obtained during the WP3 of the EVENFLOW EU project.

These models have been trained on a pre-processed version of the bulk RNA-Seq TCGA datasets of either KIRC or BRCA, independently (see data availability in the respective section).

Available Models

KIRC (Kidney Renal Clear Cell Carcinoma)

Location: KIRC/

Data availability: Zenodo

Model Files:

  • 20250321_VAE_idim8516_md512_feat256mse_relu.pth - VAE weights
  • network_reconstruction.pth - Reconstruction network weights
  • network_dims.csv - Network architecture specifications

Model Specifications:

  • Input dimension: 8,516 genes
  • VAE architecture:
    • Middle dimension: 512
    • Latent dimension: 256
    • Loss function: MSE
    • Activation: ReLU
  • Reconstruction network: [8954, 3512, 824, 3731, 8954]
  • Training: Beta-VAE with 3 cycles, 600 epochs total

BRCA (Breast Invasive Carcinoma)

Location: BRCA/

Data availability: Zenodo

Model Files:

  • 20251209_VAE_idim8954_md1024_feat512mse_relu.pth - VAE weights
  • network_reconstruction.pth - Reconstruction network weights
  • network_dims.csv - Network architecture specifications

Model Specifications:

  • Input dimension: 8,954 genes
  • VAE architecture:
    • Middle dimension: 1,024
    • Latent dimension: 512
    • Loss function: MSE
    • Activation: ReLU
  • Reconstruction network: [8954, 3104, 790, 4027, 8954]
  • Training: Beta-VAE with 3 cycles, 600 epochs total

Usage

Loading Models in Python

See renalprog for the needed VAE and NetworkReconstruction objects.

import torch
import pandas as pd
import json
from pathlib import Path
import huggingface_hub as hf
from renalprog.modeling.train import VAE, NetworkReconstruction

# Configuration
cancer_type = "KIRC"  # or "BRCA"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================================
# Load VAE Model
# ============================================================================

# Download VAE config
vae_config_path = hf.hf_hub_download(
    repo_id="gprolcastelo/evenflow_models",
    filename=f"{cancer_type}/config.json"
)

# Load configuration
with open(vae_config_path, "r") as f:
    vae_config = json.load(f)

print(f"VAE Configuration: {vae_config}")

# Download VAE model weights
if cancer_type == "KIRC":
    vae_filename = "KIRC/20250321_VAE_idim8516_md512_feat256mse_relu.pth"
elif cancer_type == "BRCA":
    vae_filename = "BRCA/20251209_VAE_idim8954_md1024_feat512mse_relu.pth"
else:
    raise ValueError(f"Unknown cancer type: {cancer_type}")

vae_model_path = hf.hf_hub_download(
    repo_id="gprolcastelo/evenflow_models",
    filename=vae_filename
)

# Initialize and load VAE
model_vae = VAE(
    input_dim=vae_config["INPUT_DIM"],
    mid_dim=vae_config["MID_DIM"],
    features=vae_config["LATENT_DIM"]
).to(device)

checkpoint_vae = torch.load(vae_model_path, map_location=device, weights_only=False)
model_vae.load_state_dict(checkpoint_vae)
model_vae.eval()

print(f"VAE model loaded successfully from {cancer_type}")

# ============================================================================
# Load Reconstruction Network
# ============================================================================

# Download network dimensions
network_dims_path = hf.hf_hub_download(
    repo_id="gprolcastelo/evenflow_models",
    filename=f"{cancer_type}/network_dims.csv"
)

# Load network dimensions
network_dims = pd.read_csv(network_dims_path)
layer_dims = network_dims.values.tolist()[0]

print(f"Reconstruction Network dimensions: {layer_dims}")

# Download reconstruction network weights
recnet_model_path = hf.hf_hub_download(
    repo_id="gprolcastelo/evenflow_models",
    filename=f"{cancer_type}/network_reconstruction.pth"
)

# Initialize and load Reconstruction Network
model_recnet = NetworkReconstruction(layer_dims=layer_dims).to(device)
checkpoint_recnet = torch.load(recnet_model_path, map_location=device, weights_only=False)
model_recnet.load_state_dict(checkpoint_recnet)
model_recnet.eval()

print(f"Reconstruction Network loaded successfully from {cancer_type}")

# ============================================================================
# Use the models
# ============================================================================

# Example: Apply VAE to your data
# your_data = torch.tensor(your_data_array).float().to(device)
# with torch.no_grad():
#     vae_output = model_vae(your_data)
#     recnet_output = model_recnet(vae_output)

Citation

โš ๏ธ Warning
This citation is temporary. It will be updated when a pre-print is released.

If you use these pretrained models, please cite:

@software{renalprog2024,
  title = {RenalProg: A Deep Learning Framework for Kidney Cancer Progression Modeling},
  author = {[Guillermo Prol-Castelo, Elina Syrri, Nikolaos Manginas, Vasileos Manginas, Nikos Katzouris, Davide Cirillo, George Paliouras, Alfonso Valencia]},
  year = {2025},
  url = {https://github.com/gprolcas/renalprog},
  note = {Preprint in preparation}
}

Training Details

These models were trained using:

  • Random seed: 2023
  • Train/test split: 80/20
  • Optimizer: Adam
  • Learning rate: 1e-4
  • Batch size: 8
  • Beta annealing (for VAE): 3 cycles with 0.5 ratio

Model Performance

KIRC Model:

  • Reconstruction loss (test): ~1.1

BRCA Model:

  • Reconstruction loss (test): ~0.9

License

These pretrained models are provided under the same Apache 2.0 license.

Contact

For questions about the pretrained models, please:

  1. Check the documentation
  2. Open an issue on GitHub
  3. Contact the authors

Last Updated: December 2025 Version: 1.0.0-alpha

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support