You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

ROMAN β€” Scene Graph Node Relevance Classification

Given a 3D scene graph and a natural-language navigation constraint (e.g., "avoid the kitchen"), classify each node as relevant or non-relevant for cost-function generation used by a PRM path planner.

Two model families are provided:

  • Binary-relevance (V11T, V12, V13) β€” single per-node probability.
  • Three-objective (V11ALL, V12ALL, V13ALL) β€” per-node tuple (relevance, weight ∈ [-1,+1], sigma β‰₯ 0.8) consumed by the PRM Gaussian cost field f(x) = Ξ£_{i: rel_i=1} w_i Β· exp(-β€–xβˆ’p_iβ€–Β² / 2Οƒ_iΒ²).

Architectural detail for each family lives in MODEL_ARCHITECTURES.md (PDF also provided).

Models in this repo

V12 β€” Graph neural network on enriched V4E features

Lightweight GNN on 1262-dim enriched node features (label & parent-room embeddings, tiled instruction, node-type, relative spatial, floor, material, affordance).

Model Path Params Test F1 HO F1
SAGEConv (binary) v12/SAGE/best.pt 1.25 M 0.924 0.949
GCNConv (binary) v12/GCN/best.pt 0.80 M 0.921 0.946
V12ALL-SAGE (rel+w+Οƒ) v12all/SAGE/best.pt 1.25 M 0.916 0.941
V12ALL-GCN (rel+w+Οƒ) v12all/GCN/best.pt 0.80 M 0.905 0.926

V11T β€” ModernBERT token-span classifier

ModernBERT-base backbone with per-node span-pooling on a single linearized scene sequence ([CLS] c [SEP] n_1 | … | n_N [SEP], max 1024 tokens).

Model Path Trainable Test F1 HO F1
LoRA r=8 (binary) v11t/lora_r8/model_best.pt 1.7 M 0.910 0.949
LoRA r=16 (binary) v11t/lora_r16/model_best.pt 3.4 M 0.897 0.951
Full fine-tune (binary) v11t/full/model_best.pt 149.0 M 0.907 0.958
V11ALL-LoRA r=8 (rel+w+Οƒ) v11all/lora_r8/model_best.pt 1.7 M 0.889 0.920
V11ALL-Full (rel+w+Οƒ) v11all/full/model_best.pt 149.0 M 0.929 0.990

V13 β€” ModernBERT retrieval cascade

Pair-wise classifier over (constraint, node_text) pairs. Node text is a self-describing JSON-style snippet ({type: object, name: chair, room: kitchen, floor: B, size: 0.5x0.8x0.4m, material: wood}).

Model Path Trainable Test F1 HO F1
Bi-Encoder LoRA r=8 (binary, recall-focused) v13/bi_encoder_lora_r8/model_best.pt 1.7 M 0.946 0.972
Cross-Encoder Full (binary) v13/cross_encoder_full/model_best.pt 149.0 M 0.956 0.975
V13ALL Bi-Encoder (rel, recall-focused) v13all/bi_encoder_lora_r8_recall/model_best.pt 1.7 M 0.946 0.972
V13ALL Cross-Encoder (rel+w+Οƒ) v13all/cross_encoder_full/model_best.pt 149.0 M 0.956 0.975
V13ALL Cascade (bi ∧ cross) β€” same two files, gate at inference β€” β€” 0.954 0.974

