GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine
Paper
β’
2509.20935
β’
Published
β’
1
GALAX is a graph-augmented language model designed for explainable target prioritization in precision medicine. It combines three key components:
By jointly leveraging multi-omics features, proteinβprotein interactions, and diseaseβtarget associations, GALAX provides an interpretable framework for CRISPR target prioritization across diverse cancer cell lines. To support benchmarking and reproducibility, we also introduce the Target-QA dataset.
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import os, torch
# 1. Load GALAX language model
model_id = "FuhaiLiAiLab/GALAX"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto"
)
# 2. Access graph foundation model
repo_path = snapshot_download(model_id)
combined_model_path = os.path.join(repo_path, "best_combined_model.pt")
device = "cuda" if torch.cuda.is_available() else "cpu"
best_combined_model = torch.load(combined_model_path, map_location=device)
GALAX consistently outperforms baselines and ablation variants.
Table 1. Precision and Recall across datasets
| Model | Overall Precision β | Overall Recall β | LUAD Precision β | LUAD Recall β | BRCA Precision β | BRCA Recall β |
|---|---|---|---|---|---|---|
| M2T | 0.0016 | 0.0011 | 0.0020 | 0.0014 | 0.0000 | 0.0000 |
| GAT | 0.0006 Β± 0.0000 | 0.0006 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0033 Β± 0.0000 | 0.0033 Β± 0.0000 |
| L3 + Omics | 0.0071 Β± 0.0032 | 0.0013 Β± 0.0002 | 0.0079 Β± 0.0137 | 0.0005 Β± 0.0008 | 0.0020 Β± 0.0035 | 0.0017 Β± 0.0029 |
| L3 + Omics + KG | 0.0125 Β± 0.0032 | 0.0029 Β± 0.0003 | 0.0014 Β± 0.0025 | 0.0010 Β± 0.0017 | 0.0073 Β± 0.0068 | 0.0033 Β± 0.0029 |
| L3-FT(Med) + Omics | 0.0179 Β± 0.0045 | 0.0133 Β± 0.0064 | 0.0091 Β± 0.0018 | 0.0105 Β± 0.0044 | 0.0110 Β± 0.0086 | 0.0106 Β± 0.0075 |
| L3-FT(Med) + Omics + KG | 0.0158 Β± 0.0030 | 0.0058 Β± 0.0011 | 0.0081 Β± 0.0071 | 0.0024 Β± 0.0017 | 0.0149 Β± 0.0057 | 0.0050 Β± 0.0000 |
| L3-FT(QA) + Omics | 0.5250 Β± 0.0282 | 0.4959 Β± 0.0435 | 0.5201 Β± 0.0408 | 0.4905 Β± 0.0532 | 0.5074 Β± 0.0498 | 0.4856 Β± 0.0570 |
| L3-FT(QA) + Omics + KG | 0.5185 Β± 0.0240 | 0.4908 Β± 0.0402 | 0.5214 Β± 0.0242 | 0.4952 Β± 0.0432 | 0.4856 Β± 0.0395 | 0.4656 Β± 0.0436 |
| G-Retriever + pre-GAT | 0.4763 Β± 0.0004 | 0.3929 Β± 0.0063 | 0.4642 Β± 0.0181 | 0.3881 Β± 0.0264 | 0.4414 Β± 0.0099 | 0.3772 Β± 0.0010 |
| GALAX | 0.5472 Β± 0.0053 | 0.5332 Β± 0.0031 | 0.5345 Β± 0.0185 | 0.5157 Β± 0.0043 | 0.5608 Β± 0.0031 | 0.5533 Β± 0.0033 |
Table 2. Hit@10 and Hit@5 across datasets
| Model | Overall Hit@10 β | Overall Hit@5 β | LUAD Hit@10 β | LUAD Hit@5 β | BRCA Hit@10 β | BRCA Hit@5 β |
|---|---|---|---|---|---|---|
| M2T | 0.0029 | 0.0000 | 0.0000 | 0.0000 | 0.0000 | 0.0000 |
| GAT | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 |
| L3 + Omics | 0.0021 Β± 0.0037 | 0.0032 Β± 0.0055 | 0.0048 Β± 0.0082 | 0.0095 Β± 0.0165 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 |
| L3 + Omics + KG | 0.0122 Β± 0.0033 | 0.0085 Β± 0.0037 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0056 Β± 0.0096 | 0.0111 Β± 0.0192 |
| L3-FT(Med) + Omics | 0.0122 Β± 0.0072 | 0.0116 Β± 0.0097 | 0.0000 Β± 0.0000 | 0.0000 Β± 0.0000 | 0.0111 Β± 0.0192 | 0.0000 Β± 0.0000 |
| L3-FT(Med) + Omics + KG | 0.0132 Β± 0.0040 | 0.0106 Β± 0.0048 | 0.0048 Β± 0.0082 | 0.0095 Β± 0.0165 | 0.0111 Β± 0.0192 | 0.0000 Β± 0.0000 |
| L3-FT(QA) + Omics | 0.8693 Β± 0.0157 | 0.8889 Β± 0.0168 | 0.8667 Β± 0.0218 | 0.8476 Β± 0.0165 | 0.8389 Β± 0.0096 | 0.8889 Β± 0.0509 |
| L3-FT(QA) + Omics + KG | 0.8529 Β± 0.0153 | 0.8794 Β± 0.0114 | 0.8048 Β± 0.0541 | 0.7905 Β± 0.0436 | 0.8222 Β± 0.0347 | 0.8778 Β± 0.0192 |
| G-Retriever + pre-GAT | 0.8550 Β± 0.0046 | 0.8804 Β± 0.0037 | 0.8524 Β± 0.0165 | 0.8857 Β± 0.0000 | 0.8667 Β± 0.0000 | 0.8667 Β± 0.0000 |
| GALAX | 0.8815 Β± 0.0033 | 0.9249 Β± 0.0048 | 0.8810 Β± 0.0082 | 0.9238 Β± 0.0436 | 0.8500 Β± 0.0441 | 0.8889 Β± 0.0839 |
If you use this model, please cite:
@article{zhang2025galax,
title = {GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine},
author = {Zhang, Heming and Huang, Di and Li, Wenyu and Province, Michael and Chen, Yixin and Payne, Philip and Li, Fuhai},
journal = {arXiv preprint arXiv:2509.20935},
year = {2025},
doi = {10.48550/arXiv.2509.20935},
url = {https://arxiv.org/abs/2509.20935}
}