Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +17 -0
- README.md +147 -8
- abst.png +3 -0
- app.py +174 -0
- configs/base.yaml +11 -0
- configs/gatv2.yaml +36 -0
- evaluate.py +288 -0
- infer.py +797 -0
- mecari/__init__.py +9 -0
- mecari/analyzers/mecab.py +151 -0
- mecari/config/config.py +84 -0
- mecari/data/data_module.py +361 -0
- mecari/featurizers/lexical.py +116 -0
- mecari/models/__init__.py +6 -0
- mecari/models/base.py +214 -0
- mecari/models/gatv2.py +139 -0
- mecari/utils/__init__.py +4 -0
- mecari/utils/morph_utils.py +51 -0
- mecari/utils/signature.py +39 -0
- packages.txt +5 -0
- preprocess.py +366 -0
- pyproject.toml +56 -0
- requirements.txt +20 -0
- runtime.txt +1 -0
- sample_model/config.yaml +39 -0
- sample_model/model.pt +3 -0
- train.py +388 -0
- up_hf.py +18 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
abst.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
.venv
|
| 10 |
+
|
| 11 |
+
annotations*/
|
| 12 |
+
experiments/
|
| 13 |
+
lightning_logs/
|
| 14 |
+
cache/
|
| 15 |
+
results/
|
| 16 |
+
|
| 17 |
+
KWDLC/
|
README.md
CHANGED
|
@@ -1,14 +1,153 @@
|
|
| 1 |
---
|
| 2 |
-
title: Mecari
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: cc-by-nc-4.0
|
| 11 |
-
short_description: 'Demo of Mecari: GNN-based Morphological Analyzer'
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Mecari Morpheme Analyzer
|
| 3 |
+
emoji: 🧩
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.37.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Mecari (Japanese Morphological Analysis with Graph Neural Networks)
|
| 13 |
+
|
| 14 |
+
## Training
|
| 15 |
+
|
| 16 |
+
### Overview
|
| 17 |
+
|
| 18 |
+
Mecari [1] is a GNN‑based Japanese morphological analyzer. It supports training from partially annotated graphs (only '+'/'-' where available; '?' is ignored) and aims for fast training and inference.
|
| 19 |
+
|
| 20 |
+
<p align="center">
|
| 21 |
+
<img src="abst.png" alt="Overview" width="70%" />
|
| 22 |
+
<!-- Adjust width (e.g., 60%, 50%, or px) as desired -->
|
| 23 |
+
|
| 24 |
+
</p>
|
| 25 |
+
|
| 26 |
+
### Graph
|
| 27 |
+
The graph is built from MeCab morpheme candidates.
|
| 28 |
+
|
| 29 |
+
### Annotation
|
| 30 |
+
Annotations are created by matching morpheme candidates to gold labels.
|
| 31 |
+
Annotations serve as the training targets (supervision) during learning.
|
| 32 |
+
- `+`: Candidate that exactly matches the gold.
|
| 33 |
+
- `-`: Any other candidate that overlaps by 1+ character with a `+` candidate.
|
| 34 |
+
- `?`: All other candidates (ignored during training).
|
| 35 |
+
|
| 36 |
+
### Training
|
| 37 |
+
Nodes are featurized with JUMAN++‑style unigram features, edges are modeled as undirected (bidirectional), and a GATv2 [2] is trained on the resulting graphs.
|
| 38 |
+
|
| 39 |
+
### Inference
|
| 40 |
+
Use the model’s node scores and run Viterbi to search the optimal non‑overlapping path.
|
| 41 |
+
|
| 42 |
+
## Results (KWDLC test)
|
| 43 |
+
|
| 44 |
+
- Trained model (sample_model): Seg F1 0.9725, POS F1 0.9562
|
| 45 |
+
- MeCab (JUMANDIC) baseline: Seg F1 0.9677, POS F1 0.9465
|
| 46 |
+
|
| 47 |
+
The GATv2 model trained with this repository (current code and `configs/gatv2.yaml`) using the official KWDLC split outperforms MeCab on both segmentation and POS accuracy.
|
| 48 |
+
|
| 49 |
+
## Tested Environment
|
| 50 |
+
|
| 51 |
+
- OS: Ubuntu 24.04.3 LTS (Noble Numbat)
|
| 52 |
+
- Python: 3.11.3
|
| 53 |
+
- PyTorch: 2.2.2+cu121
|
| 54 |
+
- CUDA (runtime): 12.1 (cu121)
|
| 55 |
+
- MeCab (binary): 0.996
|
| 56 |
+
- JUMANDIC: `/var/lib/mecab/dic/juman-utf8`
|
| 57 |
+
|
| 58 |
+
## MeCab Setup (Ubuntu 24.04)
|
| 59 |
+
1) Install packages (includes the JUMANDIC dictionary)
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
sudo apt update
|
| 63 |
+
sudo apt install -y mecab mecab-utils libmecab-dev mecab-jumandic-utf8
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
2) Verify installation
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
mecab -v # e.g., mecab of 0.996
|
| 70 |
+
test -d /var/lib/mecab/dic/juman-utf8 && echo "JUMANDIC OK"
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Project Setup
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
# Install uv if needed
|
| 77 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 78 |
+
|
| 79 |
+
# Create venv and install dependencies
|
| 80 |
+
uv venv
|
| 81 |
+
source .venv/bin/activate
|
| 82 |
+
uv sync
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Quickstart (Morphological analysis)
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# Analyze a single sentence with the bundled sample model
|
| 89 |
+
python infer.py --text "東京都の外国人参政権"
|
| 90 |
+
|
| 91 |
+
# Interactive mode
|
| 92 |
+
python infer.py
|
| 93 |
+
|
| 94 |
+
# After training, specify an experiment to use a custom model
|
| 95 |
+
python infer.py --experiment gatv2_YYYYMMDD_HHMMSS --text "..."
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
Note
|
| 99 |
+
- When no experiment is specified, the model at `sample_model/` is loaded by default.
|
| 100 |
+
|
| 101 |
+
## Train by yourself
|
| 102 |
+
### KWDLC Setup (Required)
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
cd /path/to/Mecari
|
| 106 |
+
git clone --depth 1 https://github.com/ku-nlp/KWDLC
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
- Training requires KWDLC (non‑KWDLC training is not supported at the moment).
|
| 110 |
+
- Splits strictly follow the official `dev.id` / `test.id` files.
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
### Preprocessing
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
python preprocess.py --config configs/gatv2.yaml
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Training
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
python train.py --config configs/gatv2.yaml
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
- Outputs are saved under `experiments/<name>/`.
|
| 126 |
+
- The bundled model was trained with the current codebase and configuration (`configs/gatv2.yaml`).
|
| 127 |
+
|
| 128 |
+
### Evaluation
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
python evaluate.py --max-samples 50 \
|
| 132 |
+
--experiment gatv2_YYYYMMDD_HHMMSS
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## License
|
| 137 |
+
|
| 138 |
+
CC BY‑NC 4.0 (non‑commercial use only)
|
| 139 |
+
|
| 140 |
+
## Acknowledgments
|
| 141 |
+
- [1] Technical inspiration: Mecari, a morphological analysis system developed by Google, as described in “Data processing for Japanese text‑to‑pronunciation models” by G. Mazovetskiy and T. Kudo (NLP2024 Workshop on Japanese Language Resources). URL: https://jedworkshop.github.io/JLR2024/materials/b-2.pdf (pp. 19–23)
|
| 142 |
+
- [2] Graph architecture: Brody, Shaked, Uri Alon, and Eran Yahav. "HOW ATTENTIVE ARE GRAPH ATTENTION NETWORKS?." 10th International Conference on Learning Representations, ICLR 2022. 2022.
|
| 143 |
+
|
| 144 |
+
## Disclaimer
|
| 145 |
+
- Independent academic implementation for educational and research purposes.
|
| 146 |
+
- Core concepts (graph‑based morpheme boundary annotation) follow the published work; implementation details and code structure are our interpretation.
|
| 147 |
+
- Not affiliated with, endorsed by, or connected to Google or its subsidiaries.
|
| 148 |
+
|
| 149 |
+
## Purpose
|
| 150 |
+
- Academic research
|
| 151 |
+
- Education
|
| 152 |
+
- Technical skill development
|
| 153 |
+
- Understanding of NLP techniques
|
abst.png
ADDED
|
Git LFS Details
|
app.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# Ensure wandb never starts in Spaces
|
| 9 |
+
os.environ["WANDB_MODE"] = "disabled"
|
| 10 |
+
|
| 11 |
+
# Resolve MeCab binary for this process
|
| 12 |
+
_default_mecab = "/usr/bin/mecab" if os.path.exists("/usr/bin/mecab") else "mecab"
|
| 13 |
+
MECAB_BIN = os.getenv("MECAB_BIN", _default_mecab)
|
| 14 |
+
os.environ["MECAB_BIN"] = MECAB_BIN
|
| 15 |
+
|
| 16 |
+
# Lazy-loaded model
|
| 17 |
+
_model = None
|
| 18 |
+
_exp_info = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _ensure_model():
|
| 22 |
+
global _model, _exp_info
|
| 23 |
+
if _model is None:
|
| 24 |
+
from infer import load_model
|
| 25 |
+
|
| 26 |
+
result = load_model()
|
| 27 |
+
if result is None:
|
| 28 |
+
raise RuntimeError(
|
| 29 |
+
"Model could not be loaded. Ensure sample_model/ exists with config.yaml and model.pt."
|
| 30 |
+
)
|
| 31 |
+
_model, _exp_info = result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _to_mecab_lines(results, optimal_morphemes=None) -> str:
|
| 35 |
+
# Build MeCab-like output lines
|
| 36 |
+
def mecab_features(m):
|
| 37 |
+
pos = m.get("pos", "*")
|
| 38 |
+
pos1 = m.get("pos_detail1", "*")
|
| 39 |
+
pos2 = m.get("pos_detail2", "*")
|
| 40 |
+
ctype = m.get("inflection_type", "*")
|
| 41 |
+
cform = m.get("inflection_form", "*")
|
| 42 |
+
base = m.get("base_form", m.get("lemma", "*")) or "*"
|
| 43 |
+
# Mecari output includes reading as 7th field
|
| 44 |
+
reading = m.get("reading", "*") or "*"
|
| 45 |
+
return f"{pos},{pos1},{pos2},{ctype},{cform},{base},{reading}"
|
| 46 |
+
|
| 47 |
+
items = (
|
| 48 |
+
optimal_morphemes
|
| 49 |
+
if optimal_morphemes
|
| 50 |
+
else [
|
| 51 |
+
{
|
| 52 |
+
"surface": r.get("surface", ""),
|
| 53 |
+
"pos": r.get("pos", "*"),
|
| 54 |
+
"pos_detail1": "*",
|
| 55 |
+
"pos_detail2": "*",
|
| 56 |
+
"inflection_type": "*",
|
| 57 |
+
"inflection_form": "*",
|
| 58 |
+
"base_form": r.get("surface", ""),
|
| 59 |
+
"reading": r.get("reading", "*"),
|
| 60 |
+
}
|
| 61 |
+
for r in results
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
lines = [f"{m.get('surface','')}\t{mecab_features(m)}" for m in items]
|
| 66 |
+
lines.append("EOS")
|
| 67 |
+
return "\n".join(lines)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def mecab_plain(text: str) -> str:
|
| 71 |
+
"""Run system MeCab and return its raw parsing (surface\tCSV ...\nEOS)."""
|
| 72 |
+
try:
|
| 73 |
+
from mecari.analyzers.mecab import MeCabAnalyzer
|
| 74 |
+
|
| 75 |
+
analyzer = MeCabAnalyzer()
|
| 76 |
+
mecab_bin = os.getenv("MECAB_BIN", analyzer.mecab_bin)
|
| 77 |
+
args = [mecab_bin]
|
| 78 |
+
if isinstance(analyzer.jumandic_path, str) and os.path.isdir(analyzer.jumandic_path):
|
| 79 |
+
args += ["-d", analyzer.jumandic_path]
|
| 80 |
+
p = subprocess.run(args, input=text, text=True, capture_output=True)
|
| 81 |
+
out = (p.stdout or "") + ("\n" + p.stderr if p.stderr else "")
|
| 82 |
+
if p.returncode != 0:
|
| 83 |
+
return out.strip() or f"mecab error rc={p.returncode}"
|
| 84 |
+
# Trim extra tail fields (e.g., カテゴリ:*, ドメイン:*) and keep first 6 features
|
| 85 |
+
lines = []
|
| 86 |
+
for line in out.splitlines():
|
| 87 |
+
if not line or line.strip() == "EOS":
|
| 88 |
+
lines.append("EOS")
|
| 89 |
+
continue
|
| 90 |
+
if "\t" in line:
|
| 91 |
+
surface, feats = line.split("\t", 1)
|
| 92 |
+
parts = [s.strip() for s in feats.split(",")]
|
| 93 |
+
trimmed = parts[:6]
|
| 94 |
+
while len(trimmed) < 6:
|
| 95 |
+
trimmed.append("*")
|
| 96 |
+
lines.append(f"{surface}\t{','.join(trimmed)}")
|
| 97 |
+
else:
|
| 98 |
+
lines.append(line)
|
| 99 |
+
# Ensure trailing EOS only once
|
| 100 |
+
if not lines or lines[-1] != "EOS":
|
| 101 |
+
lines.append("EOS")
|
| 102 |
+
return "\n".join(lines)
|
| 103 |
+
except FileNotFoundError:
|
| 104 |
+
return "MeCabバイナリが見つかりません(MECAB_BINやpackages.txtを確認)。"
|
| 105 |
+
except Exception as e:
|
| 106 |
+
return f"mecab実行時エラー: {e}"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def analyze(text: str):
|
| 110 |
+
if not text or not text.strip():
|
| 111 |
+
return "", ""
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
_ensure_model()
|
| 115 |
+
from infer import predict_morphemes_from_text
|
| 116 |
+
|
| 117 |
+
text = text.strip()
|
| 118 |
+
result = predict_morphemes_from_text(text, _model, _exp_info, silent=True)
|
| 119 |
+
if not result:
|
| 120 |
+
return "推論に失敗しました。", mecab_plain(text)
|
| 121 |
+
results, optimal_morphemes = result
|
| 122 |
+
mecari_out = _to_mecab_lines(results, optimal_morphemes)
|
| 123 |
+
mecab_out = mecab_plain(text)
|
| 124 |
+
return mecari_out, mecab_out
|
| 125 |
+
except FileNotFoundError:
|
| 126 |
+
return (
|
| 127 |
+
"MeCabが見つかりません。Spaceのpackages.txtに 'mecab' と 'mecab-jumandic-utf8' を含めてビルドし直すか、\n"
|
| 128 |
+
"変数 MECAB_BIN=/usr/bin/mecab を設定してください。"
|
| 129 |
+
), ""
|
| 130 |
+
except Exception as e:
|
| 131 |
+
import traceback
|
| 132 |
+
|
| 133 |
+
tb = traceback.format_exc()
|
| 134 |
+
return f"エラー: {e}\n\n{tb}", ""
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
FONT_CSS = """
|
| 138 |
+
/* Prefer common system fonts for Latin text */
|
| 139 |
+
body, .gradio-container, .prose, textarea, input, button,
|
| 140 |
+
.gr-text-input input, .gr-text-input textarea, .gr-textbox textarea {
|
| 141 |
+
font-family: system-ui, -apple-system, 'Segoe UI', Roboto, 'Noto Sans',
|
| 142 |
+
'Helvetica Neue', Arial, 'Apple Color Emoji', 'Segoe UI Emoji',
|
| 143 |
+
sans-serif !important;
|
| 144 |
+
}
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=FONT_CSS) as demo:
|
| 148 |
+
gr.Markdown(
|
| 149 |
+
"""
|
| 150 |
+
# Mecari Morpheme Analyzer
|
| 151 |
+
|
| 152 |
+
GNNベースの形態素解析器"Mecari"のデモです。github: https://github.com/zbller/Mecari
|
| 153 |
+
"""
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
with gr.Row():
|
| 157 |
+
inp = gr.Textbox(label="テキスト入力", value="とうきょうに行った", placeholder="とうきょうに行った", lines=3)
|
| 158 |
+
btn = gr.Button("解析する")
|
| 159 |
+
with gr.Row():
|
| 160 |
+
out_mecari = gr.Textbox(label="Mecari", lines=10)
|
| 161 |
+
out_mecab = gr.Textbox(label="MeCab(Jumandic)", lines=10)
|
| 162 |
+
btn.click(fn=analyze, inputs=inp, outputs=[out_mecari, out_mecab])
|
| 163 |
+
|
| 164 |
+
# Optional warm-up
|
| 165 |
+
def _warmup():
|
| 166 |
+
try:
|
| 167 |
+
_ensure_model()
|
| 168 |
+
except Exception:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
_warmup()
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7863")))
|
configs/base.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
features:
|
| 2 |
+
lexical_feature_dim: 100000
|
| 3 |
+
|
| 4 |
+
training:
|
| 5 |
+
deterministic: false
|
| 6 |
+
annotations_dir: "annotations"
|
| 7 |
+
project_name: "mecari"
|
| 8 |
+
|
| 9 |
+
inference:
|
| 10 |
+
checkpoint_dir: "experiments"
|
| 11 |
+
experiment_name: null
|
configs/gatv2.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: "base.yaml"
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
type: "gatv2"
|
| 5 |
+
hidden_dim: 64
|
| 6 |
+
num_layers: 4
|
| 7 |
+
num_heads: 4
|
| 8 |
+
dropout: 0.1
|
| 9 |
+
num_classes: 1
|
| 10 |
+
share_weights: false
|
| 11 |
+
|
| 12 |
+
edge_features:
|
| 13 |
+
use_bidirectional_edges: true
|
| 14 |
+
|
| 15 |
+
training:
|
| 16 |
+
learning_rate: 0.001
|
| 17 |
+
batch_size: 128
|
| 18 |
+
max_steps: 10000
|
| 19 |
+
patience: 10
|
| 20 |
+
gradient_clip_val: 0.5
|
| 21 |
+
gradient_clip_algorithm: "norm"
|
| 22 |
+
num_workers: 4
|
| 23 |
+
accumulate_grad_batches: 1
|
| 24 |
+
seed: 42
|
| 25 |
+
warmup_steps: 500
|
| 26 |
+
warmup_start_lr: 0.0
|
| 27 |
+
optimizer:
|
| 28 |
+
type: "adamw"
|
| 29 |
+
weight_decay: 0.001
|
| 30 |
+
use_wandb: true
|
| 31 |
+
log_every_n_steps: 50
|
| 32 |
+
val_check_interval: 1.0
|
| 33 |
+
|
| 34 |
+
loss:
|
| 35 |
+
use_pos_weight: true
|
| 36 |
+
label_smoothing: 0.0
|
evaluate.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""Unified evaluation for MeCab (JUMANDIC) and the trained model.
|
| 5 |
+
|
| 6 |
+
Evaluates both systems on the same KWDLC test data and compares results.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import subprocess
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_knp_file(knp_file: Path) -> List[Dict]:
|
| 19 |
+
"""Extract gold morphemes from a KNP file."""
|
| 20 |
+
sentences = []
|
| 21 |
+
current_sentence = []
|
| 22 |
+
current_text = ""
|
| 23 |
+
|
| 24 |
+
with open(knp_file, "r", encoding="utf-8") as f:
|
| 25 |
+
for line in f:
|
| 26 |
+
line = line.rstrip("\n")
|
| 27 |
+
|
| 28 |
+
if line.startswith("#"):
|
| 29 |
+
if line.startswith("# S-ID:"):
|
| 30 |
+
if current_sentence:
|
| 31 |
+
sentences.append({"morphemes": current_sentence, "text": current_text})
|
| 32 |
+
current_sentence = []
|
| 33 |
+
current_text = ""
|
| 34 |
+
continue
|
| 35 |
+
elif line == "EOS":
|
| 36 |
+
if current_sentence:
|
| 37 |
+
sentences.append({"morphemes": current_sentence, "text": current_text})
|
| 38 |
+
current_sentence = []
|
| 39 |
+
current_text = ""
|
| 40 |
+
elif line.startswith("+") or line.startswith("*"):
|
| 41 |
+
continue
|
| 42 |
+
elif line:
|
| 43 |
+
parts = line.split(" ")
|
| 44 |
+
if len(parts) >= 4:
|
| 45 |
+
surface = parts[0]
|
| 46 |
+
reading = parts[1]
|
| 47 |
+
pos = parts[3]
|
| 48 |
+
|
| 49 |
+
current_sentence.append({"surface": surface, "reading": reading, "pos": pos})
|
| 50 |
+
current_text += surface
|
| 51 |
+
|
| 52 |
+
return sentences
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def analyze_with_mecab(text: str) -> List[Dict]:
|
| 56 |
+
"""Analyze text with MeCab (JUMANDIC) using a simple best-path parse."""
|
| 57 |
+
try:
|
| 58 |
+
result = subprocess.run(
|
| 59 |
+
["mecab", "-d", "/var/lib/mecab/dic/juman-utf8"],
|
| 60 |
+
input=text,
|
| 61 |
+
capture_output=True,
|
| 62 |
+
text=True,
|
| 63 |
+
encoding="utf-8",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if result.returncode != 0:
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
morphemes = []
|
| 70 |
+
for line in result.stdout.strip().split("\n"):
|
| 71 |
+
if line == "EOS":
|
| 72 |
+
break
|
| 73 |
+
parts = line.split("\t")
|
| 74 |
+
if len(parts) >= 2:
|
| 75 |
+
surface = parts[0]
|
| 76 |
+
features = parts[1].split(",")
|
| 77 |
+
if len(features) >= 7:
|
| 78 |
+
pos = features[0]
|
| 79 |
+
# Do not fallback reading to surface when missing ('*')
|
| 80 |
+
reading = features[7] if len(features) > 7 and features[7] != "*" else ""
|
| 81 |
+
|
| 82 |
+
morphemes.append({"surface": surface, "reading": reading, "pos": pos})
|
| 83 |
+
|
| 84 |
+
return morphemes
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"MeCab error: {e}")
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def analyze_with_jumanpp(text: str) -> List[Dict]:
|
| 91 |
+
"""Analyze text with JUMAN++ (optional baseline)."""
|
| 92 |
+
try:
|
| 93 |
+
result = subprocess.run(["jumanpp"], input=text, capture_output=True, text=True, encoding="utf-8")
|
| 94 |
+
|
| 95 |
+
if result.returncode != 0:
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
morphemes = []
|
| 99 |
+
for line in result.stdout.strip().split("\n"):
|
| 100 |
+
if line.startswith("@") or line == "EOS":
|
| 101 |
+
continue
|
| 102 |
+
parts = line.split(" ")
|
| 103 |
+
if len(parts) >= 12:
|
| 104 |
+
surface = parts[0]
|
| 105 |
+
reading = parts[1]
|
| 106 |
+
pos = parts[3]
|
| 107 |
+
|
| 108 |
+
morphemes.append({"surface": surface, "reading": reading, "pos": pos})
|
| 109 |
+
|
| 110 |
+
return morphemes
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"JUMAN++ error: {e}")
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def analyze_with_model(text: str, model, experiment_info) -> List[Dict]:
|
| 117 |
+
"""Analyze text with the trained model."""
|
| 118 |
+
try:
|
| 119 |
+
import infer
|
| 120 |
+
|
| 121 |
+
results, optimal_morphemes = infer.predict_morphemes_from_text(
|
| 122 |
+
text, model=model, experiment_info=experiment_info, silent=True
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
morphemes = []
|
| 126 |
+
for morph in optimal_morphemes:
|
| 127 |
+
morphemes.append(
|
| 128 |
+
{"surface": morph["surface"], "reading": morph.get("reading", ""), "pos": morph.get("pos", "*")}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return morphemes
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Model inference error: {e}")
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def evaluate_morphemes(gold_morphemes: List[Dict], pred_morphemes: List[Dict]) -> Dict:
|
| 138 |
+
"""Compute segmentation and POS F1 between gold and predictions."""
|
| 139 |
+
gold_spans = []
|
| 140 |
+
pred_spans = []
|
| 141 |
+
|
| 142 |
+
# Gold spans (from gold morphemes)
|
| 143 |
+
pos = 0
|
| 144 |
+
for m in gold_morphemes:
|
| 145 |
+
surface = m["surface"]
|
| 146 |
+
end = pos + len(surface)
|
| 147 |
+
gold_spans.append((pos, end, m["pos"]))
|
| 148 |
+
pos = end
|
| 149 |
+
|
| 150 |
+
# Predicted spans (from predictions)
|
| 151 |
+
pos = 0
|
| 152 |
+
for m in pred_morphemes:
|
| 153 |
+
surface = m["surface"]
|
| 154 |
+
end = pos + len(surface)
|
| 155 |
+
pred_spans.append((pos, end, m["pos"]))
|
| 156 |
+
pos = end
|
| 157 |
+
|
| 158 |
+
# Segmentation accuracy (without POS)
|
| 159 |
+
gold_seg = {(s, e) for s, e, _ in gold_spans}
|
| 160 |
+
pred_seg = {(s, e) for s, e, _ in pred_spans}
|
| 161 |
+
|
| 162 |
+
seg_correct = len(gold_seg & pred_seg)
|
| 163 |
+
seg_precision = seg_correct / len(pred_seg) if pred_seg else 0
|
| 164 |
+
seg_recall = seg_correct / len(gold_seg) if gold_seg else 0
|
| 165 |
+
seg_f1 = 2 * seg_precision * seg_recall / (seg_precision + seg_recall) if (seg_precision + seg_recall) > 0 else 0
|
| 166 |
+
|
| 167 |
+
# Accuracy with POS
|
| 168 |
+
gold_pos = set(gold_spans)
|
| 169 |
+
pred_pos = set(pred_spans)
|
| 170 |
+
|
| 171 |
+
pos_correct = len(gold_pos & pred_pos)
|
| 172 |
+
pos_precision = pos_correct / len(pred_pos) if pred_pos else 0
|
| 173 |
+
pos_recall = pos_correct / len(gold_pos) if gold_pos else 0
|
| 174 |
+
pos_f1 = 2 * pos_precision * pos_recall / (pos_precision + pos_recall) if (pos_precision + pos_recall) > 0 else 0
|
| 175 |
+
|
| 176 |
+
return {
|
| 177 |
+
"seg_precision": seg_precision,
|
| 178 |
+
"seg_recall": seg_recall,
|
| 179 |
+
"seg_f1": seg_f1,
|
| 180 |
+
"pos_precision": pos_precision,
|
| 181 |
+
"pos_recall": pos_recall,
|
| 182 |
+
"pos_f1": pos_f1,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
parser = argparse.ArgumentParser(description="Unified evaluation script")
|
| 188 |
+
parser.add_argument("--kwdlc-dir", type=str, default="KWDLC", help="Path to KWDLC root directory")
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--test-ids", type=str, default="KWDLC/id/split_for_pas/test.id", help="File containing test IDs (one per line)"
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--max-samples", type=int, default=None, help="Max number of samples to evaluate (default: all)"
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument("--experiment", "-e", type=str, required=True, help="Experiment name to evaluate")
|
| 196 |
+
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
# Load test document IDs
|
| 200 |
+
test_ids = []
|
| 201 |
+
with open(args.test_ids, "r") as f:
|
| 202 |
+
for line in f:
|
| 203 |
+
test_ids.append(line.strip())
|
| 204 |
+
|
| 205 |
+
if args.max_samples is not None:
|
| 206 |
+
test_ids = test_ids[: args.max_samples]
|
| 207 |
+
|
| 208 |
+
print(f"Evaluating: {len(test_ids)} files")
|
| 209 |
+
|
| 210 |
+
import infer
|
| 211 |
+
|
| 212 |
+
model_info = infer.load_model(experiment_name=args.experiment)
|
| 213 |
+
if model_info:
|
| 214 |
+
model, experiment_info = model_info
|
| 215 |
+
# Force CPU execution for evaluation
|
| 216 |
+
device = torch.device("cpu")
|
| 217 |
+
model = model.to(device)
|
| 218 |
+
experiment_info["device"] = device
|
| 219 |
+
print(f"Model: {experiment_info['name']}")
|
| 220 |
+
else:
|
| 221 |
+
print("Failed to load model")
|
| 222 |
+
model = None
|
| 223 |
+
experiment_info = None
|
| 224 |
+
|
| 225 |
+
mecab_results = []
|
| 226 |
+
model_results = []
|
| 227 |
+
|
| 228 |
+
print("\nStart evaluation...")
|
| 229 |
+
for test_id in tqdm(test_ids, desc="evaluating"):
|
| 230 |
+
# Find KNP file
|
| 231 |
+
found = False
|
| 232 |
+
knp_base = Path(args.kwdlc_dir) / "knp"
|
| 233 |
+
|
| 234 |
+
for subdir in knp_base.glob("w*"):
|
| 235 |
+
candidate = subdir / f"{test_id}.knp"
|
| 236 |
+
if candidate.exists():
|
| 237 |
+
knp_path = candidate
|
| 238 |
+
found = True
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if not found:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
# Read gold data
|
| 245 |
+
gold_sentences = parse_knp_file(knp_path)
|
| 246 |
+
|
| 247 |
+
for sent_data in gold_sentences:
|
| 248 |
+
text = sent_data["text"]
|
| 249 |
+
gold_morphemes = sent_data["morphemes"]
|
| 250 |
+
|
| 251 |
+
# MeCab (JUMANDIC)
|
| 252 |
+
pred_mecab = analyze_with_mecab(text)
|
| 253 |
+
if pred_mecab:
|
| 254 |
+
result = evaluate_morphemes(gold_morphemes, pred_mecab)
|
| 255 |
+
mecab_results.append(result)
|
| 256 |
+
|
| 257 |
+
# Trained model
|
| 258 |
+
if model is not None:
|
| 259 |
+
pred_model = analyze_with_model(text, model, experiment_info)
|
| 260 |
+
if pred_model:
|
| 261 |
+
model_eval = evaluate_morphemes(gold_morphemes, pred_model)
|
| 262 |
+
model_results.append(model_eval)
|
| 263 |
+
|
| 264 |
+
# Aggregate and display results
|
| 265 |
+
print("\n" + "=" * 70)
|
| 266 |
+
print("Evaluation Results (KWDLC test data)")
|
| 267 |
+
print("=" * 70)
|
| 268 |
+
print(f"Num evaluated: MeCab={len(mecab_results)}, Model={len(model_results)}")
|
| 269 |
+
|
| 270 |
+
# MeCab (JUMANDIC)
|
| 271 |
+
if mecab_results:
|
| 272 |
+
avg_seg_f1 = sum(r["seg_f1"] for r in mecab_results) / len(mecab_results)
|
| 273 |
+
avg_pos_f1 = sum(r["pos_f1"] for r in mecab_results) / len(mecab_results)
|
| 274 |
+
print("\n[1] MeCab (JUMANDIC):")
|
| 275 |
+
print(f" Seg F1: {avg_seg_f1:.4f}")
|
| 276 |
+
print(f" POS F1: {avg_pos_f1:.4f}")
|
| 277 |
+
|
| 278 |
+
# Trained model
|
| 279 |
+
if model_results:
|
| 280 |
+
avg_seg_f1 = sum(r["seg_f1"] for r in model_results) / len(model_results)
|
| 281 |
+
avg_pos_f1 = sum(r["pos_f1"] for r in model_results) / len(model_results)
|
| 282 |
+
print(f"\n[2] Trained model ({experiment_info['name']}):")
|
| 283 |
+
print(f" Seg F1: {avg_seg_f1:.4f}")
|
| 284 |
+
print(f" POS F1: {avg_pos_f1:.4f}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
infer.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Show immediate feedback from the moment the command starts
|
| 5 |
+
print("Loading model...", flush=True)
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
# Disable WandB during inference to avoid hanging processes
|
| 12 |
+
os.environ["WANDB_MODE"] = "disabled"
|
| 13 |
+
|
| 14 |
+
from importlib import import_module
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
from mecari.analyzers.mecab import MeCabAnalyzer
|
| 21 |
+
from mecari.data.data_module import DataModule
|
| 22 |
+
from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def set_seed(seed: int = 42) -> None:
|
| 26 |
+
"""Set random seeds for reproducibility during inference.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
seed: Random seed value.
|
| 30 |
+
"""
|
| 31 |
+
random.seed(seed)
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
torch.cuda.manual_seed(seed)
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
torch.backends.cudnn.deterministic = True
|
| 37 |
+
torch.backends.cudnn.benchmark = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
set_seed(42)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _find_best_checkpoint(checkpoints_dir: str, prefer_metric: str = "val_error") -> Tuple[Optional[str], float]:
|
| 44 |
+
"""Find the best checkpoint file in a directory.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
checkpoints_dir: Path to the checkpoints directory.
|
| 48 |
+
prefer_metric: Preferred metric ("val_error" or "val_loss").
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple of (best checkpoint filename, score).
|
| 52 |
+
"""
|
| 53 |
+
checkpoint_files = [f for f in os.listdir(checkpoints_dir) if f.endswith(".ckpt")]
|
| 54 |
+
if not checkpoint_files:
|
| 55 |
+
return None, float("inf")
|
| 56 |
+
|
| 57 |
+
best_checkpoint = None
|
| 58 |
+
best_score = float("inf")
|
| 59 |
+
|
| 60 |
+
# Prefer filenames that include the metric keyword (e.g., val_error=..., val_error_epoch=...)
|
| 61 |
+
for ckpt_file in checkpoint_files:
|
| 62 |
+
if prefer_metric == "val_loss" and ("val_loss=" in ckpt_file or "val_loss_epoch=" in ckpt_file):
|
| 63 |
+
try:
|
| 64 |
+
if "val_loss_epoch=" in ckpt_file:
|
| 65 |
+
score_str = ckpt_file.split("val_loss_epoch=")[-1].split(".ckpt")[0]
|
| 66 |
+
else:
|
| 67 |
+
score_str = ckpt_file.split("val_loss=")[-1].split(".ckpt")[0]
|
| 68 |
+
score = float(score_str)
|
| 69 |
+
if score < best_score:
|
| 70 |
+
best_score = score
|
| 71 |
+
best_checkpoint = ckpt_file
|
| 72 |
+
except (ValueError, IndexError):
|
| 73 |
+
pass
|
| 74 |
+
elif prefer_metric == "val_error" and ("val_error=" in ckpt_file or "val_error_epoch=" in ckpt_file):
|
| 75 |
+
try:
|
| 76 |
+
if "val_error_epoch=" in ckpt_file:
|
| 77 |
+
score_str = ckpt_file.split("val_error_epoch=")[-1].split(".ckpt")[0]
|
| 78 |
+
else:
|
| 79 |
+
score_str = ckpt_file.split("val_error=")[-1].split(".ckpt")[0]
|
| 80 |
+
score = float(score_str)
|
| 81 |
+
if score < best_score:
|
| 82 |
+
best_score = score
|
| 83 |
+
best_checkpoint = ckpt_file
|
| 84 |
+
except (ValueError, IndexError):
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
# If not found, try the alternative metric
|
| 88 |
+
if not best_checkpoint:
|
| 89 |
+
other_metric = "val_loss" if prefer_metric == "val_error" else "val_error"
|
| 90 |
+
for ckpt_file in checkpoint_files:
|
| 91 |
+
if other_metric == "val_loss" and "val_loss=" in ckpt_file:
|
| 92 |
+
try:
|
| 93 |
+
score_str = ckpt_file.split("val_loss=")[1].split("-loss.ckpt")[0]
|
| 94 |
+
score = float(score_str)
|
| 95 |
+
if score < best_score:
|
| 96 |
+
best_score = score
|
| 97 |
+
best_checkpoint = ckpt_file
|
| 98 |
+
except (ValueError, IndexError):
|
| 99 |
+
pass
|
| 100 |
+
elif other_metric == "val_error" and "val_error=" in ckpt_file:
|
| 101 |
+
try:
|
| 102 |
+
score_str = ckpt_file.split("val_error=")[1].split(".ckpt")[0]
|
| 103 |
+
score = float(score_str)
|
| 104 |
+
if score < best_score:
|
| 105 |
+
best_score = score
|
| 106 |
+
best_checkpoint = ckpt_file
|
| 107 |
+
except (ValueError, IndexError):
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
# Additional fallback: parse score from filename pattern (model-epoch-score.ckpt)
|
| 111 |
+
if not best_checkpoint:
|
| 112 |
+
for ckpt_file in sorted(checkpoint_files):
|
| 113 |
+
if ckpt_file == "last.ckpt":
|
| 114 |
+
continue
|
| 115 |
+
try:
|
| 116 |
+
stem = ckpt_file[:-5] if ckpt_file.endswith(".ckpt") else ckpt_file
|
| 117 |
+
# Fallback: treat the last hyphen-separated token as a score
|
| 118 |
+
last_tok = stem.split("-")[-1]
|
| 119 |
+
score = float(last_tok)
|
| 120 |
+
if score < best_score:
|
| 121 |
+
best_score = score
|
| 122 |
+
best_checkpoint = ckpt_file
|
| 123 |
+
except Exception:
|
| 124 |
+
continue
|
| 125 |
+
# Final fallback: use last.ckpt or the first file
|
| 126 |
+
if not best_checkpoint:
|
| 127 |
+
if "last.ckpt" in checkpoint_files:
|
| 128 |
+
best_checkpoint = "last.ckpt"
|
| 129 |
+
else:
|
| 130 |
+
best_checkpoint = sorted(checkpoint_files)[0]
|
| 131 |
+
|
| 132 |
+
return best_checkpoint, best_score
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _load_model_by_type(model_type: str, checkpoint_path: str) -> Any:
|
| 136 |
+
"""Load the appropriate model class based on type.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model_type: Model type ("gat" or "gatv2").
|
| 140 |
+
checkpoint_path: Path to the checkpoint file.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Loaded model instance.
|
| 144 |
+
"""
|
| 145 |
+
if model_type == "gatv2":
|
| 146 |
+
cls = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
|
| 147 |
+
model = cls.load_from_checkpoint(checkpoint_path, strict=False, map_location="cpu")
|
| 148 |
+
|
| 149 |
+
model.eval()
|
| 150 |
+
model.cpu()
|
| 151 |
+
return model
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _instantiate_model_from_config(config: Dict[str, Any]):
|
| 155 |
+
"""Instantiate a model using config fields (no checkpoint loading)."""
|
| 156 |
+
model_cfg = config.get("model", {})
|
| 157 |
+
training_cfg = config.get("training", {})
|
| 158 |
+
features_cfg = config.get("features", {})
|
| 159 |
+
|
| 160 |
+
if model_cfg.get("type") != "gatv2":
|
| 161 |
+
raise ValueError(f"Unsupported model type: {model_cfg.get('type')}")
|
| 162 |
+
|
| 163 |
+
MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
|
| 164 |
+
model = MecariGATv2(
|
| 165 |
+
hidden_dim=model_cfg.get("hidden_dim", 64),
|
| 166 |
+
num_classes=model_cfg.get("num_classes", 1),
|
| 167 |
+
learning_rate=training_cfg.get("learning_rate", 1e-3),
|
| 168 |
+
lexical_feature_dim=features_cfg.get("lexical_feature_dim", 100000),
|
| 169 |
+
num_heads=model_cfg.get("num_heads", 4),
|
| 170 |
+
share_weights=model_cfg.get("share_weights", False),
|
| 171 |
+
dropout=model_cfg.get("dropout", 0.1),
|
| 172 |
+
attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)),
|
| 173 |
+
add_self_loops_flag=model_cfg.get("add_self_loops", True),
|
| 174 |
+
edge_dropout=model_cfg.get("edge_dropout", 0.0),
|
| 175 |
+
norm=model_cfg.get("norm", "layer"),
|
| 176 |
+
)
|
| 177 |
+
return model
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _load_model_from_state(config_path: str, state_path: str):
|
| 181 |
+
"""Load model from a plain state_dict plus config.yaml."""
|
| 182 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 183 |
+
config = yaml.safe_load(f)
|
| 184 |
+
|
| 185 |
+
model = _instantiate_model_from_config(config)
|
| 186 |
+
state = torch.load(state_path, map_location="cpu")
|
| 187 |
+
# Lightning checkpoints saved via export may store under 'state_dict' already
|
| 188 |
+
if (
|
| 189 |
+
isinstance(state, dict)
|
| 190 |
+
and "state_dict" in state
|
| 191 |
+
and all(k.startswith("model.") for k in state["state_dict"].keys())
|
| 192 |
+
):
|
| 193 |
+
state = state["state_dict"]
|
| 194 |
+
# Remove potential 'model.' prefix if present (depends on save path)
|
| 195 |
+
new_state = {}
|
| 196 |
+
for k, v in state.items():
|
| 197 |
+
nk = k
|
| 198 |
+
if k.startswith("model."):
|
| 199 |
+
nk = k[len("model.") :]
|
| 200 |
+
new_state[nk] = v
|
| 201 |
+
model.load_state_dict(new_state, strict=False)
|
| 202 |
+
model.eval()
|
| 203 |
+
model.cpu()
|
| 204 |
+
return model
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def load_model(
|
| 208 |
+
experiment_name: Optional[str] = None, model_type: Optional[str] = None, prefer_metric: str = "val_error"
|
| 209 |
+
) -> Optional[Tuple[Any, Dict[str, Any]]]:
|
| 210 |
+
"""Load a trained model and its experiment info.
|
| 211 |
+
|
| 212 |
+
Default behavior: load the single model under sample_model/.
|
| 213 |
+
If --experiment is provided (or sample_model is unavailable), use experiments/.
|
| 214 |
+
"""
|
| 215 |
+
# Default: load from sample_model/
|
| 216 |
+
if not experiment_name:
|
| 217 |
+
root = "sample_model"
|
| 218 |
+
if os.path.exists(root):
|
| 219 |
+
fixed_config = os.path.join(root, "config.yaml")
|
| 220 |
+
state_path = os.path.join(root, "model.pt")
|
| 221 |
+
if os.path.exists(fixed_config) and os.path.exists(state_path):
|
| 222 |
+
try:
|
| 223 |
+
with open(fixed_config, "r", encoding="utf-8") as f:
|
| 224 |
+
config = yaml.safe_load(f)
|
| 225 |
+
model = _load_model_from_state(fixed_config, state_path)
|
| 226 |
+
experiment_info = {
|
| 227 |
+
"name": os.path.basename(root),
|
| 228 |
+
"path": root,
|
| 229 |
+
"best_metric": None,
|
| 230 |
+
"best_score": None,
|
| 231 |
+
"model_type": config.get("model", {}).get("type", "unknown"),
|
| 232 |
+
"best_model_path": state_path,
|
| 233 |
+
"config": config,
|
| 234 |
+
}
|
| 235 |
+
return model, experiment_info
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"Failed to load sample model: {e}")
|
| 238 |
+
return None
|
| 239 |
+
print("sample_model/model.pt or config.yaml not found")
|
| 240 |
+
return None
|
| 241 |
+
else:
|
| 242 |
+
print("sample_model directory not found")
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
# Specific experiment provided
|
| 246 |
+
if experiment_name:
|
| 247 |
+
exp_path = os.path.join("experiments", experiment_name)
|
| 248 |
+
config_path = os.path.join(exp_path, "config.yaml")
|
| 249 |
+
checkpoints_dir = os.path.join(exp_path, "checkpoints")
|
| 250 |
+
|
| 251 |
+
if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir):
|
| 252 |
+
print(f"Experiment not found: {experiment_name}")
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 257 |
+
config = yaml.safe_load(f)
|
| 258 |
+
|
| 259 |
+
model_type_from_config = config.get("model", {}).get("type", "unknown")
|
| 260 |
+
best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric)
|
| 261 |
+
|
| 262 |
+
if not best_checkpoint:
|
| 263 |
+
print("No checkpoint found")
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error"
|
| 267 |
+
|
| 268 |
+
experiment_info = {
|
| 269 |
+
"name": experiment_name,
|
| 270 |
+
"path": exp_path,
|
| 271 |
+
"val_error": best_score if prefer_metric == "val_error" else None,
|
| 272 |
+
"val_loss": best_score if prefer_metric == "val_loss" else None,
|
| 273 |
+
"best_metric": metric_name,
|
| 274 |
+
"best_score": best_score,
|
| 275 |
+
"model_type": model_type_from_config,
|
| 276 |
+
"best_model_path": os.path.join(checkpoints_dir, best_checkpoint),
|
| 277 |
+
"config": config,
|
| 278 |
+
}
|
| 279 |
+
except Exception as e:
|
| 280 |
+
print(f"Failed to read experiment info: {e}")
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
# Auto-select the best experiment
|
| 284 |
+
else:
|
| 285 |
+
if not os.path.exists(experiments_dir):
|
| 286 |
+
print("Experiments directory does not exist")
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
experiments = []
|
| 290 |
+
for exp_dir in os.listdir(experiments_dir):
|
| 291 |
+
exp_path = os.path.join(experiments_dir, exp_dir)
|
| 292 |
+
config_path = os.path.join(exp_path, "config.yaml")
|
| 293 |
+
checkpoints_dir = os.path.join(exp_path, "checkpoints")
|
| 294 |
+
|
| 295 |
+
if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir):
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 300 |
+
config = yaml.safe_load(f)
|
| 301 |
+
|
| 302 |
+
exp_model_type = config.get("model", {}).get("type", "unknown")
|
| 303 |
+
|
| 304 |
+
if model_type and exp_model_type.lower() != model_type.lower():
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric)
|
| 308 |
+
if best_checkpoint:
|
| 309 |
+
metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error"
|
| 310 |
+
experiments.append(
|
| 311 |
+
{
|
| 312 |
+
"name": exp_dir,
|
| 313 |
+
"path": exp_path,
|
| 314 |
+
"val_error": best_score if prefer_metric == "val_error" else None,
|
| 315 |
+
"val_loss": best_score if prefer_metric == "val_loss" else None,
|
| 316 |
+
"best_metric": metric_name,
|
| 317 |
+
"best_score": best_score,
|
| 318 |
+
"model_type": exp_model_type,
|
| 319 |
+
"best_model_path": os.path.join(checkpoints_dir, best_checkpoint),
|
| 320 |
+
"config": config,
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
+
except Exception:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
if not experiments:
|
| 327 |
+
print("No available experiments found")
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
experiment_info = min(experiments, key=lambda x: x["best_score"])
|
| 331 |
+
|
| 332 |
+
# Load model
|
| 333 |
+
print(f"Loading model: {experiment_info['best_model_path']}")
|
| 334 |
+
print(f"Experiment: {experiment_info['name']}")
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
model = _load_model_by_type(experiment_info["model_type"], experiment_info["best_model_path"])
|
| 338 |
+
|
| 339 |
+
# No BERT features in this pipeline
|
| 340 |
+
|
| 341 |
+
return model, experiment_info
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"Model loading error: {e}")
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def viterbi_decode_from_morphemes(logits: torch.Tensor, morphemes: list, edges: list, silent: bool = False) -> list:
|
| 348 |
+
"""Edge-based Viterbi decoding.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
logits: Logits per morpheme.
|
| 352 |
+
morphemes: List of morpheme records.
|
| 353 |
+
edges: Edge list among morpheme indices.
|
| 354 |
+
silent: If True, suppress debug prints.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Indices of morphemes on the optimal path.
|
| 358 |
+
"""
|
| 359 |
+
if len(logits) != len(morphemes):
|
| 360 |
+
if not silent:
|
| 361 |
+
print(f"Warning: #logits ({len(logits)}) != #morphemes ({len(morphemes)})")
|
| 362 |
+
return list(range(min(len(logits), len(morphemes))))
|
| 363 |
+
|
| 364 |
+
if not silent:
|
| 365 |
+
print("\n=== Viterbi Decode ===")
|
| 366 |
+
print(f"#Morphemes: {len(morphemes)}")
|
| 367 |
+
print(f"Using edge info: {len(edges)} edges")
|
| 368 |
+
|
| 369 |
+
print("\nNode logits:")
|
| 370 |
+
for idx, (morph, logit) in enumerate(zip(morphemes, logits)):
|
| 371 |
+
print(
|
| 372 |
+
f" [{idx:3d}] {morph['surface']:10s} ({morph['start_pos']:2d}-{morph['end_pos']:2d}) {morph['pos']:10s} logit={logit:.3f}"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Build adjacency from edges (forward edges only)
|
| 376 |
+
n = len(morphemes)
|
| 377 |
+
adj_list = [[] for _ in range(n)]
|
| 378 |
+
for edge in edges:
|
| 379 |
+
source_idx = edge["source_idx"]
|
| 380 |
+
target_idx = edge["target_idx"]
|
| 381 |
+
if 0 <= source_idx < n and 0 <= target_idx < n:
|
| 382 |
+
# Add forward edges only (source.end_pos <= target.start_pos)
|
| 383 |
+
source_end = morphemes[source_idx].get("end_pos", 0)
|
| 384 |
+
target_start = morphemes[target_idx].get("start_pos", 0)
|
| 385 |
+
if source_end <= target_start:
|
| 386 |
+
adj_list[source_idx].append(target_idx)
|
| 387 |
+
|
| 388 |
+
# POS to UD mapping (for display)
|
| 389 |
+
pos_to_ud = {
|
| 390 |
+
"名詞": "NOUN",
|
| 391 |
+
"動詞": "VERB",
|
| 392 |
+
"形容詞": "ADJ",
|
| 393 |
+
"副詞": "ADV",
|
| 394 |
+
"助詞": "ADP", # approximate
|
| 395 |
+
"助動詞": "AUX",
|
| 396 |
+
"接続詞": "CCONJ",
|
| 397 |
+
"連体詞": "DET",
|
| 398 |
+
"感動詞": "INTJ",
|
| 399 |
+
"代名詞": "PRON",
|
| 400 |
+
"形状詞": "ADJ",
|
| 401 |
+
"補助記号": "PUNCT",
|
| 402 |
+
"接頭辞": "PREFIX",
|
| 403 |
+
"接尾辞": "SUFFIX",
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
if not silent:
|
| 407 |
+
print("\nMorpheme details:")
|
| 408 |
+
for i, morpheme in enumerate(morphemes):
|
| 409 |
+
start_pos = morpheme.get("start_pos", 0)
|
| 410 |
+
end_pos = morpheme.get("end_pos", 0)
|
| 411 |
+
surface = morpheme.get("surface", "")
|
| 412 |
+
logit = morpheme.get("logit", 0.0)
|
| 413 |
+
pos = morpheme.get("pos", "")
|
| 414 |
+
pos_main = pos.split(",")[0] if "," in pos else pos
|
| 415 |
+
ud_pos = pos_to_ud.get(pos_main, "X")
|
| 416 |
+
print(f" {i}: {surface} ({start_pos}-{end_pos}) {pos_main}({ud_pos}) logit={logit:.3f}")
|
| 417 |
+
|
| 418 |
+
# Dynamic programming
|
| 419 |
+
dp = [-float("inf")] * n # max score to each node
|
| 420 |
+
parent = [-1] * n # best predecessor per node
|
| 421 |
+
|
| 422 |
+
# Find start nodes (earliest start position)
|
| 423 |
+
start_nodes = []
|
| 424 |
+
min_start_pos = min(m.get("start_pos", 0) for m in morphemes)
|
| 425 |
+
for i, m in enumerate(morphemes):
|
| 426 |
+
if m.get("start_pos", 0) == min_start_pos:
|
| 427 |
+
start_nodes.append(i)
|
| 428 |
+
|
| 429 |
+
# Initialize start nodes
|
| 430 |
+
for i in start_nodes:
|
| 431 |
+
dp[i] = morphemes[i].get("logit", 0.0)
|
| 432 |
+
|
| 433 |
+
# Process nodes in position order (topological-like)
|
| 434 |
+
node_positions = [(i, morphemes[i].get("start_pos", 0), morphemes[i].get("end_pos", 0)) for i in range(n)]
|
| 435 |
+
node_positions.sort(key=lambda x: (x[1], x[2])) # sort by start_pos, end_pos
|
| 436 |
+
|
| 437 |
+
# Relax edges for each node in order
|
| 438 |
+
for node_idx, _, _ in node_positions:
|
| 439 |
+
if dp[node_idx] == -float("inf"):
|
| 440 |
+
continue # unreachable node
|
| 441 |
+
|
| 442 |
+
# Relax transitions to reachable next nodes
|
| 443 |
+
for next_idx in adj_list[node_idx]:
|
| 444 |
+
new_score = dp[node_idx] + morphemes[next_idx].get("logit", 0.0)
|
| 445 |
+
if new_score > dp[next_idx]:
|
| 446 |
+
dp[next_idx] = new_score
|
| 447 |
+
parent[next_idx] = node_idx
|
| 448 |
+
|
| 449 |
+
# Select best end node at the final position
|
| 450 |
+
end_nodes = []
|
| 451 |
+
max_end_pos = max(m.get("end_pos", 0) for m in morphemes)
|
| 452 |
+
for i, m in enumerate(morphemes):
|
| 453 |
+
if m.get("end_pos", 0) == max_end_pos:
|
| 454 |
+
end_nodes.append(i)
|
| 455 |
+
|
| 456 |
+
best_end_idx = -1
|
| 457 |
+
best_score = -float("inf")
|
| 458 |
+
for i in end_nodes:
|
| 459 |
+
if dp[i] > best_score:
|
| 460 |
+
best_score = dp[i]
|
| 461 |
+
best_end_idx = i
|
| 462 |
+
|
| 463 |
+
# Backtracking with safety cap to avoid infinite loops
|
| 464 |
+
path = []
|
| 465 |
+
current = best_end_idx
|
| 466 |
+
max_iterations = n * 2 # safety cap
|
| 467 |
+
iteration_count = 0
|
| 468 |
+
visited = set()
|
| 469 |
+
|
| 470 |
+
while current != -1 and iteration_count < max_iterations:
|
| 471 |
+
if current in visited:
|
| 472 |
+
print(f"Warning: Detected cycle during backtracking (node {current})")
|
| 473 |
+
break
|
| 474 |
+
visited.add(current)
|
| 475 |
+
path.append(current)
|
| 476 |
+
current = parent[current]
|
| 477 |
+
iteration_count += 1
|
| 478 |
+
|
| 479 |
+
if iteration_count >= max_iterations:
|
| 480 |
+
print(f"Warning: Backtracking reached max iterations ({max_iterations})")
|
| 481 |
+
|
| 482 |
+
path.reverse()
|
| 483 |
+
|
| 484 |
+
# Display
|
| 485 |
+
if path:
|
| 486 |
+
total_score = sum(morphemes[idx].get("logit", 0.0) for idx in path)
|
| 487 |
+
if not silent:
|
| 488 |
+
print(f"\nOptimal path (total score: {total_score:.3f}):")
|
| 489 |
+
for idx in path:
|
| 490 |
+
morpheme = morphemes[idx]
|
| 491 |
+
logit = morpheme.get("logit", 0.0)
|
| 492 |
+
print(f" {morpheme['surface']} (logit: {logit:.3f})")
|
| 493 |
+
|
| 494 |
+
return path
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
##
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# Global singletons (lazy initialization)
|
| 501 |
+
_analyzer = None
|
| 502 |
+
_data_module_cache = {}
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def predict_morphemes_from_text(text, model=None, experiment_info=None, silent=False):
|
| 506 |
+
"""Predict morpheme boundaries from text.
|
| 507 |
+
|
| 508 |
+
Steps:
|
| 509 |
+
1. Analyze with MeCab to get candidates.
|
| 510 |
+
2. Build nodes/edges from morphemes and connections.
|
| 511 |
+
3. Run the model to get per-node scores.
|
| 512 |
+
4. Run Viterbi decoding over nodes and edges.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
text: Input text.
|
| 516 |
+
model: Model to use.
|
| 517 |
+
experiment_info: Experiment metadata.
|
| 518 |
+
silent: If True, suppress prints.
|
| 519 |
+
"""
|
| 520 |
+
global _analyzer
|
| 521 |
+
|
| 522 |
+
if model is None:
|
| 523 |
+
result = load_model()
|
| 524 |
+
if result is None:
|
| 525 |
+
return [], []
|
| 526 |
+
model, experiment_info = result
|
| 527 |
+
|
| 528 |
+
if not silent:
|
| 529 |
+
print(f"Input text: {text}")
|
| 530 |
+
|
| 531 |
+
# 1) Get morpheme candidates (initialize analyzer on first use)
|
| 532 |
+
if _analyzer is None:
|
| 533 |
+
_analyzer = MeCabAnalyzer()
|
| 534 |
+
|
| 535 |
+
# Fetch candidates directly via analyzer and deduplicate
|
| 536 |
+
candidates = _analyzer.get_morpheme_candidates(text)
|
| 537 |
+
candidates = normalize_mecab_candidates(candidates)
|
| 538 |
+
candidates = dedup_morphemes(candidates)
|
| 539 |
+
|
| 540 |
+
if not candidates:
|
| 541 |
+
print("Error: Failed to obtain morpheme candidates")
|
| 542 |
+
return [], []
|
| 543 |
+
|
| 544 |
+
if not silent:
|
| 545 |
+
print(f"#Candidates: {len(candidates)}")
|
| 546 |
+
|
| 547 |
+
# 2) Use candidates as morphemes
|
| 548 |
+
morphemes = candidates
|
| 549 |
+
|
| 550 |
+
# Validate type
|
| 551 |
+
if not isinstance(morphemes, list):
|
| 552 |
+
print(f"Warning: morphemes is not a list: {type(morphemes)}")
|
| 553 |
+
morphemes = []
|
| 554 |
+
|
| 555 |
+
# Add lexical features using the shared DataModule implementation
|
| 556 |
+
dm_tmp = DataModule(annotations_dir="dummy", batch_size=1, num_workers=0, lexical_feature_dim=100000, silent=True)
|
| 557 |
+
morphemes = dm_tmp.compute_lexical_features(morphemes, text)
|
| 558 |
+
|
| 559 |
+
# Build edges (adjacent only)
|
| 560 |
+
edges = build_adjacent_edges(morphemes)
|
| 561 |
+
|
| 562 |
+
# Add annotation field as '?' for inference
|
| 563 |
+
for morpheme in morphemes:
|
| 564 |
+
if "annotation" not in morpheme:
|
| 565 |
+
morpheme["annotation"] = "?"
|
| 566 |
+
|
| 567 |
+
if not silent:
|
| 568 |
+
print(f"Unified graph: {len(morphemes)} nodes, {len(edges)} edges")
|
| 569 |
+
|
| 570 |
+
# 3) Initialize DataModule per experiment settings
|
| 571 |
+
features_config = experiment_info["config"].get("features", {})
|
| 572 |
+
training_config = experiment_info["config"].get("training", {})
|
| 573 |
+
edge_config = experiment_info["config"].get("edge_features", {})
|
| 574 |
+
|
| 575 |
+
# Cache DataModule by annotations_dir
|
| 576 |
+
global _data_module_cache
|
| 577 |
+
cache_key = str(training_config.get("annotations_dir", "annotations_kwdlc"))
|
| 578 |
+
|
| 579 |
+
if cache_key not in _data_module_cache:
|
| 580 |
+
# Always use lexical features
|
| 581 |
+
_data_module_cache[cache_key] = DataModule(
|
| 582 |
+
annotations_dir=training_config.get("annotations_dir", "annotations_kwdlc"),
|
| 583 |
+
batch_size=1,
|
| 584 |
+
num_workers=0,
|
| 585 |
+
silent=silent,
|
| 586 |
+
lexical_feature_dim=features_config.get("lexical_feature_dim", 100000),
|
| 587 |
+
use_bidirectional_edges=edge_config.get("use_bidirectional_edges", True),
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
data_module = _data_module_cache[cache_key]
|
| 591 |
+
|
| 592 |
+
# Build graph using the same public API as preprocessing
|
| 593 |
+
graph = data_module.create_graph_from_morphemes_data(
|
| 594 |
+
morphemes=morphemes,
|
| 595 |
+
edges=edges,
|
| 596 |
+
text=text,
|
| 597 |
+
for_training=False,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
if graph is None:
|
| 601 |
+
print("Error: Failed to create PyTorch graph")
|
| 602 |
+
return [], []
|
| 603 |
+
|
| 604 |
+
# Inference
|
| 605 |
+
|
| 606 |
+
# Device (CPU by default)
|
| 607 |
+
device = torch.device("cpu")
|
| 608 |
+
|
| 609 |
+
# Respect explicit device from experiment_info if present
|
| 610 |
+
if experiment_info and "device" in experiment_info:
|
| 611 |
+
device = experiment_info["device"]
|
| 612 |
+
|
| 613 |
+
with torch.no_grad():
|
| 614 |
+
# Ensure lexical feature tensors exist
|
| 615 |
+
if not hasattr(graph, "lexical_indices") or graph.lexical_indices is None:
|
| 616 |
+
print("Error: lexical_indices not found")
|
| 617 |
+
return [], []
|
| 618 |
+
|
| 619 |
+
logits = model(
|
| 620 |
+
graph.lexical_indices.to(device), # lexical_indices
|
| 621 |
+
graph.lexical_values.to(device), # lexical_values
|
| 622 |
+
graph.edge_index.to(device),
|
| 623 |
+
None,
|
| 624 |
+
graph.edge_attr.to(device) if graph.edge_attr is not None else None,
|
| 625 |
+
).squeeze()
|
| 626 |
+
|
| 627 |
+
if logits.dim() == 0:
|
| 628 |
+
logits = logits.unsqueeze(0)
|
| 629 |
+
probabilities = torch.sigmoid(logits)
|
| 630 |
+
predictions = (probabilities >= 0.5).float()
|
| 631 |
+
|
| 632 |
+
# Move back to CPU for post-processing
|
| 633 |
+
logits = logits.cpu()
|
| 634 |
+
probabilities = probabilities.cpu()
|
| 635 |
+
predictions = predictions.cpu()
|
| 636 |
+
|
| 637 |
+
# Attach predictions to morphemes
|
| 638 |
+
for i, morpheme in enumerate(morphemes):
|
| 639 |
+
if i < len(predictions):
|
| 640 |
+
morpheme["predicted_annotation"] = "+" if predictions[i] == 1 else "-"
|
| 641 |
+
morpheme["logit"] = logits[i].item()
|
| 642 |
+
morpheme["probability"] = probabilities[i].item()
|
| 643 |
+
|
| 644 |
+
# 4) Viterbi decode over nodes/edges (no CRF)
|
| 645 |
+
optimal_path = viterbi_decode_from_morphemes(logits, morphemes, edges, silent=silent)
|
| 646 |
+
|
| 647 |
+
# Format results
|
| 648 |
+
results = []
|
| 649 |
+
for i, morpheme in enumerate(morphemes):
|
| 650 |
+
is_in_optimal_path = optimal_path and i in optimal_path
|
| 651 |
+
|
| 652 |
+
result = {
|
| 653 |
+
"surface": morpheme["surface"],
|
| 654 |
+
"pos": morpheme["pos"],
|
| 655 |
+
"reading": morpheme["reading"],
|
| 656 |
+
"predicted_annotation": morpheme.get("predicted_annotation", "?"),
|
| 657 |
+
"logit": morpheme.get("logit", 0.0),
|
| 658 |
+
"probability": morpheme.get("probability", 0.5),
|
| 659 |
+
"in_optimal_path": is_in_optimal_path,
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
results.append(result)
|
| 663 |
+
|
| 664 |
+
# Collect morphemes on the optimal path
|
| 665 |
+
optimal_morphemes = []
|
| 666 |
+
if optimal_path:
|
| 667 |
+
# Count candidates per span
|
| 668 |
+
position_candidates = {}
|
| 669 |
+
for i, m in enumerate(morphemes):
|
| 670 |
+
pos_key = (m.get("start_pos", 0), m.get("end_pos", 0))
|
| 671 |
+
if pos_key not in position_candidates:
|
| 672 |
+
position_candidates[pos_key] = []
|
| 673 |
+
position_candidates[pos_key].append(i)
|
| 674 |
+
|
| 675 |
+
for idx in optimal_path:
|
| 676 |
+
if idx < len(morphemes):
|
| 677 |
+
morph = morphemes[idx].copy()
|
| 678 |
+
# Add candidate count and selected rank for this span
|
| 679 |
+
pos_key = (morph.get("start_pos", 0), morph.get("end_pos", 0))
|
| 680 |
+
if pos_key in position_candidates:
|
| 681 |
+
candidates_at_pos = position_candidates[pos_key]
|
| 682 |
+
morph["num_candidates"] = len(candidates_at_pos)
|
| 683 |
+
morph["selected_rank"] = candidates_at_pos.index(idx) + 1 if idx in candidates_at_pos else 0
|
| 684 |
+
optimal_morphemes.append(morph)
|
| 685 |
+
|
| 686 |
+
return results, optimal_morphemes
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def print_results(results, optimal_morphemes=None, verbose: bool = False):
|
| 690 |
+
"""Print morphemes in MeCab-like format (surface\tCSV features)."""
|
| 691 |
+
if not results:
|
| 692 |
+
return
|
| 693 |
+
|
| 694 |
+
def mecab_features(m):
|
| 695 |
+
pos = m.get("pos", "*")
|
| 696 |
+
pos1 = m.get("pos_detail1", "*")
|
| 697 |
+
pos2 = m.get("pos_detail2", "*")
|
| 698 |
+
ctype = m.get("inflection_type", "*")
|
| 699 |
+
cform = m.get("inflection_form", "*")
|
| 700 |
+
base = m.get("base_form", m.get("lemma", "*")) or "*"
|
| 701 |
+
reading = m.get("reading", "*") or "*"
|
| 702 |
+
return f"{pos},{pos1},{pos2},{ctype},{cform},{base},{reading}"
|
| 703 |
+
|
| 704 |
+
items = (
|
| 705 |
+
optimal_morphemes
|
| 706 |
+
if optimal_morphemes
|
| 707 |
+
else [
|
| 708 |
+
{
|
| 709 |
+
"surface": r.get("surface", ""),
|
| 710 |
+
"pos": r.get("pos", "*"),
|
| 711 |
+
"pos_detail1": "*",
|
| 712 |
+
"pos_detail2": "*",
|
| 713 |
+
"inflection_type": "*",
|
| 714 |
+
"inflection_form": "*",
|
| 715 |
+
"base_form": r.get("surface", ""),
|
| 716 |
+
"reading": r.get("reading", "*"),
|
| 717 |
+
}
|
| 718 |
+
for r in results
|
| 719 |
+
]
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
for m in items:
|
| 723 |
+
print(f"{m.get('surface', '')}\t{mecab_features(m)}")
|
| 724 |
+
print("EOS")
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def main():
|
| 728 |
+
"""Main inference entrypoint."""
|
| 729 |
+
import argparse
|
| 730 |
+
|
| 731 |
+
parser = argparse.ArgumentParser(description="Mecari morphological analysis inference")
|
| 732 |
+
parser.add_argument("--text", "-t", help="Input text directly")
|
| 733 |
+
parser.add_argument("--experiment", "-e", help="Experiment name to load (e.g., gat_20250730_145624)")
|
| 734 |
+
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output (include UD POS)")
|
| 735 |
+
args = parser.parse_args()
|
| 736 |
+
|
| 737 |
+
if args.experiment:
|
| 738 |
+
result = load_model(experiment_name=args.experiment)
|
| 739 |
+
else:
|
| 740 |
+
result = load_model()
|
| 741 |
+
|
| 742 |
+
if result is None:
|
| 743 |
+
return
|
| 744 |
+
|
| 745 |
+
model, experiment_info = result
|
| 746 |
+
|
| 747 |
+
if args.text:
|
| 748 |
+
result = predict_morphemes_from_text(args.text, model, experiment_info, silent=not args.verbose)
|
| 749 |
+
if result:
|
| 750 |
+
results, optimal_morphemes = result
|
| 751 |
+
print_results(results, optimal_morphemes, verbose=args.verbose)
|
| 752 |
+
else:
|
| 753 |
+
print("Inference failed.")
|
| 754 |
+
|
| 755 |
+
else:
|
| 756 |
+
print("\nMecari morphological inference")
|
| 757 |
+
print("Enter text (e.g., Tokyo is nice)")
|
| 758 |
+
print("Type 'quit' or 'exit' to finish.\n")
|
| 759 |
+
|
| 760 |
+
while True:
|
| 761 |
+
try:
|
| 762 |
+
user_input = input("Input: ").strip()
|
| 763 |
+
|
| 764 |
+
if user_input.lower() in ["quit", "exit", "q"]:
|
| 765 |
+
print("Exiting.")
|
| 766 |
+
break
|
| 767 |
+
|
| 768 |
+
if not user_input:
|
| 769 |
+
continue
|
| 770 |
+
|
| 771 |
+
print(f"Text: {user_input}")
|
| 772 |
+
|
| 773 |
+
result = predict_morphemes_from_text(user_input, model, experiment_info, silent=not args.verbose)
|
| 774 |
+
if result:
|
| 775 |
+
results, optimal_morphemes = result
|
| 776 |
+
print_results(results, optimal_morphemes, verbose=args.verbose)
|
| 777 |
+
else:
|
| 778 |
+
print("Inference failed.")
|
| 779 |
+
|
| 780 |
+
print()
|
| 781 |
+
|
| 782 |
+
except EOFError:
|
| 783 |
+
print("\nExiting.")
|
| 784 |
+
break
|
| 785 |
+
except KeyboardInterrupt:
|
| 786 |
+
print("\nExiting.")
|
| 787 |
+
break
|
| 788 |
+
except Exception as e:
|
| 789 |
+
import traceback
|
| 790 |
+
|
| 791 |
+
print(f"\nAn error occurred: {e}")
|
| 792 |
+
traceback.print_exc()
|
| 793 |
+
continue
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
if __name__ == "__main__":
|
| 797 |
+
main()
|
mecari/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mecari - Japanese Morphological Analysis with Graph Neural Networks"""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
|
| 5 |
+
# Export minimal API (avoid heavy imports at package import time)
|
| 6 |
+
from mecari.config.config import get_model_config, load_config, override_config, save_config # noqa: F401
|
| 7 |
+
from mecari.data.data_module import DataModule # noqa: F401
|
| 8 |
+
|
| 9 |
+
__all__ = ["DataModule", "get_model_config", "override_config", "save_config", "load_config"]
|
mecari/analyzers/mecab.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import tempfile
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
|
| 9 |
+
from mecari.utils.signature import signature_key
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _byte_to_char_map(text: str) -> dict[int, int]:
|
| 13 |
+
mapping: dict[int, int] = {}
|
| 14 |
+
cpos = 0
|
| 15 |
+
bpos = 0
|
| 16 |
+
for ch in text:
|
| 17 |
+
mapping[bpos] = cpos
|
| 18 |
+
bpos += len(ch.encode("utf-8"))
|
| 19 |
+
cpos += 1
|
| 20 |
+
mapping[bpos] = cpos
|
| 21 |
+
return mapping
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MeCabAnalyzer:
|
| 25 |
+
"""Obtain morpheme candidates for building graph.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
jumandic_path: Filesystem path to the JUMANDIC dictionary used by MeCab.
|
| 29 |
+
mecab_bin: Optional MeCab binary name or full path. If None, resolves
|
| 30 |
+
from the MECAB_BIN environment variable or defaults to "mecab".
|
| 31 |
+
|
| 32 |
+
Methods:
|
| 33 |
+
version(): Return the MeCab version string, or an empty string on error.
|
| 34 |
+
get_morpheme_candidates(text): Analyze text and return a list of
|
| 35 |
+
morpheme dicts with fields such as:
|
| 36 |
+
- surface, base_form, reading
|
| 37 |
+
- pos, pos_detail1/2/3
|
| 38 |
+
- inflection_type, inflection_form
|
| 39 |
+
- start_pos, end_pos (character offsets)
|
| 40 |
+
Unknown or unavailable values are filled with "*" or empty strings.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
jumandic_path: str | None = None,
|
| 46 |
+
mecab_bin: str | None = None,
|
| 47 |
+
) -> None:
|
| 48 |
+
# Prefer JUMANDIC if present; otherwise fall back to IPADIC
|
| 49 |
+
if jumandic_path is None:
|
| 50 |
+
candidates = [
|
| 51 |
+
"/var/lib/mecab/dic/juman-utf8",
|
| 52 |
+
"/usr/lib/x86_64-linux-gnu/mecab/dic/juman-utf8",
|
| 53 |
+
]
|
| 54 |
+
ipadic_candidates = [
|
| 55 |
+
"/var/lib/mecab/dic/ipadic",
|
| 56 |
+
"/usr/lib/x86_64-linux-gnu/mecab/dic/ipadic",
|
| 57 |
+
]
|
| 58 |
+
chosen = next((p for p in candidates if os.path.isdir(p)), None)
|
| 59 |
+
if chosen is None:
|
| 60 |
+
chosen = next((p for p in ipadic_candidates if os.path.isdir(p)), None)
|
| 61 |
+
self.jumandic_path = chosen # may be None; handled below
|
| 62 |
+
else:
|
| 63 |
+
self.jumandic_path = jumandic_path
|
| 64 |
+
|
| 65 |
+
# Allow selecting a specific mecab binary via arg or env var; default to common path
|
| 66 |
+
self.mecab_bin = mecab_bin or os.getenv("MECAB_BIN") or (
|
| 67 |
+
"/usr/bin/mecab" if os.path.exists("/usr/bin/mecab") else "mecab"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def version(self) -> str:
|
| 71 |
+
try:
|
| 72 |
+
out = subprocess.run([self.mecab_bin, "-v"], capture_output=True, text=True)
|
| 73 |
+
return (out.stdout or out.stderr).strip()
|
| 74 |
+
except Exception:
|
| 75 |
+
return ""
|
| 76 |
+
|
| 77 |
+
def get_morpheme_candidates(self, text: str) -> List[Dict]:
|
| 78 |
+
"""Return a flat list of JUMANDIC candidates (robust %H format)."""
|
| 79 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
|
| 80 |
+
f.write(text)
|
| 81 |
+
temp_file = f.name
|
| 82 |
+
try:
|
| 83 |
+
fmt = "%pi\t%m\t%H\t%ps\t%pe\n"
|
| 84 |
+
cmd = [self.mecab_bin]
|
| 85 |
+
# Pass dictionary only if we have a resolvable path
|
| 86 |
+
if isinstance(self.jumandic_path, str) and os.path.isdir(self.jumandic_path):
|
| 87 |
+
cmd += ["-d", self.jumandic_path]
|
| 88 |
+
cmd += ["-F", fmt, "-E", "", "-a", temp_file]
|
| 89 |
+
result = subprocess.run(cmd, capture_output=True, text=True, encoding="utf-8", errors="ignore")
|
| 90 |
+
stdout = result.stdout
|
| 91 |
+
finally:
|
| 92 |
+
try:
|
| 93 |
+
import os
|
| 94 |
+
|
| 95 |
+
os.unlink(temp_file)
|
| 96 |
+
except Exception:
|
| 97 |
+
pass
|
| 98 |
+
if result.returncode != 0:
|
| 99 |
+
return []
|
| 100 |
+
byte_to_char = _byte_to_char_map(text)
|
| 101 |
+
out: list[dict] = []
|
| 102 |
+
seen = set()
|
| 103 |
+
for line in stdout.strip().split("\n"):
|
| 104 |
+
if not line:
|
| 105 |
+
continue
|
| 106 |
+
parts = line.split("\t")
|
| 107 |
+
if len(parts) < 5:
|
| 108 |
+
continue
|
| 109 |
+
node_id, surface, features, sb, eb = parts[0], parts[1], parts[2], parts[3], parts[4]
|
| 110 |
+
if surface in ("BOS", "EOS"):
|
| 111 |
+
continue
|
| 112 |
+
if not surface.strip():
|
| 113 |
+
continue
|
| 114 |
+
try:
|
| 115 |
+
start_byte = int(sb)
|
| 116 |
+
end_byte = int(eb)
|
| 117 |
+
except ValueError:
|
| 118 |
+
continue
|
| 119 |
+
start_pos = byte_to_char.get(start_byte, 0)
|
| 120 |
+
end_pos = byte_to_char.get(end_byte, len(text))
|
| 121 |
+
fs = features.split(",")
|
| 122 |
+
pos = fs[0] if len(fs) > 0 else "*"
|
| 123 |
+
pos1 = fs[1] if len(fs) > 1 else "*"
|
| 124 |
+
is_conj = pos in ("動詞", "形容詞", "助動詞")
|
| 125 |
+
ctype = fs[2] if len(fs) > 2 and fs[2] != "*" and is_conj else "*"
|
| 126 |
+
cform = fs[3] if len(fs) > 3 and fs[3] != "*" and is_conj else "*"
|
| 127 |
+
pos2 = (fs[2] if len(fs) > 2 else "*") if not is_conj else "*"
|
| 128 |
+
pos3 = (fs[3] if len(fs) > 3 else "*") if not is_conj else "*"
|
| 129 |
+
base = fs[4] if len(fs) > 4 and fs[4] != "*" else ""
|
| 130 |
+
reading = fs[5] if len(fs) > 5 and fs[5] != "*" else ""
|
| 131 |
+
m = {
|
| 132 |
+
"surface": surface,
|
| 133 |
+
"pos": pos,
|
| 134 |
+
"pos_detail1": pos1,
|
| 135 |
+
"pos_detail2": pos2,
|
| 136 |
+
"pos_detail3": pos3,
|
| 137 |
+
"base_form": base,
|
| 138 |
+
"reading": reading,
|
| 139 |
+
"inflection_type": ctype,
|
| 140 |
+
"inflection_form": cform,
|
| 141 |
+
"start_pos": start_pos,
|
| 142 |
+
"end_pos": end_pos,
|
| 143 |
+
"annotation": "?",
|
| 144 |
+
"node_id": node_id,
|
| 145 |
+
}
|
| 146 |
+
key = signature_key(m)
|
| 147 |
+
if key in seen:
|
| 148 |
+
continue
|
| 149 |
+
seen.add(key)
|
| 150 |
+
out.append(m)
|
| 151 |
+
return out
|
mecari/config/config.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
| 11 |
+
"""Load a YAML config with inheritance (defaults/extends)."""
|
| 12 |
+
if not os.path.exists(config_path):
|
| 13 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 14 |
+
|
| 15 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 16 |
+
config = yaml.safe_load(f)
|
| 17 |
+
|
| 18 |
+
# Handle inheritance (Hydra-style defaults or legacy extends)
|
| 19 |
+
if "defaults" in config:
|
| 20 |
+
# Hydra-style defaults (list format)
|
| 21 |
+
defaults = config["defaults"]
|
| 22 |
+
if isinstance(defaults, list):
|
| 23 |
+
base_config = {}
|
| 24 |
+
for default_item in defaults:
|
| 25 |
+
if isinstance(default_item, str):
|
| 26 |
+
base_config_path = default_item
|
| 27 |
+
else:
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
if not os.path.isabs(base_config_path):
|
| 31 |
+
config_dir = os.path.dirname(config_path)
|
| 32 |
+
base_config_path = os.path.join(config_dir, base_config_path + ".yaml")
|
| 33 |
+
|
| 34 |
+
if os.path.exists(base_config_path):
|
| 35 |
+
loaded = load_config(base_config_path)
|
| 36 |
+
base_config = override_config(base_config, loaded)
|
| 37 |
+
|
| 38 |
+
child_config = {k: v for k, v in config.items() if k != "defaults" and v is not None}
|
| 39 |
+
config = override_config(base_config, child_config)
|
| 40 |
+
elif "extends" in config:
|
| 41 |
+
# Legacy extends format
|
| 42 |
+
base_config_path = config["extends"]
|
| 43 |
+
if not os.path.isabs(base_config_path):
|
| 44 |
+
config_dir = os.path.dirname(config_path)
|
| 45 |
+
base_config_path = os.path.join(config_dir, base_config_path)
|
| 46 |
+
|
| 47 |
+
base_config = load_config(base_config_path)
|
| 48 |
+
|
| 49 |
+
child_config = {k: v for k, v in config.items() if k != "extends" and v is not None}
|
| 50 |
+
config = override_config(base_config, child_config)
|
| 51 |
+
|
| 52 |
+
return config
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_model_config(model_type: str) -> Dict[str, Any]:
|
| 56 |
+
"""Return config for a given model type."""
|
| 57 |
+
config_path = f"configs/{model_type}.yaml"
|
| 58 |
+
|
| 59 |
+
return load_config(config_path)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def override_config(config: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
| 63 |
+
"""Deep-override config with values from overrides."""
|
| 64 |
+
|
| 65 |
+
def deep_update(base_dict, update_dict):
|
| 66 |
+
for key, value in update_dict.items():
|
| 67 |
+
if isinstance(value, dict) and key in base_dict and isinstance(base_dict[key], dict):
|
| 68 |
+
deep_update(base_dict[key], value)
|
| 69 |
+
else:
|
| 70 |
+
base_dict[key] = value
|
| 71 |
+
|
| 72 |
+
import copy
|
| 73 |
+
|
| 74 |
+
result = copy.deepcopy(config)
|
| 75 |
+
deep_update(result, overrides)
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_config(config: Dict[str, Any], output_path: str):
|
| 80 |
+
"""Save config as YAML."""
|
| 81 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 82 |
+
|
| 83 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 84 |
+
yaml.dump(config, f, default_flow_style=False, allow_unicode=True, indent=2)
|
mecari/data/data_module.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from torch_geometric.data import Data, DataLoader
|
| 11 |
+
|
| 12 |
+
# Required import for lexical feature computation
|
| 13 |
+
from mecari.featurizers.lexical import (
|
| 14 |
+
LexicalNGramFeaturizer as LexFeaturizer,
|
| 15 |
+
Morpheme as LexMorpheme,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
"""Data module for lexical-graph training using prebuilt .pt graphs only."""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Prebuilt .pt graph dataset
|
| 23 |
+
class _PtGraphDataset(Dataset):
|
| 24 |
+
"""Prebuilt PyG graph tensors saved as .pt per sentence.
|
| 25 |
+
|
| 26 |
+
Each file is expected to be a dict with keys:
|
| 27 |
+
- 'graph': torch_geometric.data.Data
|
| 28 |
+
- 'source_id': str (used for split)
|
| 29 |
+
- optional: 'text'
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, files: List[str]) -> None:
|
| 33 |
+
self.files = files
|
| 34 |
+
|
| 35 |
+
def __len__(self) -> int:
|
| 36 |
+
return len(self.files)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx: int) -> Data:
|
| 39 |
+
path = self.files[idx]
|
| 40 |
+
obj = torch.load(path, map_location="cpu")
|
| 41 |
+
if isinstance(obj, dict) and "graph" in obj:
|
| 42 |
+
data = obj["graph"]
|
| 43 |
+
else:
|
| 44 |
+
data = obj
|
| 45 |
+
if not isinstance(data, Data):
|
| 46 |
+
raise RuntimeError(f"Invalid graph object in: {path}")
|
| 47 |
+
data.data_index = idx
|
| 48 |
+
return data
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Safe globals registration for PyTorch 2.6+
|
| 52 |
+
try:
|
| 53 |
+
import torch.serialization
|
| 54 |
+
from torch_geometric.data.data import DataEdgeAttr
|
| 55 |
+
|
| 56 |
+
torch.serialization.add_safe_globals([DataEdgeAttr, Data])
|
| 57 |
+
except (ImportError, AttributeError):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class DataModule(pl.LightningDataModule):
|
| 62 |
+
"""Loads .pt graphs and builds lexical graph features for training."""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
annotations_dir: str = "annotations",
|
| 67 |
+
batch_size: int = 32,
|
| 68 |
+
num_workers: int = 0,
|
| 69 |
+
max_files: Optional[int] = None,
|
| 70 |
+
use_bidirectional_edges: bool = True,
|
| 71 |
+
annotations_override_dir: Optional[str] = None,
|
| 72 |
+
silent: bool = False,
|
| 73 |
+
lexical_feature_dim: int = 100000,
|
| 74 |
+
lexical_max_features: int = 20,
|
| 75 |
+
) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.annotations_dir = annotations_dir
|
| 78 |
+
self.annotations_override_dir = annotations_override_dir
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.num_workers = num_workers
|
| 81 |
+
self.max_files = max_files
|
| 82 |
+
self.use_bidirectional_edges = True
|
| 83 |
+
self.silent = silent
|
| 84 |
+
self.lexical_feature_dim = lexical_feature_dim
|
| 85 |
+
self.lexical_max_features = int(lexical_max_features)
|
| 86 |
+
self.use_bidirectional_edges = bool(use_bidirectional_edges)
|
| 87 |
+
|
| 88 |
+
# Initialized in setup()
|
| 89 |
+
self.train_dataset = []
|
| 90 |
+
self.val_dataset = []
|
| 91 |
+
self.test_dataset = []
|
| 92 |
+
# Eagerly initialize lexical featurizer (small and picklable)
|
| 93 |
+
self._lex_featurizer = LexFeaturizer(dim=int(self.lexical_feature_dim), add_bias=True)
|
| 94 |
+
# POS mapping for evaluation breakdown
|
| 95 |
+
self.pos_to_id = {
|
| 96 |
+
"名詞": 1,
|
| 97 |
+
"動詞": 2,
|
| 98 |
+
"形容詞": 3,
|
| 99 |
+
"副詞": 4,
|
| 100 |
+
"助詞": 5,
|
| 101 |
+
"助動詞": 6,
|
| 102 |
+
"接続詞": 7,
|
| 103 |
+
"連体詞": 8,
|
| 104 |
+
"感動詞": 9,
|
| 105 |
+
"形状詞": 10,
|
| 106 |
+
"補助記号": 11,
|
| 107 |
+
"接頭辞": 12,
|
| 108 |
+
"接尾辞": 13,
|
| 109 |
+
"特殊": 14,
|
| 110 |
+
}
|
| 111 |
+
self.id_to_pos = {v: k for k, v in self.pos_to_id.items()}
|
| 112 |
+
|
| 113 |
+
def create_graph_from_morphemes_data(self, *args, **kwargs) -> Optional[Data]:
|
| 114 |
+
"""Create a lexical graph from morpheme data (or candidates)."""
|
| 115 |
+
if "candidates" in kwargs:
|
| 116 |
+
candidates = kwargs.pop("candidates")
|
| 117 |
+
text = kwargs.get("text", "")
|
| 118 |
+
morphemes_edges = self._build_graph_from_candidates(candidates, text)
|
| 119 |
+
if not morphemes_edges:
|
| 120 |
+
return None
|
| 121 |
+
kwargs["morphemes"] = morphemes_edges["morphemes"]
|
| 122 |
+
kwargs["edges"] = morphemes_edges["edges"]
|
| 123 |
+
return self._create_lexical_graph(*args, **kwargs)
|
| 124 |
+
|
| 125 |
+
# --- Lexical features helper (for preprocessing) ---
|
| 126 |
+
def compute_lexical_features(self, morphemes: List[Dict], text: str) -> List[Dict]:
|
| 127 |
+
"""Add lexical_features to each morpheme using Mecari's lexical featurizer.
|
| 128 |
+
|
| 129 |
+
Requires mecari.featurizers.lexical to be importable. Raises a clear error
|
| 130 |
+
if the featurizer is unavailable (training/inference depend on it).
|
| 131 |
+
"""
|
| 132 |
+
if not morphemes:
|
| 133 |
+
return morphemes
|
| 134 |
+
|
| 135 |
+
for m in morphemes:
|
| 136 |
+
try:
|
| 137 |
+
morph_obj = LexMorpheme(
|
| 138 |
+
surf=m.get("surface", ""),
|
| 139 |
+
lemma=m.get("base_form", ""),
|
| 140 |
+
pos=m.get("pos", "*"),
|
| 141 |
+
pos1=m.get("pos_detail1", "*"),
|
| 142 |
+
ctype=m.get("inflection_type", "*"),
|
| 143 |
+
cform=m.get("inflection_form", "*"),
|
| 144 |
+
reading=m.get("reading", "*"),
|
| 145 |
+
)
|
| 146 |
+
st = m.get("start_pos", 0)
|
| 147 |
+
ed = m.get("end_pos", st + len(m.get("surface", "")))
|
| 148 |
+
prev_char = text[st - 1] if st > 0 else None
|
| 149 |
+
next_char = text[ed] if ed < len(text) else None
|
| 150 |
+
feats = self._lex_featurizer.unigram_feats(morph_obj, prev_char, next_char)
|
| 151 |
+
m["lexical_features"] = feats
|
| 152 |
+
except Exception:
|
| 153 |
+
# on any failure, leave unchanged
|
| 154 |
+
pass
|
| 155 |
+
return morphemes
|
| 156 |
+
|
| 157 |
+
def _create_lexical_graph(
|
| 158 |
+
self, morphemes: List[Dict], edges: List[Dict], text: str, for_training: bool = True
|
| 159 |
+
) -> Optional[Data]:
|
| 160 |
+
"""Build a graph using lexical features."""
|
| 161 |
+
if not morphemes:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# Sparse lexical features per node
|
| 165 |
+
all_indices = []
|
| 166 |
+
all_values = []
|
| 167 |
+
all_lengths = []
|
| 168 |
+
annotations = []
|
| 169 |
+
valid_mask = []
|
| 170 |
+
|
| 171 |
+
max_features = 0
|
| 172 |
+
for morpheme in morphemes:
|
| 173 |
+
lexical_feats = morpheme.get("lexical_features", [])
|
| 174 |
+
indices = []
|
| 175 |
+
values = []
|
| 176 |
+
for idx, val in lexical_feats:
|
| 177 |
+
if 0 <= idx < self.lexical_feature_dim:
|
| 178 |
+
indices.append(idx)
|
| 179 |
+
values.append(val)
|
| 180 |
+
all_lengths.append(len(indices))
|
| 181 |
+
max_features = max(max_features, len(indices))
|
| 182 |
+
|
| 183 |
+
all_indices.append(indices)
|
| 184 |
+
all_values.append(values)
|
| 185 |
+
|
| 186 |
+
if for_training:
|
| 187 |
+
annotation = morpheme.get("annotation", "?")
|
| 188 |
+
if annotation == "+":
|
| 189 |
+
annotations.append(1)
|
| 190 |
+
valid_mask.append(True)
|
| 191 |
+
elif annotation == "-":
|
| 192 |
+
annotations.append(0)
|
| 193 |
+
valid_mask.append(True)
|
| 194 |
+
else:
|
| 195 |
+
annotations.append(0)
|
| 196 |
+
valid_mask.append(False)
|
| 197 |
+
|
| 198 |
+
# Fixed-size padding/truncation for batching
|
| 199 |
+
FIXED_MAX_FEATURES = int(getattr(self, "lexical_max_features", 20))
|
| 200 |
+
|
| 201 |
+
padded_indices = []
|
| 202 |
+
padded_values = []
|
| 203 |
+
for indices, values in zip(all_indices, all_values):
|
| 204 |
+
if len(indices) > FIXED_MAX_FEATURES:
|
| 205 |
+
padded_indices.append(indices[:FIXED_MAX_FEATURES])
|
| 206 |
+
padded_values.append(values[:FIXED_MAX_FEATURES])
|
| 207 |
+
else:
|
| 208 |
+
pad_length = FIXED_MAX_FEATURES - len(indices)
|
| 209 |
+
padded_indices.append(indices + [0] * pad_length)
|
| 210 |
+
padded_values.append(values + [0.0] * pad_length)
|
| 211 |
+
|
| 212 |
+
edge_index = self._build_edge_index(edges, len(morphemes))
|
| 213 |
+
|
| 214 |
+
# POS ids per node (for evaluation breakdown)
|
| 215 |
+
pos_ids = []
|
| 216 |
+
for m in morphemes:
|
| 217 |
+
pos = m.get("pos", "*")
|
| 218 |
+
pos_ids.append(self.pos_to_id.get(pos, 0))
|
| 219 |
+
|
| 220 |
+
graph_data = Data(
|
| 221 |
+
lexical_indices=torch.tensor(padded_indices, dtype=torch.long),
|
| 222 |
+
lexical_values=torch.tensor(padded_values, dtype=torch.float32),
|
| 223 |
+
lexical_lengths=torch.tensor(all_lengths, dtype=torch.long),
|
| 224 |
+
edge_index=edge_index,
|
| 225 |
+
num_nodes=len(morphemes),
|
| 226 |
+
)
|
| 227 |
+
graph_data.pos_ids = torch.tensor(pos_ids, dtype=torch.long)
|
| 228 |
+
if for_training:
|
| 229 |
+
graph_data.y = torch.tensor(annotations, dtype=torch.float32)
|
| 230 |
+
graph_data.valid_mask = torch.tensor(valid_mask, dtype=torch.bool)
|
| 231 |
+
|
| 232 |
+
return graph_data
|
| 233 |
+
|
| 234 |
+
def _build_edge_index(self, edges: List[Dict], num_nodes: int) -> torch.Tensor:
|
| 235 |
+
"""Build a PyG edge_index tensor from edge dicts."""
|
| 236 |
+
if not edges:
|
| 237 |
+
return torch.tensor([[], []], dtype=torch.long)
|
| 238 |
+
|
| 239 |
+
source_indices = []
|
| 240 |
+
target_indices = []
|
| 241 |
+
|
| 242 |
+
for edge in edges:
|
| 243 |
+
source = edge.get("source_idx", 0)
|
| 244 |
+
target = edge.get("target_idx", 0)
|
| 245 |
+
|
| 246 |
+
if 0 <= source < num_nodes and 0 <= target < num_nodes:
|
| 247 |
+
source_indices.append(source)
|
| 248 |
+
target_indices.append(target)
|
| 249 |
+
if self.use_bidirectional_edges:
|
| 250 |
+
source_indices.append(target)
|
| 251 |
+
target_indices.append(source)
|
| 252 |
+
|
| 253 |
+
if not source_indices:
|
| 254 |
+
return torch.tensor([[], []], dtype=torch.long)
|
| 255 |
+
|
| 256 |
+
return torch.tensor([source_indices, target_indices], dtype=torch.long)
|
| 257 |
+
|
| 258 |
+
def _load_kwdlc_ids(self, ids_file: str) -> set:
|
| 259 |
+
"""Load KWDLC ID list (one ID per line)."""
|
| 260 |
+
ids = set()
|
| 261 |
+
if ids_file and os.path.exists(ids_file):
|
| 262 |
+
with open(ids_file, "r") as f:
|
| 263 |
+
for line in f:
|
| 264 |
+
ids.add(line.strip())
|
| 265 |
+
return ids
|
| 266 |
+
|
| 267 |
+
def load_annotation_data(self, max_files: Optional[int] = None) -> List[Dict]:
|
| 268 |
+
"""Detect and list available .pt annotation graph files."""
|
| 269 |
+
if os.path.isdir(self.annotations_dir):
|
| 270 |
+
pt_files = [
|
| 271 |
+
os.path.join(self.annotations_dir, fn)
|
| 272 |
+
for fn in sorted(os.listdir(self.annotations_dir))
|
| 273 |
+
if fn.endswith(".pt")
|
| 274 |
+
]
|
| 275 |
+
if pt_files:
|
| 276 |
+
if max_files is not None:
|
| 277 |
+
pt_files = pt_files[:max_files]
|
| 278 |
+
return [{"_mode": "pt", "_pt_files": pt_files}]
|
| 279 |
+
raise FileNotFoundError(f"No annotation graphs found under: {self.annotations_dir}")
|
| 280 |
+
|
| 281 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 282 |
+
"""Build train/val/test datasets from discovered .pt files."""
|
| 283 |
+
annotation_data = self.load_annotation_data(max_files=self.max_files)
|
| 284 |
+
|
| 285 |
+
if not annotation_data:
|
| 286 |
+
self.train_dataset = []
|
| 287 |
+
self.val_dataset = []
|
| 288 |
+
self.test_dataset = []
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
|
| 292 |
+
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
|
| 293 |
+
|
| 294 |
+
mode = annotation_data[0].get("_mode")
|
| 295 |
+
if mode == "pt":
|
| 296 |
+
files: List[str] = annotation_data[0]["_pt_files"]
|
| 297 |
+
train_files: List[str] = []
|
| 298 |
+
val_files: List[str] = []
|
| 299 |
+
test_files: List[str] = []
|
| 300 |
+
|
| 301 |
+
# Use KWDLC split ids (mandatory)
|
| 302 |
+
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
|
| 303 |
+
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
|
| 304 |
+
|
| 305 |
+
for fp in files:
|
| 306 |
+
sid = None
|
| 307 |
+
try:
|
| 308 |
+
obj = torch.load(fp, map_location="cpu")
|
| 309 |
+
if isinstance(obj, dict):
|
| 310 |
+
sid = obj.get("source_id")
|
| 311 |
+
except Exception:
|
| 312 |
+
pass
|
| 313 |
+
if sid and (dev_ids or test_ids):
|
| 314 |
+
if sid in test_ids:
|
| 315 |
+
test_files.append(fp)
|
| 316 |
+
elif sid in dev_ids:
|
| 317 |
+
val_files.append(fp)
|
| 318 |
+
else:
|
| 319 |
+
train_files.append(fp)
|
| 320 |
+
else:
|
| 321 |
+
train_files.append(fp)
|
| 322 |
+
|
| 323 |
+
# Build datasets strictly based on KWDLC dev/test ids
|
| 324 |
+
self.train_dataset = _PtGraphDataset(train_files)
|
| 325 |
+
self.val_dataset = _PtGraphDataset(val_files)
|
| 326 |
+
self.test_dataset = _PtGraphDataset(test_files)
|
| 327 |
+
|
| 328 |
+
if len(self.val_dataset) == 0 or len(self.test_dataset) == 0:
|
| 329 |
+
raise RuntimeError(
|
| 330 |
+
"KWDLC dev/test split produced empty val/test datasets. Ensure KWDLC id files exist and source_id is set in .pt files."
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
raise RuntimeError("Unsupported annotation mode; expected pt")
|
| 334 |
+
|
| 335 |
+
print(
|
| 336 |
+
f"Data split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def _create_dataloader(self, dataset: List[Data], batch_size: int, shuffle: bool = False) -> DataLoader:
|
| 340 |
+
"""Create a DataLoader with optional workers/prefetching."""
|
| 341 |
+
return DataLoader(
|
| 342 |
+
dataset,
|
| 343 |
+
batch_size=batch_size,
|
| 344 |
+
shuffle=shuffle,
|
| 345 |
+
num_workers=self.num_workers,
|
| 346 |
+
pin_memory=False,
|
| 347 |
+
persistent_workers=True if self.num_workers > 0 else False,
|
| 348 |
+
prefetch_factor=2 if self.num_workers > 0 else None,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def train_dataloader(self) -> DataLoader:
|
| 352 |
+
"""Return train DataLoader."""
|
| 353 |
+
return self._create_dataloader(self.train_dataset, self.batch_size, shuffle=True)
|
| 354 |
+
|
| 355 |
+
def val_dataloader(self) -> DataLoader:
|
| 356 |
+
"""Return val DataLoader."""
|
| 357 |
+
return self._create_dataloader(self.val_dataset, self.batch_size, shuffle=False)
|
| 358 |
+
|
| 359 |
+
def test_dataloader(self) -> DataLoader:
|
| 360 |
+
"""Return test DataLoader."""
|
| 361 |
+
return self._create_dataloader(self.test_dataset, self.batch_size, shuffle=False)
|
mecari/featurizers/lexical.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# -------- Basic data structures --------
|
| 7 |
+
@dataclass
|
| 8 |
+
class Morpheme:
|
| 9 |
+
surf: str # surface
|
| 10 |
+
lemma: str # lemma (base form)
|
| 11 |
+
pos: str # POS (coarse)
|
| 12 |
+
pos1: str = "*" # POS (fine)
|
| 13 |
+
ctype: str = "*" # conjugation type
|
| 14 |
+
cform: str = "*" # conjugation form
|
| 15 |
+
reading: str = "*" # reading (if any)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# -------- Utilities --------
|
| 19 |
+
def _stable_hash(s: str, dim: int) -> int:
|
| 20 |
+
# md5 stable hash -> lower 8 bytes -> modulo by dim
|
| 21 |
+
d = hashlib.md5(s.encode("utf-8")).digest()
|
| 22 |
+
return int.from_bytes(d[:8], "little") % dim
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _charclass(ch: str) -> str:
|
| 26 |
+
# Simple character classes (for boundary features)
|
| 27 |
+
if not ch:
|
| 28 |
+
return "O"
|
| 29 |
+
try:
|
| 30 |
+
o = ord(ch)
|
| 31 |
+
except Exception:
|
| 32 |
+
return "O"
|
| 33 |
+
if 0x3040 <= o <= 0x309F:
|
| 34 |
+
return "H" # hiragana
|
| 35 |
+
if 0x30A0 <= o <= 0x30FF:
|
| 36 |
+
return "K" # katakana
|
| 37 |
+
if 0x4E00 <= o <= 0x9FFF or 0x3400 <= o <= 0x4DBF:
|
| 38 |
+
return "C" # kanji
|
| 39 |
+
if 0x0030 <= o <= 0x0039 or 0xFF10 <= o <= 0xFF19:
|
| 40 |
+
return "D" # digits
|
| 41 |
+
if 0x0041 <= o <= 0x007A or 0xFF21 <= o <= 0xFF5A:
|
| 42 |
+
return "A" # letters
|
| 43 |
+
if ch.isspace():
|
| 44 |
+
return "S"
|
| 45 |
+
return "O" # other
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _affix(s: str, n: int) -> str:
|
| 49 |
+
return s[:n] if len(s) >= n else s
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _suffix(s: str, n: int) -> str:
|
| 53 |
+
return s[-n:] if len(s) >= n else s
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# -------- Lexical n-gram featurizer --------
|
| 57 |
+
class LexicalNGramFeaturizer:
|
| 58 |
+
"""Build unigram + boundary features as (index, value) pairs."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, dim: int = 1_000_000, add_bias: bool = True):
|
| 61 |
+
self.dim = dim
|
| 62 |
+
self.add_bias = add_bias
|
| 63 |
+
|
| 64 |
+
def _push(self, feats: List[Tuple[int, float]], key: str, val: float = 1.0):
|
| 65 |
+
feats.append((_stable_hash(key, self.dim), val))
|
| 66 |
+
|
| 67 |
+
def unigram_feats(self, m: Morpheme, prev_char: Optional[str], next_char: Optional[str]) -> List[Tuple[int, float]]:
|
| 68 |
+
f: List[Tuple[int, float]] = []
|
| 69 |
+
# POS
|
| 70 |
+
self._push(f, f"U:POS={m.pos}")
|
| 71 |
+
self._push(f, f"U:POS1={m.pos}:{m.pos1}")
|
| 72 |
+
# Lexicalized (surface/lemma) + POS
|
| 73 |
+
self._push(f, f"U:LEM={m.lemma}")
|
| 74 |
+
self._push(f, f"U:SURF={m.surf}")
|
| 75 |
+
self._push(f, f"U:LEM+POS={m.lemma}|{m.pos}")
|
| 76 |
+
self._push(f, f"U:SURF+POS1={m.surf}|{m.pos}:{m.pos1}")
|
| 77 |
+
# Conjugation
|
| 78 |
+
self._push(f, f"U:CFORM={m.ctype}:{m.cform}")
|
| 79 |
+
# Reading (coarse)
|
| 80 |
+
if m.reading and m.reading != "*":
|
| 81 |
+
self._push(f, f"U:READ={m.reading}")
|
| 82 |
+
# Prefix/Suffix (string n-grams)
|
| 83 |
+
self._push(f, f"U:PREF2={_affix(m.surf, 2)}")
|
| 84 |
+
self._push(f, f"U:SUF2={_suffix(m.surf, 2)}")
|
| 85 |
+
# Boundary char types (1 char left/right)
|
| 86 |
+
if prev_char:
|
| 87 |
+
self._push(f, f"U:BTYPE_L={_charclass(prev_char)}->{_charclass(m.surf[:1])}")
|
| 88 |
+
if next_char:
|
| 89 |
+
self._push(f, f"U:BTYPE_R={_charclass(m.surf[-1:])}->{_charclass(next_char)}")
|
| 90 |
+
if self.add_bias:
|
| 91 |
+
self._push(f, "U:BIAS")
|
| 92 |
+
return f
|
| 93 |
+
|
| 94 |
+
def featurize_sequence(
|
| 95 |
+
self, morphs: List[Morpheme], raw_sentence: Optional[str] = None
|
| 96 |
+
) -> List[Dict[str, List[Tuple[int, float]]]]:
|
| 97 |
+
if raw_sentence is None:
|
| 98 |
+
raw_sentence = "".join(m.surf for m in morphs)
|
| 99 |
+
spans = []
|
| 100 |
+
cur = 0
|
| 101 |
+
for m in morphs:
|
| 102 |
+
st, ed = cur, cur + len(m.surf)
|
| 103 |
+
spans.append((st, ed))
|
| 104 |
+
cur = ed
|
| 105 |
+
|
| 106 |
+
for i, m in enumerate(morphs):
|
| 107 |
+
st, ed = spans[i]
|
| 108 |
+
prev_char = raw_sentence[st - 1] if st > 0 else None
|
| 109 |
+
next_char = raw_sentence[ed] if ed < len(raw_sentence) else None
|
| 110 |
+
feats_node = self.unigram_feats(m, prev_char, next_char)
|
| 111 |
+
|
| 112 |
+
return feats_node
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
pass
|
mecari/models/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from .base import BaseMecariGNN # noqa: F401
|
| 5 |
+
from .gatv2 import MecariGATv2 # noqa: F401
|
| 6 |
+
__all__ = ["BaseMecariGNN", "MecariGATv2"]
|
mecari/models/base.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""Base model with lexical features only."""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import pytorch_lightning as pl
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseMecariGNN(pl.LightningModule):
|
| 14 |
+
"""Base class for Mecari morpheme GNNs."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
hidden_dim: int = 512,
|
| 19 |
+
num_classes: int = 1,
|
| 20 |
+
learning_rate: float = 1e-3,
|
| 21 |
+
lexical_feature_dim: int = 100000,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.save_hyperparameters()
|
| 25 |
+
|
| 26 |
+
self.hidden_dim = hidden_dim
|
| 27 |
+
self.num_classes = num_classes
|
| 28 |
+
self.learning_rate = learning_rate
|
| 29 |
+
self.lexical_feature_dim = lexical_feature_dim
|
| 30 |
+
|
| 31 |
+
self.lexical_embedding = nn.Embedding(
|
| 32 |
+
num_embeddings=lexical_feature_dim, embedding_dim=hidden_dim, padding_idx=0, sparse=False
|
| 33 |
+
)
|
| 34 |
+
nn.init.xavier_uniform_(self.lexical_embedding.weight[1:])
|
| 35 |
+
self.lexical_embedding.weight.data[0].fill_(0)
|
| 36 |
+
|
| 37 |
+
self.lexical_norm = nn.LayerNorm(hidden_dim)
|
| 38 |
+
self.lexical_dropout = nn.Dropout(0.2)
|
| 39 |
+
|
| 40 |
+
self.residual_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 41 |
+
|
| 42 |
+
self.node_classifier = nn.Sequential(
|
| 43 |
+
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def _process_features(
|
| 47 |
+
self, lexical_indices: torch.Tensor, lexical_values: torch.Tensor, bert_features: Optional[torch.Tensor] = None
|
| 48 |
+
) -> torch.Tensor:
|
| 49 |
+
"""Process lexical features."""
|
| 50 |
+
embedded = self.lexical_embedding(lexical_indices)
|
| 51 |
+
weighted = embedded * lexical_values.unsqueeze(-1)
|
| 52 |
+
aggregated = weighted.sum(dim=1)
|
| 53 |
+
processed = self.lexical_dropout(self.lexical_norm(aggregated))
|
| 54 |
+
return processed
|
| 55 |
+
|
| 56 |
+
def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None):
|
| 57 |
+
"""Forward pass (implemented in subclasses)."""
|
| 58 |
+
raise NotImplementedError("Subclasses must implement forward method")
|
| 59 |
+
|
| 60 |
+
def training_step(self, batch, batch_idx):
|
| 61 |
+
node_predictions = self(
|
| 62 |
+
batch.lexical_indices,
|
| 63 |
+
batch.lexical_values,
|
| 64 |
+
batch.edge_index,
|
| 65 |
+
None,
|
| 66 |
+
batch.edge_attr if hasattr(batch, "edge_attr") else None,
|
| 67 |
+
).squeeze()
|
| 68 |
+
|
| 69 |
+
valid_mask = batch.valid_mask
|
| 70 |
+
valid_predictions = node_predictions[valid_mask]
|
| 71 |
+
valid_targets = batch.y[valid_mask]
|
| 72 |
+
|
| 73 |
+
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="train")
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
pred_probs = torch.sigmoid(valid_predictions)
|
| 77 |
+
pred_binary = (pred_probs > 0.5).float()
|
| 78 |
+
correct = (pred_binary == valid_targets).sum()
|
| 79 |
+
accuracy = correct / valid_targets.numel()
|
| 80 |
+
error_rate = 1.0 - accuracy
|
| 81 |
+
|
| 82 |
+
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
|
| 83 |
+
self.log("train_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
|
| 84 |
+
|
| 85 |
+
if self.trainer and self.trainer.optimizers:
|
| 86 |
+
current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
| 87 |
+
self.log("learning_rate", current_lr, on_step=True, on_epoch=False)
|
| 88 |
+
|
| 89 |
+
return loss
|
| 90 |
+
|
| 91 |
+
def validation_step(self, batch, batch_idx):
|
| 92 |
+
node_predictions = self(
|
| 93 |
+
batch.lexical_indices,
|
| 94 |
+
batch.lexical_values,
|
| 95 |
+
batch.edge_index,
|
| 96 |
+
None,
|
| 97 |
+
batch.edge_attr if hasattr(batch, "edge_attr") else None,
|
| 98 |
+
).squeeze()
|
| 99 |
+
|
| 100 |
+
valid_mask = batch.valid_mask
|
| 101 |
+
valid_predictions = node_predictions[valid_mask]
|
| 102 |
+
valid_targets = batch.y[valid_mask]
|
| 103 |
+
|
| 104 |
+
loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="val")
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
pred_probs = torch.sigmoid(valid_predictions)
|
| 108 |
+
pred_binary = (pred_probs > 0.5).float()
|
| 109 |
+
correct = (pred_binary == valid_targets).sum()
|
| 110 |
+
accuracy = correct / valid_targets.numel()
|
| 111 |
+
error_rate = 1.0 - accuracy
|
| 112 |
+
|
| 113 |
+
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
|
| 114 |
+
self.log("val_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
|
| 115 |
+
|
| 116 |
+
self.log("val_loss_epoch", loss, on_step=False, on_epoch=True)
|
| 117 |
+
self.log("val_error_epoch", error_rate, on_step=False, on_epoch=True)
|
| 118 |
+
|
| 119 |
+
return loss
|
| 120 |
+
|
| 121 |
+
def configure_optimizers(self):
|
| 122 |
+
"""Configure optimizer."""
|
| 123 |
+
optimizer_config = getattr(self, "training_config", {}).get("optimizer", {})
|
| 124 |
+
optimizer_type = optimizer_config.get("type", "adamw")
|
| 125 |
+
|
| 126 |
+
if optimizer_type == "adamw":
|
| 127 |
+
optimizer = torch.optim.AdamW(
|
| 128 |
+
self.parameters(), lr=self.learning_rate, weight_decay=optimizer_config.get("weight_decay", 0.01)
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
| 132 |
+
# Optional warmup scheduler (linear warmup to base LR)
|
| 133 |
+
tc = getattr(self, "training_config", {}) or {}
|
| 134 |
+
warmup_steps = int(tc.get("warmup_steps", 0) or 0)
|
| 135 |
+
warmup_start_lr = float(tc.get("warmup_start_lr", 0.0) or 0.0)
|
| 136 |
+
if warmup_steps > 0 and self.learning_rate > 0.0:
|
| 137 |
+
start_factor = max(0.0, min(1.0, warmup_start_lr / float(self.learning_rate)))
|
| 138 |
+
|
| 139 |
+
def lr_lambda(step: int):
|
| 140 |
+
if step <= 0:
|
| 141 |
+
return start_factor
|
| 142 |
+
if step < warmup_steps:
|
| 143 |
+
return start_factor + (1.0 - start_factor) * (step / float(warmup_steps))
|
| 144 |
+
return 1.0
|
| 145 |
+
|
| 146 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 147 |
+
return {
|
| 148 |
+
"optimizer": optimizer,
|
| 149 |
+
"lr_scheduler": {
|
| 150 |
+
"scheduler": scheduler,
|
| 151 |
+
"interval": "step",
|
| 152 |
+
"frequency": 1,
|
| 153 |
+
"name": "linear_warmup",
|
| 154 |
+
},
|
| 155 |
+
}
|
| 156 |
+
return {"optimizer": optimizer}
|
| 157 |
+
|
| 158 |
+
def test_step(self, batch, batch_idx):
|
| 159 |
+
node_predictions = self(
|
| 160 |
+
batch.lexical_indices,
|
| 161 |
+
batch.lexical_values,
|
| 162 |
+
batch.edge_index,
|
| 163 |
+
None,
|
| 164 |
+
batch.edge_attr if hasattr(batch, "edge_attr") else None,
|
| 165 |
+
).squeeze()
|
| 166 |
+
|
| 167 |
+
valid_mask = batch.valid_mask
|
| 168 |
+
valid_predictions = node_predictions[valid_mask]
|
| 169 |
+
valid_targets = batch.y[valid_mask]
|
| 170 |
+
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
pred_probs = torch.sigmoid(valid_predictions)
|
| 173 |
+
pred_binary = (pred_probs > 0.5).float()
|
| 174 |
+
correct = (pred_binary == valid_targets).sum()
|
| 175 |
+
accuracy = correct / valid_targets.numel()
|
| 176 |
+
error_rate = 1.0 - accuracy
|
| 177 |
+
|
| 178 |
+
self.log("test_error", error_rate, on_step=False, on_epoch=True)
|
| 179 |
+
self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
|
| 180 |
+
|
| 181 |
+
return error_rate
|
| 182 |
+
|
| 183 |
+
def _compute_bce_loss(self, logits: torch.Tensor, targets: torch.Tensor, stage: str = "train") -> torch.Tensor:
|
| 184 |
+
"""BCEWithLogits loss with optional label smoothing and pos_weight.
|
| 185 |
+
|
| 186 |
+
- label_smoothing: smooth targets toward 0.5 by epsilon.
|
| 187 |
+
- pos_weight: handle class imbalance using ratio (neg/pos) per batch, robustly.
|
| 188 |
+
"""
|
| 189 |
+
loss_cfg = getattr(self, "training_config", {}).get("loss", {})
|
| 190 |
+
eps = float(loss_cfg.get("label_smoothing", 0.0) or 0.0)
|
| 191 |
+
use_pos_weight = bool(loss_cfg.get("use_pos_weight", True))
|
| 192 |
+
|
| 193 |
+
# Compute pos_weight from unsmoothed targets
|
| 194 |
+
pos = torch.clamp(targets.sum(), min=0.0)
|
| 195 |
+
total = torch.tensor(targets.numel(), device=targets.device, dtype=targets.dtype)
|
| 196 |
+
neg = total - pos
|
| 197 |
+
pos_weight = None
|
| 198 |
+
if use_pos_weight and pos > 0 and neg > 0:
|
| 199 |
+
# pos_weight = neg/pos; clamp to avoid extreme values
|
| 200 |
+
pw = (neg / pos).detach()
|
| 201 |
+
pw = torch.clamp(pw, 0.5, 50.0) # safety bounds
|
| 202 |
+
pos_weight = pw
|
| 203 |
+
|
| 204 |
+
# Apply label smoothing to targets: y' = (1-eps)*y + 0.5*eps
|
| 205 |
+
if eps > 0.0:
|
| 206 |
+
targets = (1.0 - eps) * targets + 0.5 * eps
|
| 207 |
+
|
| 208 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 209 |
+
logits,
|
| 210 |
+
targets,
|
| 211 |
+
pos_weight=pos_weight,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return loss
|
mecari/models/gatv2.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""GATv2 model for morpheme graph classification."""
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch_geometric.nn import GATv2Conv
|
| 8 |
+
from torch_geometric.utils import add_self_loops, dropout_adj
|
| 9 |
+
|
| 10 |
+
from .base import BaseMecariGNN
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MecariGATv2(BaseMecariGNN):
|
| 14 |
+
"""Graph Attention Network v2 for morpheme analysis"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
hidden_dim: int = 512,
|
| 19 |
+
num_heads: int = 8,
|
| 20 |
+
num_layers: int = 4,
|
| 21 |
+
num_classes: int = 1,
|
| 22 |
+
learning_rate: float = 1e-3,
|
| 23 |
+
lexical_feature_dim: int = 100000,
|
| 24 |
+
share_weights: bool = False, # share-weights option of GATv2
|
| 25 |
+
# New knobs
|
| 26 |
+
dropout: float = 0.1,
|
| 27 |
+
attn_dropout: float = 0.1,
|
| 28 |
+
add_self_loops_flag: bool = True,
|
| 29 |
+
edge_dropout: float = 0.0,
|
| 30 |
+
norm: str = "layer",
|
| 31 |
+
**kwargs, # Ignore extra params for config compatibility
|
| 32 |
+
):
|
| 33 |
+
super().__init__(
|
| 34 |
+
hidden_dim=hidden_dim,
|
| 35 |
+
num_classes=num_classes,
|
| 36 |
+
learning_rate=learning_rate,
|
| 37 |
+
lexical_feature_dim=lexical_feature_dim,
|
| 38 |
+
)
|
| 39 |
+
self.num_heads = num_heads
|
| 40 |
+
self.num_layers = num_layers
|
| 41 |
+
self.share_weights = share_weights
|
| 42 |
+
self.feat_dropout_p = dropout
|
| 43 |
+
self.attn_dropout_p = attn_dropout
|
| 44 |
+
self.add_self_loops_flag = add_self_loops_flag
|
| 45 |
+
self.edge_dropout_p = edge_dropout
|
| 46 |
+
self.norm_type = (norm or "layer").lower()
|
| 47 |
+
|
| 48 |
+
# GATv2 layers
|
| 49 |
+
self.gatv2_layers = nn.ModuleList()
|
| 50 |
+
self.layer_norms = nn.ModuleList()
|
| 51 |
+
|
| 52 |
+
for i in range(num_layers):
|
| 53 |
+
if i == 0:
|
| 54 |
+
# First layer
|
| 55 |
+
self.gatv2_layers.append(
|
| 56 |
+
GATv2Conv(
|
| 57 |
+
hidden_dim,
|
| 58 |
+
hidden_dim,
|
| 59 |
+
heads=num_heads,
|
| 60 |
+
dropout=self.attn_dropout_p,
|
| 61 |
+
share_weights=share_weights,
|
| 62 |
+
add_self_loops=False,
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
elif i == num_layers - 1:
|
| 66 |
+
# Last layer - single head
|
| 67 |
+
self.gatv2_layers.append(
|
| 68 |
+
GATv2Conv(
|
| 69 |
+
hidden_dim * num_heads,
|
| 70 |
+
hidden_dim,
|
| 71 |
+
heads=1,
|
| 72 |
+
concat=False,
|
| 73 |
+
dropout=self.attn_dropout_p,
|
| 74 |
+
share_weights=share_weights,
|
| 75 |
+
add_self_loops=False,
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
# Middle layers
|
| 80 |
+
self.gatv2_layers.append(
|
| 81 |
+
GATv2Conv(
|
| 82 |
+
hidden_dim * num_heads,
|
| 83 |
+
hidden_dim,
|
| 84 |
+
heads=num_heads,
|
| 85 |
+
dropout=self.attn_dropout_p,
|
| 86 |
+
share_weights=share_weights,
|
| 87 |
+
add_self_loops=False,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Layer normalization (all layers)
|
| 92 |
+
if i < num_layers - 1:
|
| 93 |
+
self.layer_norms.append(
|
| 94 |
+
nn.LayerNorm(hidden_dim * num_heads)
|
| 95 |
+
if self.norm_type == "layer"
|
| 96 |
+
else nn.BatchNorm1d(hidden_dim * num_heads)
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
self.layer_norms.append(
|
| 100 |
+
nn.LayerNorm(hidden_dim) if self.norm_type == "layer" else nn.BatchNorm1d(hidden_dim)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None):
|
| 104 |
+
"""Forward pass of GATv2"""
|
| 105 |
+
x = self._process_features(lexical_indices, lexical_values, bert_features)
|
| 106 |
+
|
| 107 |
+
residual = self.residual_proj(x)
|
| 108 |
+
|
| 109 |
+
ei = edge_index
|
| 110 |
+
if self.add_self_loops_flag:
|
| 111 |
+
ei, _ = add_self_loops(ei, num_nodes=x.size(0))
|
| 112 |
+
if self.edge_dropout_p > 0 and self.training:
|
| 113 |
+
ei, _ = dropout_adj(ei, p=self.edge_dropout_p, force_undirected=False, training=True)
|
| 114 |
+
|
| 115 |
+
# Apply GATv2 layers
|
| 116 |
+
prev = None
|
| 117 |
+
for i in range(self.num_layers):
|
| 118 |
+
prev = x
|
| 119 |
+
x = self.gatv2_layers[i](x, ei)
|
| 120 |
+
x = self.layer_norms[i](x)
|
| 121 |
+
|
| 122 |
+
# Per-layer residual if dimension matches (middle layers)
|
| 123 |
+
if x.shape == prev.shape and i < self.num_layers - 1:
|
| 124 |
+
x = x + prev
|
| 125 |
+
|
| 126 |
+
# Add residual at last layer
|
| 127 |
+
if i == self.num_layers - 1:
|
| 128 |
+
x = x + residual
|
| 129 |
+
|
| 130 |
+
x = F.elu(x)
|
| 131 |
+
|
| 132 |
+
# Dropout except last layer
|
| 133 |
+
if i < self.num_layers - 1:
|
| 134 |
+
x = F.dropout(x, p=self.feat_dropout_p, training=self.training)
|
| 135 |
+
|
| 136 |
+
# Classification
|
| 137 |
+
logits = self.node_classifier(x)
|
| 138 |
+
|
| 139 |
+
return logits
|
mecari/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
__all__ = []
|
mecari/utils/morph_utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Any, Optional
|
| 6 |
+
|
| 7 |
+
from mecari.utils.signature import signature_key
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def dedup_morphemes(morphs: List[Dict]) -> List[Dict]:
|
| 11 |
+
seen = set()
|
| 12 |
+
out: List[Dict] = []
|
| 13 |
+
for m in morphs:
|
| 14 |
+
key = signature_key(m)
|
| 15 |
+
if key in seen:
|
| 16 |
+
continue
|
| 17 |
+
seen.add(key)
|
| 18 |
+
out.append(m)
|
| 19 |
+
out.sort(key=lambda m: (
|
| 20 |
+
m.get("start_pos", 0),
|
| 21 |
+
-(m.get("end_pos", 0) - m.get("start_pos", 0)),
|
| 22 |
+
m.get("surface", ""),
|
| 23 |
+
m.get("reading", ""),
|
| 24 |
+
m.get("pos", "*"),
|
| 25 |
+
))
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_adjacent_edges(morphs: List[Dict]) -> List[Dict]:
|
| 30 |
+
edges: List[Dict] = []
|
| 31 |
+
for i, s in enumerate(morphs):
|
| 32 |
+
for j, t in enumerate(morphs):
|
| 33 |
+
if i >= j:
|
| 34 |
+
continue
|
| 35 |
+
if s.get("end_pos", 0) == t.get("start_pos", 0):
|
| 36 |
+
edges.append({"source_idx": i, "target_idx": j, "edge_type": "forward"})
|
| 37 |
+
return edges
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def normalize_mecab_candidates(candidates: List[Dict]) -> List[Dict]:
|
| 41 |
+
"""Normalize MeCab candidates consistently for preprocessing/inference.
|
| 42 |
+
|
| 43 |
+
- If surface is digit-only and base_form is empty/missing, set base_form = surface.
|
| 44 |
+
Modifies candidates in place and returns the list for convenience.
|
| 45 |
+
"""
|
| 46 |
+
for c in candidates:
|
| 47 |
+
surf = c.get("surface", "")
|
| 48 |
+
bf = c.get("base_form")
|
| 49 |
+
if (bf is None or bf == "") and surf and surf.isdigit():
|
| 50 |
+
c["base_form"] = surf
|
| 51 |
+
return candidates
|
mecari/utils/signature.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def to_katakana(s) -> str:
|
| 9 |
+
"""Robust hiragana->katakana conversion for str or sequence.
|
| 10 |
+
|
| 11 |
+
Accepts str, list, tuple; concatenates string elements when a sequence is given.
|
| 12 |
+
Non-string inputs are stringified; None becomes empty string.
|
| 13 |
+
"""
|
| 14 |
+
if isinstance(s, (list, tuple)):
|
| 15 |
+
s = "".join(x for x in s if isinstance(x, str))
|
| 16 |
+
elif not isinstance(s, str):
|
| 17 |
+
s = str(s) if s is not None else ""
|
| 18 |
+
out = []
|
| 19 |
+
for ch in s:
|
| 20 |
+
if not ch:
|
| 21 |
+
continue
|
| 22 |
+
o = ord(ch)
|
| 23 |
+
if 0x3041 <= o <= 0x3096:
|
| 24 |
+
out.append(chr(o + 0x60))
|
| 25 |
+
else:
|
| 26 |
+
out.append(ch)
|
| 27 |
+
return "".join(out)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def signature_key(m: Dict) -> Tuple:
|
| 31 |
+
"""Stable deduplication key for a morpheme dict (POS up to pos1)."""
|
| 32 |
+
surface = m.get("surface", "")
|
| 33 |
+
pos = m.get("pos", "*")
|
| 34 |
+
pos1 = m.get("pos_detail1", "*")
|
| 35 |
+
base = m.get("base_form") or m.get("lemma") or ""
|
| 36 |
+
read = to_katakana(m.get("reading") or "")
|
| 37 |
+
st = m.get("start_pos", 0)
|
| 38 |
+
ed = m.get("end_pos", st + len(surface))
|
| 39 |
+
return (st, ed, surface, pos, pos1, base, read)
|
packages.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mecab
|
| 2 |
+
mecab-utils
|
| 3 |
+
libmecab-dev
|
| 4 |
+
mecab-jumandic-utf8
|
| 5 |
+
mecab-ipadic-utf8
|
preprocess.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Build training graphs from KWDLC with JUMANDIC.
|
| 6 |
+
|
| 7 |
+
Pipeline:
|
| 8 |
+
1) Read gold morphemes from KNP files
|
| 9 |
+
2) Parse text with MeCab (JUMANDIC) to get candidate morphemes
|
| 10 |
+
3) Match candidates to gold and assign annotations ('+', '-', '?')
|
| 11 |
+
4) Save graph data as .pt
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Dict, List
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import yaml
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from mecari.analyzers.mecab import MeCabAnalyzer
|
| 24 |
+
from mecari.data.data_module import DataModule
|
| 25 |
+
from mecari.featurizers.lexical import LexicalNGramFeaturizer as LexicalFeaturizer
|
| 26 |
+
from mecari.featurizers.lexical import Morpheme
|
| 27 |
+
from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def add_lexical_features(morphemes: List[Dict], text: str, feature_dim: int = 100000) -> List[Dict]:
|
| 31 |
+
"""Add lexical (index, value) pairs to morphemes. Not used when saving JSON.
|
| 32 |
+
|
| 33 |
+
Kept for backward-compatibility and test equivalence.
|
| 34 |
+
"""
|
| 35 |
+
featurizer = LexicalFeaturizer(dim=feature_dim, add_bias=True)
|
| 36 |
+
for m in morphemes:
|
| 37 |
+
surf = m.get("surface", "")
|
| 38 |
+
morph_obj = Morpheme(
|
| 39 |
+
surf=surf,
|
| 40 |
+
lemma=m.get("base_form", surf),
|
| 41 |
+
pos=m.get("pos", "*"),
|
| 42 |
+
pos1=m.get("pos_detail1", "*"),
|
| 43 |
+
ctype="*",
|
| 44 |
+
cform="*",
|
| 45 |
+
reading=m.get("reading", "*"),
|
| 46 |
+
)
|
| 47 |
+
st = m.get("start_pos", 0)
|
| 48 |
+
ed = m.get("end_pos", st + len(surf))
|
| 49 |
+
prev_char = text[st - 1] if st > 0 and st <= len(text) else None
|
| 50 |
+
next_char = text[ed] if ed < len(text) else None
|
| 51 |
+
feats = featurizer.unigram_feats(morph_obj, prev_char, next_char)
|
| 52 |
+
m["lexical_features"] = feats
|
| 53 |
+
return morphemes
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def hiragana_to_katakana(text: str) -> str:
|
| 57 |
+
"""Convert hiragana to katakana."""
|
| 58 |
+
return "".join([chr(ord(c) + 96) if "ぁ" <= c <= "ん" else c for c in text])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _load_gold_with_kyoto(knp_path: Path) -> List[Dict]:
|
| 62 |
+
"""Load sentences and morphemes from a KNP file using kyoto-reader (required)."""
|
| 63 |
+
try:
|
| 64 |
+
from kyoto_reader import KyotoReader # type: ignore
|
| 65 |
+
except Exception as e: # pragma: no cover
|
| 66 |
+
raise RuntimeError("kyoto-reader is required for gold loading. Install it (pip install kyoto-reader).") from e
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
try:
|
| 70 |
+
reader = KyotoReader(str(knp_path), n_jobs=0)
|
| 71 |
+
except TypeError:
|
| 72 |
+
reader = KyotoReader(str(knp_path))
|
| 73 |
+
sents: List[Dict] = []
|
| 74 |
+
for doc in reader.process_all_documents(n_jobs=0):
|
| 75 |
+
if doc is None:
|
| 76 |
+
continue
|
| 77 |
+
for sent in doc.sentences:
|
| 78 |
+
text = sent.surf
|
| 79 |
+
morphemes: List[Dict] = []
|
| 80 |
+
pos = 0
|
| 81 |
+
for mrph in sent.mrph_list():
|
| 82 |
+
surf = getattr(mrph, "midasi", "") or ""
|
| 83 |
+
read = getattr(mrph, "yomi", surf) or surf
|
| 84 |
+
lemma = getattr(mrph, "genkei", surf) or surf
|
| 85 |
+
pos_main = getattr(mrph, "hinsi", "*") or "*"
|
| 86 |
+
pos1 = getattr(mrph, "bunrui", "*") or "*"
|
| 87 |
+
st = pos
|
| 88 |
+
ed = st + len(surf)
|
| 89 |
+
pos = ed
|
| 90 |
+
morphemes.append(
|
| 91 |
+
{
|
| 92 |
+
"surface": surf,
|
| 93 |
+
"reading": read,
|
| 94 |
+
"base_form": lemma,
|
| 95 |
+
"pos": pos_main,
|
| 96 |
+
"pos_detail1": pos1,
|
| 97 |
+
"pos_detail2": "*",
|
| 98 |
+
"pos_detail3": "*",
|
| 99 |
+
"start_pos": st,
|
| 100 |
+
"end_pos": ed,
|
| 101 |
+
}
|
| 102 |
+
)
|
| 103 |
+
sents.append({"text": text, "morphemes": morphemes})
|
| 104 |
+
return sents
|
| 105 |
+
except Exception as e:
|
| 106 |
+
raise RuntimeError(f"Failed to parse KNP with kyoto-reader: {knp_path}") from e
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def match_morphemes_with_gold(candidates: List[Dict], gold_morphemes: List[Dict], text: str) -> List[Dict]:
|
| 110 |
+
"""Match candidate morphemes to gold and assign annotations ('?', '+', '-').
|
| 111 |
+
|
| 112 |
+
Policy:
|
| 113 |
+
- Initialize every candidate as '?'
|
| 114 |
+
- Mark '+' for candidates that strictly match gold (surface, POS, base, reading)
|
| 115 |
+
- Mark '-' for candidates that overlap any '+' span
|
| 116 |
+
"""
|
| 117 |
+
# Reconstruct gold spans in character offsets
|
| 118 |
+
gold_details = []
|
| 119 |
+
cur = 0
|
| 120 |
+
for g in gold_morphemes:
|
| 121 |
+
surf = g.get("surface", "")
|
| 122 |
+
st, ed = cur, cur + len(surf)
|
| 123 |
+
cur = ed
|
| 124 |
+
gold_details.append(
|
| 125 |
+
{
|
| 126 |
+
"start_pos": st,
|
| 127 |
+
"end_pos": ed,
|
| 128 |
+
"surface": surf,
|
| 129 |
+
"pos": g.get("pos", "*"),
|
| 130 |
+
"pos_detail1": g.get("pos_detail1", "*"),
|
| 131 |
+
"pos_detail2": g.get("pos_detail2", "*"),
|
| 132 |
+
"base_form": g.get("base_form", ""),
|
| 133 |
+
"reading": hiragana_to_katakana(g.get("reading", "")),
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Initialize all candidates with '?'
|
| 138 |
+
annotated: List[Dict] = []
|
| 139 |
+
for cand in candidates:
|
| 140 |
+
a = {**cand}
|
| 141 |
+
a["annotation"] = "?"
|
| 142 |
+
if "inflection_type" not in a:
|
| 143 |
+
a["inflection_type"] = "*"
|
| 144 |
+
if "inflection_form" not in a:
|
| 145 |
+
a["inflection_form"] = "*"
|
| 146 |
+
annotated.append(a)
|
| 147 |
+
|
| 148 |
+
# Match by strict equality first; allow reading mismatch as fallback
|
| 149 |
+
span_to_cands: dict[tuple[int, int], list[Dict]] = {}
|
| 150 |
+
for a in annotated:
|
| 151 |
+
cs = a.get("start_pos", 0)
|
| 152 |
+
ce = a.get("end_pos", cs + len(a.get("surface", "")))
|
| 153 |
+
span_to_cands.setdefault((cs, ce), []).append(a)
|
| 154 |
+
|
| 155 |
+
matched_spans: List[tuple[int, int]] = []
|
| 156 |
+
for g in gold_details:
|
| 157 |
+
span = (g["start_pos"], g["end_pos"])
|
| 158 |
+
cands = span_to_cands.get(span, [])
|
| 159 |
+
if not cands:
|
| 160 |
+
continue
|
| 161 |
+
strict = []
|
| 162 |
+
fallback = []
|
| 163 |
+
for a in cands:
|
| 164 |
+
if a.get("surface", "") != g["surface"]:
|
| 165 |
+
continue
|
| 166 |
+
if a.get("pos", "*") != g["pos"]:
|
| 167 |
+
continue
|
| 168 |
+
if a.get("pos_detail1", "*") != g.get("pos_detail1", "*"):
|
| 169 |
+
continue
|
| 170 |
+
if a.get("base_form", "") != g["base_form"]:
|
| 171 |
+
continue
|
| 172 |
+
if hiragana_to_katakana(a.get("reading", "")) == g["reading"]:
|
| 173 |
+
strict.append(a)
|
| 174 |
+
else:
|
| 175 |
+
fallback.append(a)
|
| 176 |
+
chosen_list = strict if strict else fallback
|
| 177 |
+
if chosen_list:
|
| 178 |
+
for a in chosen_list:
|
| 179 |
+
a["annotation"] = "+"
|
| 180 |
+
matched_spans.append(span)
|
| 181 |
+
for a in cands:
|
| 182 |
+
if (a not in chosen_list) and a.get("annotation") != "+":
|
| 183 |
+
a["annotation"] = "-"
|
| 184 |
+
|
| 185 |
+
# Demote any morpheme that overlaps (by at least 1 char) with any '+' span.
|
| 186 |
+
plus_spans = []
|
| 187 |
+
for a in annotated:
|
| 188 |
+
if a.get("annotation") == "+":
|
| 189 |
+
cs = a.get("start_pos", 0)
|
| 190 |
+
ce = a.get("end_pos", cs + len(a.get("surface", "")))
|
| 191 |
+
plus_spans.append((cs, ce))
|
| 192 |
+
|
| 193 |
+
def _strict_overlap(st1: int, ed1: int, st2: int, ed2: int) -> bool:
|
| 194 |
+
# overlap only if intersection length > 0 (touching is not overlap)
|
| 195 |
+
return max(st1, st2) < min(ed1, ed2)
|
| 196 |
+
|
| 197 |
+
for a in annotated:
|
| 198 |
+
if a.get("annotation") == "+":
|
| 199 |
+
continue
|
| 200 |
+
cs = a.get("start_pos", 0)
|
| 201 |
+
ce = a.get("end_pos", cs + len(a.get("surface", "")))
|
| 202 |
+
for ms, me in plus_spans:
|
| 203 |
+
if _strict_overlap(cs, ce, ms, me):
|
| 204 |
+
a["annotation"] = "-"
|
| 205 |
+
break
|
| 206 |
+
return annotated
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main():
|
| 210 |
+
parser = argparse.ArgumentParser(description="Create training data from KWDLC (JUMANDIC)")
|
| 211 |
+
parser.add_argument("--input-dir", type=str, default="KWDLC/knp", help="Directory containing KNP files")
|
| 212 |
+
parser.add_argument("--config", type=str, default="configs/gat.yaml", help="Path to config file")
|
| 213 |
+
parser.add_argument("--limit", type=int, help="Max number of files to process")
|
| 214 |
+
parser.add_argument("--test-only", action="store_true", help="Process only test split IDs")
|
| 215 |
+
parser.add_argument("--jumandic-path", type=str, default="/var/lib/mecab/dic/juman-utf8", help="Path to JUMANDIC")
|
| 216 |
+
args = parser.parse_args()
|
| 217 |
+
|
| 218 |
+
config = {}
|
| 219 |
+
if args.config and Path(args.config).exists():
|
| 220 |
+
with open(args.config, "r") as f:
|
| 221 |
+
config = yaml.safe_load(f)
|
| 222 |
+
|
| 223 |
+
if "extends" in config:
|
| 224 |
+
parent_config_path = Path(args.config).parent / config["extends"]
|
| 225 |
+
if parent_config_path.exists():
|
| 226 |
+
with open(parent_config_path, "r") as f:
|
| 227 |
+
parent_config = yaml.safe_load(f)
|
| 228 |
+
|
| 229 |
+
def deep_merge(base, override):
|
| 230 |
+
for key, value in override.items():
|
| 231 |
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
| 232 |
+
deep_merge(base[key], value)
|
| 233 |
+
else:
|
| 234 |
+
base[key] = value
|
| 235 |
+
return base
|
| 236 |
+
|
| 237 |
+
config = deep_merge(parent_config, config)
|
| 238 |
+
|
| 239 |
+
features_config = config.get("features", {})
|
| 240 |
+
feature_dim = features_config.get("lexical_feature_dim", 100000)
|
| 241 |
+
training_config = config.get("training", {})
|
| 242 |
+
|
| 243 |
+
if training_config.get("annotations_dir"):
|
| 244 |
+
output_dir = Path(training_config.get("annotations_dir"))
|
| 245 |
+
else:
|
| 246 |
+
output_dir = Path("annotations_kwdlc_juman")
|
| 247 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
print(f"Lexical features: using {feature_dim} dims")
|
| 249 |
+
print(f"Output directory: {output_dir}")
|
| 250 |
+
|
| 251 |
+
analyzer = MeCabAnalyzer(
|
| 252 |
+
jumandic_path=args.jumandic_path,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
knp_files = []
|
| 256 |
+
|
| 257 |
+
if args.test_only:
|
| 258 |
+
test_id_file = Path("KWDLC/id/split_for_pas/test.id")
|
| 259 |
+
if test_id_file.exists():
|
| 260 |
+
with open(test_id_file, "r") as f:
|
| 261 |
+
test_ids = [line.strip() for line in f if line.strip()]
|
| 262 |
+
|
| 263 |
+
knp_base_dir = Path(args.input_dir)
|
| 264 |
+
for file_id in test_ids:
|
| 265 |
+
dir_name = file_id[:13]
|
| 266 |
+
file_name = f"{file_id}.knp"
|
| 267 |
+
knp_path = knp_base_dir / dir_name / file_name
|
| 268 |
+
if knp_path.exists():
|
| 269 |
+
knp_files.append(knp_path)
|
| 270 |
+
else:
|
| 271 |
+
knp_dir = Path(args.input_dir)
|
| 272 |
+
knp_files = sorted(knp_dir.glob("**/*.knp"))
|
| 273 |
+
|
| 274 |
+
if args.limit:
|
| 275 |
+
knp_files = knp_files[: args.limit]
|
| 276 |
+
|
| 277 |
+
print(f"Files to process: {len(knp_files)}")
|
| 278 |
+
print(f"JUMANDIC: {args.jumandic_path}")
|
| 279 |
+
print(f"Output to: {output_dir}")
|
| 280 |
+
|
| 281 |
+
total_stats = defaultdict(int)
|
| 282 |
+
annotation_idx = 0
|
| 283 |
+
|
| 284 |
+
dm = DataModule(
|
| 285 |
+
annotations_dir=str(output_dir),
|
| 286 |
+
lexical_feature_dim=int(feature_dim),
|
| 287 |
+
use_bidirectional_edges=bool(config.get("edge_features", {}).get("use_bidirectional_edges", True)),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Save .pt files directly under the output_dir
|
| 291 |
+
|
| 292 |
+
for knp_path in tqdm(knp_files, desc="processing"):
|
| 293 |
+
try:
|
| 294 |
+
sentences = _load_gold_with_kyoto(knp_path)
|
| 295 |
+
if not sentences:
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
doc_id = knp_path.stem
|
| 299 |
+
for s in sentences:
|
| 300 |
+
s["source_id"] = doc_id
|
| 301 |
+
|
| 302 |
+
for sent_idx, sentence in enumerate(sentences):
|
| 303 |
+
text = sentence["text"]
|
| 304 |
+
gold_morphemes = sentence["morphemes"]
|
| 305 |
+
source_id = sentence.get("source_id", doc_id)
|
| 306 |
+
|
| 307 |
+
candidates = analyzer.get_morpheme_candidates(text)
|
| 308 |
+
candidates = normalize_mecab_candidates(candidates)
|
| 309 |
+
candidates = dedup_morphemes(candidates)
|
| 310 |
+
if not candidates:
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
annotated_morphemes = match_morphemes_with_gold(candidates, gold_morphemes, text)
|
| 314 |
+
|
| 315 |
+
edges = build_adjacent_edges(annotated_morphemes)
|
| 316 |
+
|
| 317 |
+
for m in annotated_morphemes:
|
| 318 |
+
if "lexical_features" in m:
|
| 319 |
+
m.pop("lexical_features", None)
|
| 320 |
+
|
| 321 |
+
morphemes_with_feats = dm.compute_lexical_features(annotated_morphemes, text)
|
| 322 |
+
graph = dm.create_graph_from_morphemes_data(
|
| 323 |
+
morphemes=morphemes_with_feats,
|
| 324 |
+
edges=edges,
|
| 325 |
+
text=text,
|
| 326 |
+
for_training=True,
|
| 327 |
+
)
|
| 328 |
+
if graph is None:
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
graph_file = output_dir / f"graph_{annotation_idx:04d}.pt"
|
| 332 |
+
payload = {
|
| 333 |
+
"graph": graph,
|
| 334 |
+
"source_id": source_id,
|
| 335 |
+
"text": text,
|
| 336 |
+
}
|
| 337 |
+
torch.save(payload, graph_file)
|
| 338 |
+
|
| 339 |
+
total_stats["sentences"] += 1
|
| 340 |
+
total_stats["morphemes"] += len(annotated_morphemes)
|
| 341 |
+
total_stats["positive"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "+")
|
| 342 |
+
total_stats["negative"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "-")
|
| 343 |
+
|
| 344 |
+
annotation_idx += 1
|
| 345 |
+
|
| 346 |
+
total_stats["files"] += 1
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Error ({knp_path}): {e}")
|
| 350 |
+
total_stats["errors"] += 1
|
| 351 |
+
|
| 352 |
+
print("\n" + "=" * 50)
|
| 353 |
+
print("Processing complete")
|
| 354 |
+
print("=" * 50)
|
| 355 |
+
print(f"Files: {total_stats['files']}")
|
| 356 |
+
print(f"Sentences: {total_stats['sentences']}")
|
| 357 |
+
print(f"Morphemes: {total_stats['morphemes']}")
|
| 358 |
+
print(f"Positive (+): {total_stats['positive']}")
|
| 359 |
+
print(f"Negative (-): {total_stats['negative']}")
|
| 360 |
+
#
|
| 361 |
+
if total_stats["errors"] > 0:
|
| 362 |
+
print(f"Errors: {total_stats['errors']}")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "mecari-morpheme"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Japanese morphological analysis using Graph Neural Networks"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11,<3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch>=2.2,<2.3",
|
| 9 |
+
"pytorch-lightning>=2.0.0",
|
| 10 |
+
"torch-geometric>=2.4,<2.5",
|
| 11 |
+
"numpy>=1.24,<2.0",
|
| 12 |
+
"pyyaml>=6.0",
|
| 13 |
+
"tqdm>=4.65.0",
|
| 14 |
+
"kyoto-reader>=2.5.0",
|
| 15 |
+
# Optional: enabled by default via config
|
| 16 |
+
"wandb>=0.15.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
dev = [
|
| 21 |
+
"ipython>=8.14.0",
|
| 22 |
+
"jupyter>=1.0.0",
|
| 23 |
+
"notebook>=7.0.0",
|
| 24 |
+
"pytest>=7.4.0",
|
| 25 |
+
"black>=23.0.0",
|
| 26 |
+
"ruff>=0.1.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[build-system]
|
| 30 |
+
requires = ["setuptools>=61.0"]
|
| 31 |
+
build-backend = "setuptools.build_meta"
|
| 32 |
+
|
| 33 |
+
[tool.setuptools]
|
| 34 |
+
packages = ["mecari"]
|
| 35 |
+
|
| 36 |
+
[tool.uv]
|
| 37 |
+
index-url = "https://pypi.org/simple"
|
| 38 |
+
# Use CUDA 12.1 compatible PyG wheels (matches torch 2.2.x + cu121 environment)
|
| 39 |
+
find-links = ["https://data.pyg.org/whl/torch-2.2.0+cu121.html"]
|
| 40 |
+
|
| 41 |
+
# torch-cluster等のビルド時にtorchが必要
|
| 42 |
+
[tool.uv.extra-build-dependencies]
|
| 43 |
+
# Ensure torch is available when resolving extension wheels
|
| 44 |
+
torch-geometric = ["torch"]
|
| 45 |
+
|
| 46 |
+
[tool.ruff]
|
| 47 |
+
line-length = 120
|
| 48 |
+
target-version = "py311"
|
| 49 |
+
|
| 50 |
+
[tool.ruff.lint]
|
| 51 |
+
select = ["E", "F", "I"]
|
| 52 |
+
ignore = ["E501"] # line too long
|
| 53 |
+
|
| 54 |
+
[tool.black]
|
| 55 |
+
line-length = 120
|
| 56 |
+
target-version = ['py311']
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html
|
| 2 |
+
|
| 3 |
+
# Core runtime
|
| 4 |
+
torch==2.2.2
|
| 5 |
+
torch-scatter
|
| 6 |
+
torch-sparse
|
| 7 |
+
torch-cluster
|
| 8 |
+
torch-spline-conv
|
| 9 |
+
torch-geometric==2.4.0
|
| 10 |
+
pytorch-lightning==2.5.2
|
| 11 |
+
numpy>=1.24,<2.1
|
| 12 |
+
pyyaml>=6.0
|
| 13 |
+
tqdm>=4.65.0
|
| 14 |
+
kyoto-reader>=2.5.0
|
| 15 |
+
|
| 16 |
+
# UI
|
| 17 |
+
gradio>=4.37.0
|
| 18 |
+
|
| 19 |
+
# Optional logger (disabled at runtime)
|
| 20 |
+
wandb>=0.15.0
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.11
|
sample_model/config.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
edge_features:
|
| 2 |
+
use_bidirectional_edges: true
|
| 3 |
+
features:
|
| 4 |
+
lexical_feature_dim: 100000
|
| 5 |
+
inference:
|
| 6 |
+
checkpoint_dir: experiments
|
| 7 |
+
experiment_name: null
|
| 8 |
+
loss:
|
| 9 |
+
label_smoothing: 0.0
|
| 10 |
+
use_pos_weight: true
|
| 11 |
+
model:
|
| 12 |
+
dropout: 0.1
|
| 13 |
+
hidden_dim: 64
|
| 14 |
+
num_classes: 1
|
| 15 |
+
num_heads: 4
|
| 16 |
+
num_layers: 4
|
| 17 |
+
share_weights: false
|
| 18 |
+
type: gatv2
|
| 19 |
+
training:
|
| 20 |
+
accumulate_grad_batches: 1
|
| 21 |
+
annotations_dir: annotations_new
|
| 22 |
+
batch_size: 128
|
| 23 |
+
deterministic: false
|
| 24 |
+
gradient_clip_algorithm: norm
|
| 25 |
+
gradient_clip_val: 0.5
|
| 26 |
+
learning_rate: 0.001
|
| 27 |
+
log_every_n_steps: 50
|
| 28 |
+
max_steps: 10000
|
| 29 |
+
num_workers: 4
|
| 30 |
+
optimizer:
|
| 31 |
+
type: adamw
|
| 32 |
+
weight_decay: 0.001
|
| 33 |
+
patience: 10
|
| 34 |
+
project_name: mecari
|
| 35 |
+
seed: 42
|
| 36 |
+
use_wandb: true
|
| 37 |
+
val_check_interval: 1.0
|
| 38 |
+
warmup_start_lr: 0.0
|
| 39 |
+
warmup_steps: 500
|
sample_model/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccfc112d4a0dcdc0b087c9cabf1b45f1aeae2f1cb8a6f86196a115aa594f68d7
|
| 3 |
+
size 26975745
|
train.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Disable tokenizer parallelism warning
|
| 7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import random
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from importlib import import_module
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pytorch_lightning as pl
|
| 17 |
+
import torch
|
| 18 |
+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
|
| 19 |
+
|
| 20 |
+
from mecari.config.config import get_model_config, override_config, save_config
|
| 21 |
+
from mecari.data.data_module import DataModule
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_seed(seed: int = 42, deterministic: bool = True) -> None:
|
| 25 |
+
"""Set random seeds for reproducibility.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
seed: Random seed value.
|
| 29 |
+
deterministic: If True, enforce deterministic behavior (slower).
|
| 30 |
+
"""
|
| 31 |
+
random.seed(seed)
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
torch.cuda.manual_seed(seed)
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
torch.backends.cudnn.deterministic = deterministic
|
| 37 |
+
torch.backends.cudnn.benchmark = not deterministic
|
| 38 |
+
pl.seed_everything(seed)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_config_sections(config: dict) -> dict:
|
| 42 |
+
"""Extract structured sections from a unified config dict."""
|
| 43 |
+
return {
|
| 44 |
+
"model": config["model"],
|
| 45 |
+
"training": config["training"],
|
| 46 |
+
"features": config.get("features", {}),
|
| 47 |
+
"edge": config.get("edge_features", {}),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def calculate_feature_dim(config: dict) -> int:
|
| 52 |
+
"""Return feature dimension from config (lexical features by default)."""
|
| 53 |
+
features_cfg = config.get("features", {})
|
| 54 |
+
|
| 55 |
+
lexical_dim = features_cfg.get("lexical_feature_dim", 100000)
|
| 56 |
+
return lexical_dim
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_data_module(config: dict) -> DataModule:
|
| 60 |
+
"""Create DataModule from config (lexical-only pipeline)."""
|
| 61 |
+
features_cfg = config.get("features", {})
|
| 62 |
+
training_cfg = config["training"]
|
| 63 |
+
edge_cfg = config.get("edge_features", {})
|
| 64 |
+
|
| 65 |
+
lexical_feature_dim = features_cfg.get("lexical_feature_dim", 100000)
|
| 66 |
+
|
| 67 |
+
return DataModule(
|
| 68 |
+
annotations_dir=training_cfg["annotations_dir"],
|
| 69 |
+
batch_size=training_cfg["batch_size"],
|
| 70 |
+
num_workers=training_cfg["num_workers"],
|
| 71 |
+
max_files=training_cfg.get("max_files"),
|
| 72 |
+
use_bidirectional_edges=edge_cfg.get("use_bidirectional_edges", True),
|
| 73 |
+
annotations_override_dir=training_cfg.get("annotations_override_dir"),
|
| 74 |
+
lexical_feature_dim=lexical_feature_dim,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def setup_loggers(config: dict, experiment_name: str):
|
| 79 |
+
"""Configure optional loggers (e.g., Weights & Biases)."""
|
| 80 |
+
import subprocess
|
| 81 |
+
|
| 82 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 83 |
+
|
| 84 |
+
loggers = []
|
| 85 |
+
|
| 86 |
+
if config["training"]["use_wandb"]:
|
| 87 |
+
try:
|
| 88 |
+
tags = []
|
| 89 |
+
try:
|
| 90 |
+
branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"], text=True).strip()
|
| 91 |
+
tags.append(f"branch:{branch}")
|
| 92 |
+
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
|
| 93 |
+
tags.append(f"commit:{commit}")
|
| 94 |
+
except:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
wandb_logger = WandbLogger(
|
| 98 |
+
project=config["training"]["project_name"],
|
| 99 |
+
name=experiment_name,
|
| 100 |
+
save_dir=f"experiments/{experiment_name}",
|
| 101 |
+
save_code=True,
|
| 102 |
+
log_model=False,
|
| 103 |
+
tags=tags,
|
| 104 |
+
)
|
| 105 |
+
loggers.append(wandb_logger)
|
| 106 |
+
print("✓ Added WandB logger (metrics only)")
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"WandbLogger initialization error: {e}")
|
| 109 |
+
else:
|
| 110 |
+
print("WandB logging disabled")
|
| 111 |
+
|
| 112 |
+
if not loggers:
|
| 113 |
+
loggers = False
|
| 114 |
+
|
| 115 |
+
return loggers
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def create_trainer(config: dict, callbacks: list, loggers, deterministic: bool) -> pl.Trainer:
|
| 119 |
+
"""Create a PyTorch Lightning Trainer."""
|
| 120 |
+
if torch.cuda.is_available():
|
| 121 |
+
accelerator = "gpu"
|
| 122 |
+
devices = 1
|
| 123 |
+
else:
|
| 124 |
+
accelerator = "cpu"
|
| 125 |
+
devices = 1
|
| 126 |
+
|
| 127 |
+
max_steps = config["training"].get("max_steps", 8600)
|
| 128 |
+
max_epochs = -1 # use max_steps only
|
| 129 |
+
|
| 130 |
+
trainer_kwargs = {
|
| 131 |
+
"max_epochs": max_epochs,
|
| 132 |
+
"max_steps": max_steps,
|
| 133 |
+
"callbacks": callbacks,
|
| 134 |
+
"logger": loggers,
|
| 135 |
+
"accelerator": accelerator,
|
| 136 |
+
"devices": devices,
|
| 137 |
+
"log_every_n_steps": config["training"]["log_every_n_steps"],
|
| 138 |
+
"val_check_interval": config["training"]["val_check_interval"],
|
| 139 |
+
"gradient_clip_val": config["training"]["gradient_clip_val"],
|
| 140 |
+
"enable_checkpointing": True,
|
| 141 |
+
"enable_progress_bar": True,
|
| 142 |
+
"limit_train_batches": 1.0,
|
| 143 |
+
"limit_val_batches": 1.0,
|
| 144 |
+
"limit_test_batches": 1.0,
|
| 145 |
+
"limit_predict_batches": 1.0,
|
| 146 |
+
"fast_dev_run": False,
|
| 147 |
+
"deterministic": deterministic,
|
| 148 |
+
"benchmark": not deterministic,
|
| 149 |
+
"precision": "16-mixed",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
if "gradient_clip_algorithm" in config["training"]:
|
| 153 |
+
trainer_kwargs["gradient_clip_algorithm"] = config["training"]["gradient_clip_algorithm"]
|
| 154 |
+
|
| 155 |
+
if "accumulate_grad_batches" in config["training"]:
|
| 156 |
+
trainer_kwargs["accumulate_grad_batches"] = config["training"]["accumulate_grad_batches"]
|
| 157 |
+
|
| 158 |
+
return pl.Trainer(**trainer_kwargs)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def create_model_and_datamodule(config: dict, feature_dim: int, data_module: Optional[DataModule] = None):
|
| 162 |
+
"""Create model and ensure DataModule is available (lexical-only)."""
|
| 163 |
+
cfg = get_config_sections(config)
|
| 164 |
+
model_cfg = cfg["model"]
|
| 165 |
+
training_cfg = cfg["training"]
|
| 166 |
+
features_cfg = cfg["features"]
|
| 167 |
+
|
| 168 |
+
if data_module is None:
|
| 169 |
+
data_module = create_data_module(config)
|
| 170 |
+
|
| 171 |
+
common_params = {
|
| 172 |
+
"hidden_dim": model_cfg["hidden_dim"],
|
| 173 |
+
"num_classes": model_cfg["num_classes"],
|
| 174 |
+
"learning_rate": training_cfg["learning_rate"],
|
| 175 |
+
"lexical_feature_dim": features_cfg.get("lexical_feature_dim", 100000),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if model_cfg["type"] == "gatv2":
|
| 179 |
+
MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
|
| 180 |
+
model = MecariGATv2(
|
| 181 |
+
**common_params,
|
| 182 |
+
num_heads=model_cfg["num_heads"],
|
| 183 |
+
share_weights=model_cfg.get("share_weights", False),
|
| 184 |
+
dropout=model_cfg.get("dropout", 0.1),
|
| 185 |
+
attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)),
|
| 186 |
+
add_self_loops_flag=model_cfg.get("add_self_loops", True),
|
| 187 |
+
edge_dropout=model_cfg.get("edge_dropout", 0.0),
|
| 188 |
+
norm=model_cfg.get("norm", "layer"),
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Unsupported model type: {model_cfg['type']}")
|
| 192 |
+
|
| 193 |
+
return model, data_module
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def main():
|
| 197 |
+
parser = argparse.ArgumentParser(description="Train the morphological analysis model")
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--model",
|
| 200 |
+
"-m",
|
| 201 |
+
choices=["gatv2"],
|
| 202 |
+
default="gatv2",
|
| 203 |
+
help="Model type (only gatv2 supported). If a config is provided, config.model.type takes precedence.",
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument("--config", "-c", help="Path to config file (overrides model type if present)")
|
| 206 |
+
parser.add_argument("--batch-size", "-b", type=int, help="Batch size")
|
| 207 |
+
parser.add_argument("--steps", "-s", type=int, help="Max training steps")
|
| 208 |
+
parser.add_argument("--lr", type=float, help="Learning rate")
|
| 209 |
+
parser.add_argument("--hidden-dim", type=int, help="Hidden dimension size")
|
| 210 |
+
parser.add_argument("--patience", type=int, help="Early stopping patience")
|
| 211 |
+
parser.add_argument("--weight-decay", type=float, help="Weight decay")
|
| 212 |
+
parser.add_argument("--no-wandb", action="store_true", help="Disable Weights & Biases logging")
|
| 213 |
+
parser.add_argument("--seed", type=int, help="Random seed")
|
| 214 |
+
parser.add_argument("--no-deterministic", action="store_true", help="Disable deterministic mode for speed")
|
| 215 |
+
parser.add_argument("--resume", type=str, help="Experiment name to resume (e.g., gatv2_20250806_162945)")
|
| 216 |
+
args = parser.parse_args()
|
| 217 |
+
|
| 218 |
+
# Load/merge config
|
| 219 |
+
if args.config:
|
| 220 |
+
from mecari.config.config import load_config
|
| 221 |
+
|
| 222 |
+
config = load_config(args.config)
|
| 223 |
+
if "model" in config and "type" in config["model"]:
|
| 224 |
+
args.model = config["model"]["type"]
|
| 225 |
+
else:
|
| 226 |
+
config = get_model_config(args.model)
|
| 227 |
+
|
| 228 |
+
overrides = {}
|
| 229 |
+
|
| 230 |
+
# Training overrides
|
| 231 |
+
training_overrides = {}
|
| 232 |
+
if args.batch_size:
|
| 233 |
+
training_overrides["batch_size"] = args.batch_size
|
| 234 |
+
if args.steps:
|
| 235 |
+
training_overrides["max_steps"] = args.steps
|
| 236 |
+
if args.lr:
|
| 237 |
+
training_overrides["learning_rate"] = args.lr
|
| 238 |
+
if args.no_wandb:
|
| 239 |
+
training_overrides["use_wandb"] = False
|
| 240 |
+
if args.patience:
|
| 241 |
+
training_overrides["patience"] = args.patience
|
| 242 |
+
if args.seed:
|
| 243 |
+
training_overrides["seed"] = args.seed
|
| 244 |
+
if args.no_deterministic:
|
| 245 |
+
training_overrides["deterministic"] = False
|
| 246 |
+
|
| 247 |
+
if training_overrides:
|
| 248 |
+
overrides["training"] = training_overrides
|
| 249 |
+
|
| 250 |
+
# Model overrides
|
| 251 |
+
if args.hidden_dim:
|
| 252 |
+
overrides["model"] = {"hidden_dim": args.hidden_dim}
|
| 253 |
+
|
| 254 |
+
# Optimizer overrides
|
| 255 |
+
if args.weight_decay:
|
| 256 |
+
overrides.setdefault("training", {})
|
| 257 |
+
overrides["training"]["optimizer"] = {"weight_decay": args.weight_decay}
|
| 258 |
+
|
| 259 |
+
if overrides:
|
| 260 |
+
config = override_config(config, overrides)
|
| 261 |
+
|
| 262 |
+
deterministic = config["training"].get("deterministic", True)
|
| 263 |
+
set_seed(config["training"]["seed"], deterministic=deterministic)
|
| 264 |
+
|
| 265 |
+
if not deterministic:
|
| 266 |
+
print("⚡ Performance mode: deterministic=False (reproducibility not guaranteed)")
|
| 267 |
+
|
| 268 |
+
resume_from_checkpoint = None
|
| 269 |
+
experiment_name = None
|
| 270 |
+
if args.resume:
|
| 271 |
+
experiment_path = os.path.join("experiments", args.resume)
|
| 272 |
+
if os.path.exists(experiment_path):
|
| 273 |
+
checkpoint_dir = os.path.join(experiment_path, "checkpoints")
|
| 274 |
+
if os.path.exists(checkpoint_dir):
|
| 275 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
|
| 276 |
+
if checkpoints:
|
| 277 |
+
checkpoints.sort()
|
| 278 |
+
resume_from_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
|
| 279 |
+
print(f"Resuming training from: {resume_from_checkpoint}")
|
| 280 |
+
experiment_name = args.resume
|
| 281 |
+
|
| 282 |
+
config_path = os.path.join(experiment_path, "config.yaml")
|
| 283 |
+
if os.path.exists(config_path):
|
| 284 |
+
from mecari.config.config import load_config
|
| 285 |
+
|
| 286 |
+
config = load_config(config_path)
|
| 287 |
+
print(f"Restored config from: {config_path}")
|
| 288 |
+
else:
|
| 289 |
+
print(f"Warning: No checkpoints found in: {checkpoint_dir}")
|
| 290 |
+
else:
|
| 291 |
+
print(f"Warning: Checkpoint directory not found: {checkpoint_dir}")
|
| 292 |
+
else:
|
| 293 |
+
print(f"Warning: Experiment directory not found: {experiment_path}")
|
| 294 |
+
else:
|
| 295 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 296 |
+
experiment_name = f"{config['model']['type']}_{timestamp}"
|
| 297 |
+
|
| 298 |
+
print(f"Experiment: {experiment_name}")
|
| 299 |
+
print(f"Model: {config['model']['type'].upper()}")
|
| 300 |
+
print("Lexical features: enabled (default)")
|
| 301 |
+
|
| 302 |
+
if torch.cuda.is_available():
|
| 303 |
+
print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
|
| 304 |
+
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 305 |
+
else:
|
| 306 |
+
print("💻 Using CPU")
|
| 307 |
+
|
| 308 |
+
data_module = create_data_module(config)
|
| 309 |
+
|
| 310 |
+
feature_dim = calculate_feature_dim(config)
|
| 311 |
+
|
| 312 |
+
model, _ = create_model_and_datamodule(config, feature_dim, data_module)
|
| 313 |
+
|
| 314 |
+
# Attach training config for schedulers, etc.
|
| 315 |
+
model.training_config = config["training"]
|
| 316 |
+
|
| 317 |
+
experiment_dir = f"experiments/{experiment_name}"
|
| 318 |
+
if not args.resume:
|
| 319 |
+
os.makedirs(experiment_dir, exist_ok=True)
|
| 320 |
+
save_config(config, f"{experiment_dir}/config.yaml")
|
| 321 |
+
|
| 322 |
+
checkpoint_callback_error = ModelCheckpoint(
|
| 323 |
+
dirpath=f"experiments/{experiment_name}/checkpoints",
|
| 324 |
+
filename=f"{config['model']['type']}-{{epoch:02d}}-{{val_error_epoch:.3f}}",
|
| 325 |
+
monitor="val_error_epoch",
|
| 326 |
+
mode="min",
|
| 327 |
+
save_top_k=1,
|
| 328 |
+
save_last=True,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
early_stopping = EarlyStopping(
|
| 332 |
+
monitor="val_error_epoch", mode="min", patience=config["training"]["patience"], verbose=True, strict=False
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
loggers = setup_loggers(config, experiment_name)
|
| 336 |
+
|
| 337 |
+
callbacks = [checkpoint_callback_error, early_stopping]
|
| 338 |
+
try:
|
| 339 |
+
if loggers:
|
| 340 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 341 |
+
callbacks.append(lr_monitor)
|
| 342 |
+
except Exception:
|
| 343 |
+
pass
|
| 344 |
+
trainer = create_trainer(config, callbacks, loggers, deterministic)
|
| 345 |
+
|
| 346 |
+
print("Starting training...")
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
if resume_from_checkpoint:
|
| 350 |
+
trainer.fit(model, data_module, ckpt_path=resume_from_checkpoint)
|
| 351 |
+
else:
|
| 352 |
+
trainer.fit(model, data_module)
|
| 353 |
+
training_status = "completed"
|
| 354 |
+
|
| 355 |
+
if data_module.test_dataset:
|
| 356 |
+
print("Evaluating on test data...")
|
| 357 |
+
trainer.test(model, data_module)
|
| 358 |
+
print("Training complete!")
|
| 359 |
+
except KeyboardInterrupt:
|
| 360 |
+
print("\nTraining interrupted...")
|
| 361 |
+
training_status = "interrupted"
|
| 362 |
+
except Exception as e:
|
| 363 |
+
print(f"\nError during training: {e}")
|
| 364 |
+
import traceback
|
| 365 |
+
|
| 366 |
+
traceback.print_exc()
|
| 367 |
+
training_status = "error"
|
| 368 |
+
|
| 369 |
+
print(f"Experiment: {experiment_name}")
|
| 370 |
+
print(f"Experiment dir: experiments/{experiment_name}")
|
| 371 |
+
|
| 372 |
+
print("\n=== Saved models ===")
|
| 373 |
+
|
| 374 |
+
if checkpoint_callback_error.best_model_path:
|
| 375 |
+
best_error = (
|
| 376 |
+
float(checkpoint_callback_error.best_model_score)
|
| 377 |
+
if checkpoint_callback_error.best_model_score is not None
|
| 378 |
+
else 1.0
|
| 379 |
+
)
|
| 380 |
+
print(f" Best val_error: {best_error:.6f}")
|
| 381 |
+
print(f" → {os.path.basename(checkpoint_callback_error.best_model_path)}")
|
| 382 |
+
|
| 383 |
+
print(f"\nFinal epoch: {trainer.current_epoch}")
|
| 384 |
+
print(f"Training status: {training_status}")
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
main()
|
up_hf.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import HfApi
|
| 3 |
+
api = HfApi()
|
| 4 |
+
repo_id = "zbller/Mecari"
|
| 5 |
+
api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio", private=False, exist_ok=True, token=os.environ["HF_TOKEN"])
|
| 6 |
+
api.upload_folder(
|
| 7 |
+
repo_id=repo_id,
|
| 8 |
+
repo_type="space",
|
| 9 |
+
folder_path=".",
|
| 10 |
+
path_in_repo=".",
|
| 11 |
+
ignore_patterns=[
|
| 12 |
+
".git", ".git/**", ".venv", ".venv/**", "__pycache__", "**/__pycache__",
|
| 13 |
+
"KWDLC", "KWDLC/**", "annotations", "annotations/**", "experiments", "experiments/**",
|
| 14 |
+
"mecari_morpheme.egg-info", "mecari_morpheme.egg-info/**",
|
| 15 |
+
],
|
| 16 |
+
token=os.environ["HF_TOKEN"],
|
| 17 |
+
)
|
| 18 |
+
print(f"Uploaded to https://huggingface.co/spaces/{repo_id}")
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|