ROMAN β Scene Graph Node Relevance Classification
Given a 3D scene graph and a natural-language navigation constraint (e.g., "avoid the kitchen"), classify each node as relevant or non-relevant for cost-function generation used by a PRM path planner.
Two model families are provided:
- Binary-relevance (V11T, V12, V13) β single per-node probability.
- Three-objective (V11ALL, V12ALL, V13ALL) β per-node tuple
(relevance, weight β [-1,+1], sigma β₯ 0.8)consumed by the PRM Gaussian cost fieldf(x) = Ξ£_{i: rel_i=1} w_i Β· exp(-βxβp_iβΒ² / 2Ο_iΒ²).
Architectural detail for each family lives in MODEL_ARCHITECTURES.md
(PDF also provided).
Models in this repo
V12 β Graph neural network on enriched V4E features
Lightweight GNN on 1262-dim enriched node features (label & parent-room embeddings, tiled instruction, node-type, relative spatial, floor, material, affordance).
| Model | Path | Params | Test F1 | HO F1 |
|---|---|---|---|---|
| SAGEConv (binary) | v12/SAGE/best.pt |
1.25 M | 0.924 | 0.949 |
| GCNConv (binary) | v12/GCN/best.pt |
0.80 M | 0.921 | 0.946 |
| V12ALL-SAGE (rel+w+Ο) | v12all/SAGE/best.pt |
1.25 M | 0.916 | 0.941 |
| V12ALL-GCN (rel+w+Ο) | v12all/GCN/best.pt |
0.80 M | 0.905 | 0.926 |
V11T β ModernBERT token-span classifier
ModernBERT-base backbone with per-node span-pooling on a single linearized
scene sequence ([CLS] c [SEP] n_1 | β¦ | n_N [SEP], max 1024 tokens).
| Model | Path | Trainable | Test F1 | HO F1 |
|---|---|---|---|---|
| LoRA r=8 (binary) | v11t/lora_r8/model_best.pt |
1.7 M | 0.910 | 0.949 |
| LoRA r=16 (binary) | v11t/lora_r16/model_best.pt |
3.4 M | 0.897 | 0.951 |
| Full fine-tune (binary) | v11t/full/model_best.pt |
149.0 M | 0.907 | 0.958 |
| V11ALL-LoRA r=8 (rel+w+Ο) | v11all/lora_r8/model_best.pt |
1.7 M | 0.889 | 0.920 |
| V11ALL-Full (rel+w+Ο) | v11all/full/model_best.pt |
149.0 M | 0.929 | 0.990 |
V13 β ModernBERT retrieval cascade
Pair-wise classifier over (constraint, node_text) pairs. Node text is a
self-describing JSON-style snippet
({type: object, name: chair, room: kitchen, floor: B, size: 0.5x0.8x0.4m, material: wood}).
| Model | Path | Trainable | Test F1 | HO F1 |
|---|---|---|---|---|
| Bi-Encoder LoRA r=8 (binary, recall-focused) | v13/bi_encoder_lora_r8/model_best.pt |
1.7 M | 0.946 | 0.972 |
| Cross-Encoder Full (binary) | v13/cross_encoder_full/model_best.pt |
149.0 M | 0.956 | 0.975 |
| V13ALL Bi-Encoder (rel, recall-focused) | v13all/bi_encoder_lora_r8_recall/model_best.pt |
1.7 M | 0.946 | 0.972 |
| V13ALL Cross-Encoder (rel+w+Ο) | v13all/cross_encoder_full/model_best.pt |
149.0 M | 0.956 | 0.975 |
| V13ALL Cascade (bi β§ cross) | β same two files, gate at inference β | β | 0.954 | 0.974 |
The bi-encoder uses a boosted pos_weight = 1.5Β·β(#neg/#pos) and F2-score
threshold tuning so it passes β 95 % of true positives to the cross-encoder.
The cascade prediction is the logical intersection of the two stages,
trading ~0.002 F1 for higher precision (0.97) and lower MAE_w (0.038).
Three-objective regression quality (V*ALL)
| Model | Test MAE_w | Test MAE_Ο | HO MAE_w | HO MAE_Ο |
|---|---|---|---|---|
| V12ALL-SAGE | 0.047 | 0.038 | 0.039 | 0.035 |
| V12ALL-GCN | 0.060 | 0.056 | 0.049 | 0.051 |
| V11ALL-LoRA r=8 | 0.087 | 0.069 | 0.047 | 0.059 |
| V11ALL-Full | 0.059 | 0.042 | 0.041 | 0.037 |
| V13ALL Cross-Encoder | 0.047 | 0.048 | 0.042 | 0.046 |
| V13ALL Cascade | 0.038 | 0.046 | 0.040 | 0.045 |
w β [-1, +1] (tanh), Ο β [0.8, +β) (softplus + 0.8). Regression MAE
is computed only on true-positive nodes (where rel_gt = 1).
Dataset
dataset/v4_dedup_enriched_1024/ β pre-tokenized HF Dataset
(ModernBERT tokenizer, enriched text mode, max_length = 1024).
Source: train_ready_v4_dedup.jsonl (7 911 records, 88 Matterport3D scenes,
β 783 K nodes, 7.6 % relevant).
Quick start (from HF Hub)
1. Install dependencies
pip install torch torch-geometric huggingface_hub sentence-transformers \
transformers peft datasets
2. Download models
# CLI: entire repo
hf download Catkamakura/roman-scene-graph --local-dir roman-scene-graph
# CLI: specific model only
hf download Catkamakura/roman-scene-graph v13all/cross_encoder_full/model_best.pt \
--local-dir roman-scene-graph
# Python: entire repo
from huggingface_hub import snapshot_download
snapshot_download("Catkamakura/roman-scene-graph", local_dir="roman-scene-graph")
# Python: specific file
from huggingface_hub import hf_hub_download
path = hf_hub_download("Catkamakura/roman-scene-graph",
"v13all/cross_encoder_full/model_best.pt")
3. Clone the code repo
The checkpoints need the source code to run:
git clone <your-code-repo-url> roman-code
cd roman-code
4. V12 GNN inference (binary)
import torch
from huggingface_hub import hf_hub_download
import text_encoders
from SceneGraphDatasetV4E import SceneGraphDatasetV4E
from train_v12_sage_vs_gcn import SceneGraphSAGE, forward_with_tiled_instr
model_path = hf_hub_download("Catkamakura/roman-scene-graph", "v12/SAGE/best.pt")
model = SceneGraphSAGE(in_channels=1262, hidden_channels_arr=[256]*3, out_channels=64)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
te = text_encoders.DictTextEncoder(
"embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl")
ds = SceneGraphDatasetV4E("your_input.jsonl", text_encoder=te,
include_parent_room=True)
ds.encode_all_node_features(); ds.all_graphs_make_x()
with torch.no_grad():
scores = forward_with_tiled_instr(model, ds[0], "cpu")
relevant = torch.sigmoid(scores) > 0.5
5. V12ALL GNN inference (relevance + weight + sigma)
import torch
from huggingface_hub import hf_hub_download
from train_v12all import SceneGraphSAGE_ALL
model_path = hf_hub_download("Catkamakura/roman-scene-graph",
"v12all/SAGE/best.pt")
model = SceneGraphSAGE_ALL(in_channels=1262, hidden_channels_arr=[256]*3,
out_channels=64)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
# Same data-prep as V12 (SceneGraphDatasetV4E); then:
with torch.no_grad():
# x = [graph.x ; tiled instruction] as in V12
rel_logits, w_hat, sigma_hat = model(x, graph.edge_index)
relevant = torch.sigmoid(rel_logits) > 0.5
# w_hat β [-1,+1] (tanh), sigma_hat β₯ 0.8 (softplus + 0.8)
6. V11T ModernBERT inference (binary, token-span)
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from datasets import Dataset
from model_modernbert import SceneGraphModernBERT
model_path = hf_hub_download("Catkamakura/roman-scene-graph",
"v11t/full/model_best.pt")
ds_root = snapshot_download("Catkamakura/roman-scene-graph",
allow_patterns="dataset/v4_dedup_enriched_1024/*",
local_dir="roman-scene-graph")
model = SceneGraphModernBERT(backbone="answerdotai/ModernBERT-base",
train_mode="full")
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
dataset = Dataset.load_from_disk(
"roman-scene-graph/dataset/v4_dedup_enriched_1024")
7. V11ALL ModernBERT inference (3 objectives)
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from model_modernbert_all import SceneGraphModernBERTALL, apply_lora_all
tok = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
backbone = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
model = SceneGraphModernBERTALL(backbone)
# For the LoRA r=8 variant, wrap the backbone before loading weights:
# model, *_ = apply_lora_all(model, r=8)
ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
"v11all/full/model_best.pt")
model.load_state_dict(torch.load(ckpt, weights_only=True), strict=False)
model.eval()
# Inputs: input_ids [1, T], attention_mask [1, T],
# node_spans [list of (start, end) per node]
rel_logits, w_hat, sigma_hat = model(input_ids, attention_mask,
node_spans=[spans])
8. V13 / V13ALL cross-encoder inference (per-pair)
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from train_v13all import CrossEncoderALL
ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
"v13all/cross_encoder_full/model_best.pt")
backbone = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
model = CrossEncoderALL(backbone)
model.load_state_dict(torch.load(ckpt, weights_only=True))
model.eval()
constraint = "avoid the kitchen"
node_text = "{type: object, name: chair, room: kitchen, floor: B, size: 0.5x0.8x0.4m}"
enc = tokenizer(constraint, node_text, return_tensors="pt",
padding="max_length", max_length=128, truncation=True)
with torch.no_grad():
rel_logits, w_hat, sigma_hat = model(**{k: v for k, v in enc.items()})
9. V13ALL cascade (bi-filter β cross-score)
from huggingface_hub import hf_hub_download
from train_v13all import BiEncoder, CrossEncoderALL, apply_lora
from transformers import AutoModel, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
# Bi-encoder (LoRA)
bi_ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
"v13all/bi_encoder_lora_r8_recall/model_best.pt")
bi = BiEncoder(apply_lora(AutoModel.from_pretrained("answerdotai/ModernBERT-base")))
bi.load_state_dict(torch.load(bi_ckpt, weights_only=True))
bi.eval()
# Cross-encoder (Full)
cr_ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
"v13all/cross_encoder_full/model_best.pt")
cr = CrossEncoderALL(AutoModel.from_pretrained("answerdotai/ModernBERT-base"))
cr.load_state_dict(torch.load(cr_ckpt, weights_only=True))
cr.eval()
TAU_BI, TAU_CR = 0.40, 0.50 # recommended cascade thresholds
def score(constraint, node_texts):
# Bi-encoder filter
q = tokenizer(constraint, return_tensors="pt", padding="max_length",
max_length=128, truncation=True)
d = tokenizer(node_texts, return_tensors="pt", padding="max_length",
max_length=128, truncation=True)
with torch.no_grad():
s_bi = bi(q["input_ids"], q["attention_mask"],
d["input_ids"], d["attention_mask"])
p_bi = torch.sigmoid(s_bi)
# Cross-encoder score (on survivors β or all, if GPU budget allows)
pairs = tokenizer([constraint]*len(node_texts), node_texts,
return_tensors="pt", padding="max_length",
max_length=128, truncation=True)
with torch.no_grad():
rel, w, sigma = cr(**{k: v for k, v in pairs.items()})
# Cascade intersection
final_rel = (p_bi > TAU_BI) & (torch.sigmoid(rel) > TAU_CR)
return final_rel, w, sigma
10. Using a local download (no internet after first fetch)
hf download Catkamakura/roman-scene-graph --local-dir roman-scene-graph
Then load from roman-scene-graph/<path> instead of using hf_hub_download.
Training
# V12 binary
python train_v12_sage_vs_gcn.py --model both \
--dataset_path training_data_v2/train_ready_v4_dedup.jsonl \
--dictFile embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl
# V12ALL (3 objectives)
./run_v12all.sh
# V11T binary
./run_v11_token_v4.sh
# V11ALL (3 objectives)
./run_v11all.sh
# V13 binary (bi-encoder + cross-encoder)
./run_v13.sh
# V13ALL (3 objectives, runs bi-encoder β cross-encoder β cascade eval)
./run_v13all.sh
Repo structure
Catkamakura/roman-scene-graph/
βββ README.md
βββ v11t/{lora_r8, lora_r16, full}/ (binary, token-span)
βββ v11all/{lora_r8, full}/ (3 objectives, token-span)
βββ v12/{SAGE, GCN}/ (binary, GNN)
βββ v12all/{SAGE, GCN}/ (3 objectives, GNN)
βββ v13/{bi_encoder_lora_r8, cross_encoder_full}/ (binary, retrieval)
βββ v13all/{bi_encoder_lora_r8_recall, cross_encoder_full}/ (3 obj., retrieval)
βββ dataset/v4_dedup_enriched_1024/ (pre-tokenized HF Dataset)
Code dependencies
| File | Purpose |
|---|---|
SceneGraphDatasetV4E.py |
V4E enriched dataset loader (floor/material/affordance) |
train_v12_sage_vs_gcn.py |
V12 binary training + SceneGraphSAGE |
train_v12all.py |
V12ALL training + SceneGraphSAGE_ALL, SceneGraphGCN_ALL |
model_v2.py |
SceneGraphGCNv2 (reused as V12 GCN baseline) |
model_modernbert.py |
V11T SceneGraphModernBERT (binary) |
model_modernbert_all.py |
V11ALL SceneGraphModernBERTALL (3 objectives) |
SceneGraphDatasetBERT.py |
V11T dataset loader (token-span) |
train_v11.py, train_v11all.py |
V11T / V11ALL training scripts |
prepare_hf_dataset.py, prepare_hf_dataset_all.py |
Build pre-tokenized HF datasets |
train_v13_retrieval.py |
V13 binary (bi-encoder + cross-encoder) training |
train_v13all.py |
V13ALL training + cascade eval |
prepare_retrieval_dataset.py, prepare_retrieval_dataset_all.py |
Build retrieval pair datasets |
text_encoders/ |
DictTextEncoder for pre-computed embeddings |
loss_v2.py, loss_retrieval.py |
Loss functions |
embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl |
Pre-computed label/constraint embeddings (V12 only) |
MODEL_ARCHITECTURES.md / .tex / .pdf |
Per-model architecture reference with vertical diagrams |