The bi-encoder uses a boosted pos_weight = 1.5·√(#neg/#pos) and F2-score threshold tuning so it passes β‰ˆ 95 % of true positives to the cross-encoder. The cascade prediction is the logical intersection of the two stages, trading ~0.002 F1 for higher precision (0.97) and lower MAE_w (0.038).

Three-objective regression quality (V*ALL)

Model Test MAE_w Test MAE_Οƒ HO MAE_w HO MAE_Οƒ
V12ALL-SAGE 0.047 0.038 0.039 0.035
V12ALL-GCN 0.060 0.056 0.049 0.051
V11ALL-LoRA r=8 0.087 0.069 0.047 0.059
V11ALL-Full 0.059 0.042 0.041 0.037
V13ALL Cross-Encoder 0.047 0.048 0.042 0.046
V13ALL Cascade 0.038 0.046 0.040 0.045

w ∈ [-1, +1] (tanh), Οƒ ∈ [0.8, +∞) (softplus + 0.8). Regression MAE is computed only on true-positive nodes (where rel_gt = 1).

Dataset

dataset/v4_dedup_enriched_1024/ β€” pre-tokenized HF Dataset (ModernBERT tokenizer, enriched text mode, max_length = 1024). Source: train_ready_v4_dedup.jsonl (7 911 records, 88 Matterport3D scenes, β‰ˆ 783 K nodes, 7.6 % relevant).

Quick start (from HF Hub)

1. Install dependencies

pip install torch torch-geometric huggingface_hub sentence-transformers \
            transformers peft datasets

2. Download models

# CLI: entire repo
hf download Catkamakura/roman-scene-graph --local-dir roman-scene-graph

# CLI: specific model only
hf download Catkamakura/roman-scene-graph v13all/cross_encoder_full/model_best.pt \
    --local-dir roman-scene-graph
# Python: entire repo
from huggingface_hub import snapshot_download
snapshot_download("Catkamakura/roman-scene-graph", local_dir="roman-scene-graph")

# Python: specific file
from huggingface_hub import hf_hub_download
path = hf_hub_download("Catkamakura/roman-scene-graph",
                      "v13all/cross_encoder_full/model_best.pt")

3. Clone the code repo

The checkpoints need the source code to run:

git clone <your-code-repo-url> roman-code
cd roman-code

4. V12 GNN inference (binary)

import torch
from huggingface_hub import hf_hub_download
import text_encoders
from SceneGraphDatasetV4E import SceneGraphDatasetV4E
from train_v12_sage_vs_gcn import SceneGraphSAGE, forward_with_tiled_instr

model_path = hf_hub_download("Catkamakura/roman-scene-graph", "v12/SAGE/best.pt")
model = SceneGraphSAGE(in_channels=1262, hidden_channels_arr=[256]*3, out_channels=64)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

te = text_encoders.DictTextEncoder(
    "embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl")
ds = SceneGraphDatasetV4E("your_input.jsonl", text_encoder=te,
                          include_parent_room=True)
ds.encode_all_node_features(); ds.all_graphs_make_x()

with torch.no_grad():
    scores = forward_with_tiled_instr(model, ds[0], "cpu")
    relevant = torch.sigmoid(scores) > 0.5

5. V12ALL GNN inference (relevance + weight + sigma)

import torch
from huggingface_hub import hf_hub_download
from train_v12all import SceneGraphSAGE_ALL

model_path = hf_hub_download("Catkamakura/roman-scene-graph",
                              "v12all/SAGE/best.pt")
model = SceneGraphSAGE_ALL(in_channels=1262, hidden_channels_arr=[256]*3,
                            out_channels=64)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

# Same data-prep as V12 (SceneGraphDatasetV4E); then:
with torch.no_grad():
    # x = [graph.x ; tiled instruction] as in V12
    rel_logits, w_hat, sigma_hat = model(x, graph.edge_index)
    relevant = torch.sigmoid(rel_logits) > 0.5
    # w_hat ∈ [-1,+1] (tanh), sigma_hat β‰₯ 0.8 (softplus + 0.8)

6. V11T ModernBERT inference (binary, token-span)

import torch
from huggingface_hub import hf_hub_download, snapshot_download
from datasets import Dataset
from model_modernbert import SceneGraphModernBERT

model_path = hf_hub_download("Catkamakura/roman-scene-graph",
                              "v11t/full/model_best.pt")
ds_root = snapshot_download("Catkamakura/roman-scene-graph",
                            allow_patterns="dataset/v4_dedup_enriched_1024/*",
                            local_dir="roman-scene-graph")

model = SceneGraphModernBERT(backbone="answerdotai/ModernBERT-base",
                              train_mode="full")
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

dataset = Dataset.load_from_disk(
    "roman-scene-graph/dataset/v4_dedup_enriched_1024")

7. V11ALL ModernBERT inference (3 objectives)

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from model_modernbert_all import SceneGraphModernBERTALL, apply_lora_all

tok = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
backbone = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
model = SceneGraphModernBERTALL(backbone)
# For the LoRA r=8 variant, wrap the backbone before loading weights:
#   model, *_ = apply_lora_all(model, r=8)

ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
                       "v11all/full/model_best.pt")
model.load_state_dict(torch.load(ckpt, weights_only=True), strict=False)
model.eval()

# Inputs: input_ids [1, T], attention_mask [1, T],
# node_spans  [list of (start, end) per node]
rel_logits, w_hat, sigma_hat = model(input_ids, attention_mask,
                                     node_spans=[spans])

8. V13 / V13ALL cross-encoder inference (per-pair)

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from train_v13all import CrossEncoderALL

ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
                       "v13all/cross_encoder_full/model_best.pt")
backbone  = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
model = CrossEncoderALL(backbone)
model.load_state_dict(torch.load(ckpt, weights_only=True))
model.eval()

constraint = "avoid the kitchen"
node_text  = "{type: object, name: chair, room: kitchen, floor: B, size: 0.5x0.8x0.4m}"
enc = tokenizer(constraint, node_text, return_tensors="pt",
                padding="max_length", max_length=128, truncation=True)
with torch.no_grad():
    rel_logits, w_hat, sigma_hat = model(**{k: v for k, v in enc.items()})

9. V13ALL cascade (bi-filter β†’ cross-score)

from huggingface_hub import hf_hub_download
from train_v13all import BiEncoder, CrossEncoderALL, apply_lora
from transformers import AutoModel, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
# Bi-encoder (LoRA)
bi_ckpt  = hf_hub_download("Catkamakura/roman-scene-graph",
                           "v13all/bi_encoder_lora_r8_recall/model_best.pt")
bi = BiEncoder(apply_lora(AutoModel.from_pretrained("answerdotai/ModernBERT-base")))
bi.load_state_dict(torch.load(bi_ckpt, weights_only=True))
bi.eval()
# Cross-encoder (Full)
cr_ckpt = hf_hub_download("Catkamakura/roman-scene-graph",
                          "v13all/cross_encoder_full/model_best.pt")
cr = CrossEncoderALL(AutoModel.from_pretrained("answerdotai/ModernBERT-base"))
cr.load_state_dict(torch.load(cr_ckpt, weights_only=True))
cr.eval()

TAU_BI, TAU_CR = 0.40, 0.50  # recommended cascade thresholds

def score(constraint, node_texts):
    # Bi-encoder filter
    q = tokenizer(constraint, return_tensors="pt", padding="max_length",
                  max_length=128, truncation=True)
    d = tokenizer(node_texts, return_tensors="pt", padding="max_length",
                  max_length=128, truncation=True)
    with torch.no_grad():
        s_bi = bi(q["input_ids"], q["attention_mask"],
                  d["input_ids"], d["attention_mask"])
        p_bi = torch.sigmoid(s_bi)
    # Cross-encoder score (on survivors β€” or all, if GPU budget allows)
    pairs = tokenizer([constraint]*len(node_texts), node_texts,
                      return_tensors="pt", padding="max_length",
                      max_length=128, truncation=True)
    with torch.no_grad():
        rel, w, sigma = cr(**{k: v for k, v in pairs.items()})
    # Cascade intersection
    final_rel = (p_bi > TAU_BI) & (torch.sigmoid(rel) > TAU_CR)
    return final_rel, w, sigma

10. Using a local download (no internet after first fetch)

hf download Catkamakura/roman-scene-graph --local-dir roman-scene-graph

Then load from roman-scene-graph/<path> instead of using hf_hub_download.

Training

# V12 binary
python train_v12_sage_vs_gcn.py --model both \
    --dataset_path training_data_v2/train_ready_v4_dedup.jsonl \
    --dictFile embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl

# V12ALL (3 objectives)
./run_v12all.sh

# V11T binary
./run_v11_token_v4.sh

# V11ALL (3 objectives)
./run_v11all.sh

# V13 binary (bi-encoder + cross-encoder)
./run_v13.sh

# V13ALL (3 objectives, runs bi-encoder β†’ cross-encoder β†’ cascade eval)
./run_v13all.sh

Repo structure

Catkamakura/roman-scene-graph/
β”œβ”€β”€ README.md
β”œβ”€β”€ v11t/{lora_r8, lora_r16, full}/            (binary, token-span)
β”œβ”€β”€ v11all/{lora_r8, full}/                    (3 objectives, token-span)
β”œβ”€β”€ v12/{SAGE, GCN}/                           (binary, GNN)
β”œβ”€β”€ v12all/{SAGE, GCN}/                        (3 objectives, GNN)
β”œβ”€β”€ v13/{bi_encoder_lora_r8, cross_encoder_full}/      (binary, retrieval)
β”œβ”€β”€ v13all/{bi_encoder_lora_r8_recall, cross_encoder_full}/  (3 obj., retrieval)
└── dataset/v4_dedup_enriched_1024/            (pre-tokenized HF Dataset)

Code dependencies

File Purpose
SceneGraphDatasetV4E.py V4E enriched dataset loader (floor/material/affordance)
train_v12_sage_vs_gcn.py V12 binary training + SceneGraphSAGE
train_v12all.py V12ALL training + SceneGraphSAGE_ALL, SceneGraphGCN_ALL
model_v2.py SceneGraphGCNv2 (reused as V12 GCN baseline)
model_modernbert.py V11T SceneGraphModernBERT (binary)
model_modernbert_all.py V11ALL SceneGraphModernBERTALL (3 objectives)
SceneGraphDatasetBERT.py V11T dataset loader (token-span)
train_v11.py, train_v11all.py V11T / V11ALL training scripts
prepare_hf_dataset.py, prepare_hf_dataset_all.py Build pre-tokenized HF datasets
train_v13_retrieval.py V13 binary (bi-encoder + cross-encoder) training
train_v13all.py V13ALL training + cascade eval
prepare_retrieval_dataset.py, prepare_retrieval_dataset_all.py Build retrieval pair datasets
text_encoders/ DictTextEncoder for pre-computed embeddings
loss_v2.py, loss_retrieval.py Loss functions
embeddings/sentence-transformers/all-MiniLM-L6-v2_embeddings_False.pkl Pre-computed label/constraint embeddings (V12 only)
MODEL_ARCHITECTURES.md / .tex / .pdf Per-model architecture reference with vertical diagrams
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support