zbller commited on
Commit
34c8a90
·
verified ·
1 Parent(s): 4150c2c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ abst.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ .venv
10
+
11
+ annotations*/
12
+ experiments/
13
+ lightning_logs/
14
+ cache/
15
+ results/
16
+
17
+ KWDLC/
README.md CHANGED
@@ -1,14 +1,153 @@
1
  ---
2
- title: Mecari
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: 'Demo of Mecari: GNN-based Morphological Analyzer'
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Mecari Morpheme Analyzer
3
+ emoji: 🧩
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # Mecari (Japanese Morphological Analysis with Graph Neural Networks)
13
+
14
+ ## Training
15
+
16
+ ### Overview
17
+
18
+ Mecari [1] is a GNN‑based Japanese morphological analyzer. It supports training from partially annotated graphs (only '+'/'-' where available; '?' is ignored) and aims for fast training and inference.
19
+
20
+ <p align="center">
21
+ <img src="abst.png" alt="Overview" width="70%" />
22
+ <!-- Adjust width (e.g., 60%, 50%, or px) as desired -->
23
+
24
+ </p>
25
+
26
+ ### Graph
27
+ The graph is built from MeCab morpheme candidates.
28
+
29
+ ### Annotation
30
+ Annotations are created by matching morpheme candidates to gold labels.
31
+ Annotations serve as the training targets (supervision) during learning.
32
+ - `+`: Candidate that exactly matches the gold.
33
+ - `-`: Any other candidate that overlaps by 1+ character with a `+` candidate.
34
+ - `?`: All other candidates (ignored during training).
35
+
36
+ ### Training
37
+ Nodes are featurized with JUMAN++‑style unigram features, edges are modeled as undirected (bidirectional), and a GATv2 [2] is trained on the resulting graphs.
38
+
39
+ ### Inference
40
+ Use the model’s node scores and run Viterbi to search the optimal non‑overlapping path.
41
+
42
+ ## Results (KWDLC test)
43
+
44
+ - Trained model (sample_model): Seg F1 0.9725, POS F1 0.9562
45
+ - MeCab (JUMANDIC) baseline: Seg F1 0.9677, POS F1 0.9465
46
+
47
+ The GATv2 model trained with this repository (current code and `configs/gatv2.yaml`) using the official KWDLC split outperforms MeCab on both segmentation and POS accuracy.
48
+
49
+ ## Tested Environment
50
+
51
+ - OS: Ubuntu 24.04.3 LTS (Noble Numbat)
52
+ - Python: 3.11.3
53
+ - PyTorch: 2.2.2+cu121
54
+ - CUDA (runtime): 12.1 (cu121)
55
+ - MeCab (binary): 0.996
56
+ - JUMANDIC: `/var/lib/mecab/dic/juman-utf8`
57
+
58
+ ## MeCab Setup (Ubuntu 24.04)
59
+ 1) Install packages (includes the JUMANDIC dictionary)
60
+
61
+ ```bash
62
+ sudo apt update
63
+ sudo apt install -y mecab mecab-utils libmecab-dev mecab-jumandic-utf8
64
+ ```
65
+
66
+ 2) Verify installation
67
+
68
+ ```bash
69
+ mecab -v # e.g., mecab of 0.996
70
+ test -d /var/lib/mecab/dic/juman-utf8 && echo "JUMANDIC OK"
71
+ ```
72
+
73
+ ## Project Setup
74
+
75
+ ```bash
76
+ # Install uv if needed
77
+ curl -LsSf https://astral.sh/uv/install.sh | sh
78
+
79
+ # Create venv and install dependencies
80
+ uv venv
81
+ source .venv/bin/activate
82
+ uv sync
83
+ ```
84
+
85
+ ## Quickstart (Morphological analysis)
86
+
87
+ ```bash
88
+ # Analyze a single sentence with the bundled sample model
89
+ python infer.py --text "東京都の外国人参政権"
90
+
91
+ # Interactive mode
92
+ python infer.py
93
+
94
+ # After training, specify an experiment to use a custom model
95
+ python infer.py --experiment gatv2_YYYYMMDD_HHMMSS --text "..."
96
+ ```
97
+
98
+ Note
99
+ - When no experiment is specified, the model at `sample_model/` is loaded by default.
100
+
101
+ ## Train by yourself
102
+ ### KWDLC Setup (Required)
103
+
104
+ ```bash
105
+ cd /path/to/Mecari
106
+ git clone --depth 1 https://github.com/ku-nlp/KWDLC
107
+ ```
108
+
109
+ - Training requires KWDLC (non‑KWDLC training is not supported at the moment).
110
+ - Splits strictly follow the official `dev.id` / `test.id` files.
111
+
112
+
113
+ ### Preprocessing
114
+
115
+ ```bash
116
+ python preprocess.py --config configs/gatv2.yaml
117
+ ```
118
+
119
+ ### Training
120
+
121
+ ```bash
122
+ python train.py --config configs/gatv2.yaml
123
+ ```
124
+
125
+ - Outputs are saved under `experiments/<name>/`.
126
+ - The bundled model was trained with the current codebase and configuration (`configs/gatv2.yaml`).
127
+
128
+ ### Evaluation
129
+
130
+ ```bash
131
+ python evaluate.py --max-samples 50 \
132
+ --experiment gatv2_YYYYMMDD_HHMMSS
133
+ ```
134
+
135
+
136
+ ## License
137
+
138
+ CC BY‑NC 4.0 (non‑commercial use only)
139
+
140
+ ## Acknowledgments
141
+ - [1] Technical inspiration: Mecari, a morphological analysis system developed by Google, as described in “Data processing for Japanese text‑to‑pronunciation models” by G. Mazovetskiy and T. Kudo (NLP2024 Workshop on Japanese Language Resources). URL: https://jedworkshop.github.io/JLR2024/materials/b-2.pdf (pp. 19–23)
142
+ - [2] Graph architecture: Brody, Shaked, Uri Alon, and Eran Yahav. "HOW ATTENTIVE ARE GRAPH ATTENTION NETWORKS?." 10th International Conference on Learning Representations, ICLR 2022. 2022.
143
+
144
+ ## Disclaimer
145
+ - Independent academic implementation for educational and research purposes.
146
+ - Core concepts (graph‑based morpheme boundary annotation) follow the published work; implementation details and code structure are our interpretation.
147
+ - Not affiliated with, endorsed by, or connected to Google or its subsidiaries.
148
+
149
+ ## Purpose
150
+ - Academic research
151
+ - Education
152
+ - Technical skill development
153
+ - Understanding of NLP techniques
abst.png ADDED

Git LFS Details

  • SHA256: 0a8f3de11ee14f75fe879878912d5c49fb761b5ee773ad97418427922a521742
  • Pointer size: 131 Bytes
  • Size of remote file: 344 kB
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import subprocess
6
+ import gradio as gr
7
+
8
+ # Ensure wandb never starts in Spaces
9
+ os.environ["WANDB_MODE"] = "disabled"
10
+
11
+ # Resolve MeCab binary for this process
12
+ _default_mecab = "/usr/bin/mecab" if os.path.exists("/usr/bin/mecab") else "mecab"
13
+ MECAB_BIN = os.getenv("MECAB_BIN", _default_mecab)
14
+ os.environ["MECAB_BIN"] = MECAB_BIN
15
+
16
+ # Lazy-loaded model
17
+ _model = None
18
+ _exp_info = None
19
+
20
+
21
+ def _ensure_model():
22
+ global _model, _exp_info
23
+ if _model is None:
24
+ from infer import load_model
25
+
26
+ result = load_model()
27
+ if result is None:
28
+ raise RuntimeError(
29
+ "Model could not be loaded. Ensure sample_model/ exists with config.yaml and model.pt."
30
+ )
31
+ _model, _exp_info = result
32
+
33
+
34
+ def _to_mecab_lines(results, optimal_morphemes=None) -> str:
35
+ # Build MeCab-like output lines
36
+ def mecab_features(m):
37
+ pos = m.get("pos", "*")
38
+ pos1 = m.get("pos_detail1", "*")
39
+ pos2 = m.get("pos_detail2", "*")
40
+ ctype = m.get("inflection_type", "*")
41
+ cform = m.get("inflection_form", "*")
42
+ base = m.get("base_form", m.get("lemma", "*")) or "*"
43
+ # Mecari output includes reading as 7th field
44
+ reading = m.get("reading", "*") or "*"
45
+ return f"{pos},{pos1},{pos2},{ctype},{cform},{base},{reading}"
46
+
47
+ items = (
48
+ optimal_morphemes
49
+ if optimal_morphemes
50
+ else [
51
+ {
52
+ "surface": r.get("surface", ""),
53
+ "pos": r.get("pos", "*"),
54
+ "pos_detail1": "*",
55
+ "pos_detail2": "*",
56
+ "inflection_type": "*",
57
+ "inflection_form": "*",
58
+ "base_form": r.get("surface", ""),
59
+ "reading": r.get("reading", "*"),
60
+ }
61
+ for r in results
62
+ ]
63
+ )
64
+
65
+ lines = [f"{m.get('surface','')}\t{mecab_features(m)}" for m in items]
66
+ lines.append("EOS")
67
+ return "\n".join(lines)
68
+
69
+
70
+ def mecab_plain(text: str) -> str:
71
+ """Run system MeCab and return its raw parsing (surface\tCSV ...\nEOS)."""
72
+ try:
73
+ from mecari.analyzers.mecab import MeCabAnalyzer
74
+
75
+ analyzer = MeCabAnalyzer()
76
+ mecab_bin = os.getenv("MECAB_BIN", analyzer.mecab_bin)
77
+ args = [mecab_bin]
78
+ if isinstance(analyzer.jumandic_path, str) and os.path.isdir(analyzer.jumandic_path):
79
+ args += ["-d", analyzer.jumandic_path]
80
+ p = subprocess.run(args, input=text, text=True, capture_output=True)
81
+ out = (p.stdout or "") + ("\n" + p.stderr if p.stderr else "")
82
+ if p.returncode != 0:
83
+ return out.strip() or f"mecab error rc={p.returncode}"
84
+ # Trim extra tail fields (e.g., カテゴリ:*, ドメイン:*) and keep first 6 features
85
+ lines = []
86
+ for line in out.splitlines():
87
+ if not line or line.strip() == "EOS":
88
+ lines.append("EOS")
89
+ continue
90
+ if "\t" in line:
91
+ surface, feats = line.split("\t", 1)
92
+ parts = [s.strip() for s in feats.split(",")]
93
+ trimmed = parts[:6]
94
+ while len(trimmed) < 6:
95
+ trimmed.append("*")
96
+ lines.append(f"{surface}\t{','.join(trimmed)}")
97
+ else:
98
+ lines.append(line)
99
+ # Ensure trailing EOS only once
100
+ if not lines or lines[-1] != "EOS":
101
+ lines.append("EOS")
102
+ return "\n".join(lines)
103
+ except FileNotFoundError:
104
+ return "MeCabバイナリが見つかりません(MECAB_BINやpackages.txtを確認)。"
105
+ except Exception as e:
106
+ return f"mecab実行時エラー: {e}"
107
+
108
+
109
+ def analyze(text: str):
110
+ if not text or not text.strip():
111
+ return "", ""
112
+
113
+ try:
114
+ _ensure_model()
115
+ from infer import predict_morphemes_from_text
116
+
117
+ text = text.strip()
118
+ result = predict_morphemes_from_text(text, _model, _exp_info, silent=True)
119
+ if not result:
120
+ return "推論に失敗しました。", mecab_plain(text)
121
+ results, optimal_morphemes = result
122
+ mecari_out = _to_mecab_lines(results, optimal_morphemes)
123
+ mecab_out = mecab_plain(text)
124
+ return mecari_out, mecab_out
125
+ except FileNotFoundError:
126
+ return (
127
+ "MeCabが見つかりません。Spaceのpackages.txtに 'mecab' と 'mecab-jumandic-utf8' を含めてビルドし直すか、\n"
128
+ "変数 MECAB_BIN=/usr/bin/mecab を設定してください。"
129
+ ), ""
130
+ except Exception as e:
131
+ import traceback
132
+
133
+ tb = traceback.format_exc()
134
+ return f"エラー: {e}\n\n{tb}", ""
135
+
136
+
137
+ FONT_CSS = """
138
+ /* Prefer common system fonts for Latin text */
139
+ body, .gradio-container, .prose, textarea, input, button,
140
+ .gr-text-input input, .gr-text-input textarea, .gr-textbox textarea {
141
+ font-family: system-ui, -apple-system, 'Segoe UI', Roboto, 'Noto Sans',
142
+ 'Helvetica Neue', Arial, 'Apple Color Emoji', 'Segoe UI Emoji',
143
+ sans-serif !important;
144
+ }
145
+ """
146
+
147
+ with gr.Blocks(theme=gr.themes.Soft(), css=FONT_CSS) as demo:
148
+ gr.Markdown(
149
+ """
150
+ # Mecari Morpheme Analyzer
151
+
152
+ GNNベースの形態素解析器"Mecari"のデモです。github: https://github.com/zbller/Mecari
153
+ """
154
+ )
155
+
156
+ with gr.Row():
157
+ inp = gr.Textbox(label="テキスト入力", value="とうきょうに行った", placeholder="とうきょうに行った", lines=3)
158
+ btn = gr.Button("解析する")
159
+ with gr.Row():
160
+ out_mecari = gr.Textbox(label="Mecari", lines=10)
161
+ out_mecab = gr.Textbox(label="MeCab(Jumandic)", lines=10)
162
+ btn.click(fn=analyze, inputs=inp, outputs=[out_mecari, out_mecab])
163
+
164
+ # Optional warm-up
165
+ def _warmup():
166
+ try:
167
+ _ensure_model()
168
+ except Exception:
169
+ pass
170
+
171
+ _warmup()
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7863")))
configs/base.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ features:
2
+ lexical_feature_dim: 100000
3
+
4
+ training:
5
+ deterministic: false
6
+ annotations_dir: "annotations"
7
+ project_name: "mecari"
8
+
9
+ inference:
10
+ checkpoint_dir: "experiments"
11
+ experiment_name: null
configs/gatv2.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: "base.yaml"
2
+
3
+ model:
4
+ type: "gatv2"
5
+ hidden_dim: 64
6
+ num_layers: 4
7
+ num_heads: 4
8
+ dropout: 0.1
9
+ num_classes: 1
10
+ share_weights: false
11
+
12
+ edge_features:
13
+ use_bidirectional_edges: true
14
+
15
+ training:
16
+ learning_rate: 0.001
17
+ batch_size: 128
18
+ max_steps: 10000
19
+ patience: 10
20
+ gradient_clip_val: 0.5
21
+ gradient_clip_algorithm: "norm"
22
+ num_workers: 4
23
+ accumulate_grad_batches: 1
24
+ seed: 42
25
+ warmup_steps: 500
26
+ warmup_start_lr: 0.0
27
+ optimizer:
28
+ type: "adamw"
29
+ weight_decay: 0.001
30
+ use_wandb: true
31
+ log_every_n_steps: 50
32
+ val_check_interval: 1.0
33
+
34
+ loss:
35
+ use_pos_weight: true
36
+ label_smoothing: 0.0
evaluate.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Unified evaluation for MeCab (JUMANDIC) and the trained model.
5
+
6
+ Evaluates both systems on the same KWDLC test data and compares results.
7
+ """
8
+
9
+ import argparse
10
+ import subprocess
11
+ from pathlib import Path
12
+ from typing import Dict, List
13
+
14
+ import torch
15
+ from tqdm import tqdm
16
+
17
+
18
+ def parse_knp_file(knp_file: Path) -> List[Dict]:
19
+ """Extract gold morphemes from a KNP file."""
20
+ sentences = []
21
+ current_sentence = []
22
+ current_text = ""
23
+
24
+ with open(knp_file, "r", encoding="utf-8") as f:
25
+ for line in f:
26
+ line = line.rstrip("\n")
27
+
28
+ if line.startswith("#"):
29
+ if line.startswith("# S-ID:"):
30
+ if current_sentence:
31
+ sentences.append({"morphemes": current_sentence, "text": current_text})
32
+ current_sentence = []
33
+ current_text = ""
34
+ continue
35
+ elif line == "EOS":
36
+ if current_sentence:
37
+ sentences.append({"morphemes": current_sentence, "text": current_text})
38
+ current_sentence = []
39
+ current_text = ""
40
+ elif line.startswith("+") or line.startswith("*"):
41
+ continue
42
+ elif line:
43
+ parts = line.split(" ")
44
+ if len(parts) >= 4:
45
+ surface = parts[0]
46
+ reading = parts[1]
47
+ pos = parts[3]
48
+
49
+ current_sentence.append({"surface": surface, "reading": reading, "pos": pos})
50
+ current_text += surface
51
+
52
+ return sentences
53
+
54
+
55
+ def analyze_with_mecab(text: str) -> List[Dict]:
56
+ """Analyze text with MeCab (JUMANDIC) using a simple best-path parse."""
57
+ try:
58
+ result = subprocess.run(
59
+ ["mecab", "-d", "/var/lib/mecab/dic/juman-utf8"],
60
+ input=text,
61
+ capture_output=True,
62
+ text=True,
63
+ encoding="utf-8",
64
+ )
65
+
66
+ if result.returncode != 0:
67
+ return []
68
+
69
+ morphemes = []
70
+ for line in result.stdout.strip().split("\n"):
71
+ if line == "EOS":
72
+ break
73
+ parts = line.split("\t")
74
+ if len(parts) >= 2:
75
+ surface = parts[0]
76
+ features = parts[1].split(",")
77
+ if len(features) >= 7:
78
+ pos = features[0]
79
+ # Do not fallback reading to surface when missing ('*')
80
+ reading = features[7] if len(features) > 7 and features[7] != "*" else ""
81
+
82
+ morphemes.append({"surface": surface, "reading": reading, "pos": pos})
83
+
84
+ return morphemes
85
+ except Exception as e:
86
+ print(f"MeCab error: {e}")
87
+ return []
88
+
89
+
90
+ def analyze_with_jumanpp(text: str) -> List[Dict]:
91
+ """Analyze text with JUMAN++ (optional baseline)."""
92
+ try:
93
+ result = subprocess.run(["jumanpp"], input=text, capture_output=True, text=True, encoding="utf-8")
94
+
95
+ if result.returncode != 0:
96
+ return []
97
+
98
+ morphemes = []
99
+ for line in result.stdout.strip().split("\n"):
100
+ if line.startswith("@") or line == "EOS":
101
+ continue
102
+ parts = line.split(" ")
103
+ if len(parts) >= 12:
104
+ surface = parts[0]
105
+ reading = parts[1]
106
+ pos = parts[3]
107
+
108
+ morphemes.append({"surface": surface, "reading": reading, "pos": pos})
109
+
110
+ return morphemes
111
+ except Exception as e:
112
+ print(f"JUMAN++ error: {e}")
113
+ return []
114
+
115
+
116
+ def analyze_with_model(text: str, model, experiment_info) -> List[Dict]:
117
+ """Analyze text with the trained model."""
118
+ try:
119
+ import infer
120
+
121
+ results, optimal_morphemes = infer.predict_morphemes_from_text(
122
+ text, model=model, experiment_info=experiment_info, silent=True
123
+ )
124
+
125
+ morphemes = []
126
+ for morph in optimal_morphemes:
127
+ morphemes.append(
128
+ {"surface": morph["surface"], "reading": morph.get("reading", ""), "pos": morph.get("pos", "*")}
129
+ )
130
+
131
+ return morphemes
132
+ except Exception as e:
133
+ print(f"Model inference error: {e}")
134
+ return []
135
+
136
+
137
+ def evaluate_morphemes(gold_morphemes: List[Dict], pred_morphemes: List[Dict]) -> Dict:
138
+ """Compute segmentation and POS F1 between gold and predictions."""
139
+ gold_spans = []
140
+ pred_spans = []
141
+
142
+ # Gold spans (from gold morphemes)
143
+ pos = 0
144
+ for m in gold_morphemes:
145
+ surface = m["surface"]
146
+ end = pos + len(surface)
147
+ gold_spans.append((pos, end, m["pos"]))
148
+ pos = end
149
+
150
+ # Predicted spans (from predictions)
151
+ pos = 0
152
+ for m in pred_morphemes:
153
+ surface = m["surface"]
154
+ end = pos + len(surface)
155
+ pred_spans.append((pos, end, m["pos"]))
156
+ pos = end
157
+
158
+ # Segmentation accuracy (without POS)
159
+ gold_seg = {(s, e) for s, e, _ in gold_spans}
160
+ pred_seg = {(s, e) for s, e, _ in pred_spans}
161
+
162
+ seg_correct = len(gold_seg & pred_seg)
163
+ seg_precision = seg_correct / len(pred_seg) if pred_seg else 0
164
+ seg_recall = seg_correct / len(gold_seg) if gold_seg else 0
165
+ seg_f1 = 2 * seg_precision * seg_recall / (seg_precision + seg_recall) if (seg_precision + seg_recall) > 0 else 0
166
+
167
+ # Accuracy with POS
168
+ gold_pos = set(gold_spans)
169
+ pred_pos = set(pred_spans)
170
+
171
+ pos_correct = len(gold_pos & pred_pos)
172
+ pos_precision = pos_correct / len(pred_pos) if pred_pos else 0
173
+ pos_recall = pos_correct / len(gold_pos) if gold_pos else 0
174
+ pos_f1 = 2 * pos_precision * pos_recall / (pos_precision + pos_recall) if (pos_precision + pos_recall) > 0 else 0
175
+
176
+ return {
177
+ "seg_precision": seg_precision,
178
+ "seg_recall": seg_recall,
179
+ "seg_f1": seg_f1,
180
+ "pos_precision": pos_precision,
181
+ "pos_recall": pos_recall,
182
+ "pos_f1": pos_f1,
183
+ }
184
+
185
+
186
+ def main():
187
+ parser = argparse.ArgumentParser(description="Unified evaluation script")
188
+ parser.add_argument("--kwdlc-dir", type=str, default="KWDLC", help="Path to KWDLC root directory")
189
+ parser.add_argument(
190
+ "--test-ids", type=str, default="KWDLC/id/split_for_pas/test.id", help="File containing test IDs (one per line)"
191
+ )
192
+ parser.add_argument(
193
+ "--max-samples", type=int, default=None, help="Max number of samples to evaluate (default: all)"
194
+ )
195
+ parser.add_argument("--experiment", "-e", type=str, required=True, help="Experiment name to evaluate")
196
+
197
+ args = parser.parse_args()
198
+
199
+ # Load test document IDs
200
+ test_ids = []
201
+ with open(args.test_ids, "r") as f:
202
+ for line in f:
203
+ test_ids.append(line.strip())
204
+
205
+ if args.max_samples is not None:
206
+ test_ids = test_ids[: args.max_samples]
207
+
208
+ print(f"Evaluating: {len(test_ids)} files")
209
+
210
+ import infer
211
+
212
+ model_info = infer.load_model(experiment_name=args.experiment)
213
+ if model_info:
214
+ model, experiment_info = model_info
215
+ # Force CPU execution for evaluation
216
+ device = torch.device("cpu")
217
+ model = model.to(device)
218
+ experiment_info["device"] = device
219
+ print(f"Model: {experiment_info['name']}")
220
+ else:
221
+ print("Failed to load model")
222
+ model = None
223
+ experiment_info = None
224
+
225
+ mecab_results = []
226
+ model_results = []
227
+
228
+ print("\nStart evaluation...")
229
+ for test_id in tqdm(test_ids, desc="evaluating"):
230
+ # Find KNP file
231
+ found = False
232
+ knp_base = Path(args.kwdlc_dir) / "knp"
233
+
234
+ for subdir in knp_base.glob("w*"):
235
+ candidate = subdir / f"{test_id}.knp"
236
+ if candidate.exists():
237
+ knp_path = candidate
238
+ found = True
239
+ break
240
+
241
+ if not found:
242
+ continue
243
+
244
+ # Read gold data
245
+ gold_sentences = parse_knp_file(knp_path)
246
+
247
+ for sent_data in gold_sentences:
248
+ text = sent_data["text"]
249
+ gold_morphemes = sent_data["morphemes"]
250
+
251
+ # MeCab (JUMANDIC)
252
+ pred_mecab = analyze_with_mecab(text)
253
+ if pred_mecab:
254
+ result = evaluate_morphemes(gold_morphemes, pred_mecab)
255
+ mecab_results.append(result)
256
+
257
+ # Trained model
258
+ if model is not None:
259
+ pred_model = analyze_with_model(text, model, experiment_info)
260
+ if pred_model:
261
+ model_eval = evaluate_morphemes(gold_morphemes, pred_model)
262
+ model_results.append(model_eval)
263
+
264
+ # Aggregate and display results
265
+ print("\n" + "=" * 70)
266
+ print("Evaluation Results (KWDLC test data)")
267
+ print("=" * 70)
268
+ print(f"Num evaluated: MeCab={len(mecab_results)}, Model={len(model_results)}")
269
+
270
+ # MeCab (JUMANDIC)
271
+ if mecab_results:
272
+ avg_seg_f1 = sum(r["seg_f1"] for r in mecab_results) / len(mecab_results)
273
+ avg_pos_f1 = sum(r["pos_f1"] for r in mecab_results) / len(mecab_results)
274
+ print("\n[1] MeCab (JUMANDIC):")
275
+ print(f" Seg F1: {avg_seg_f1:.4f}")
276
+ print(f" POS F1: {avg_pos_f1:.4f}")
277
+
278
+ # Trained model
279
+ if model_results:
280
+ avg_seg_f1 = sum(r["seg_f1"] for r in model_results) / len(model_results)
281
+ avg_pos_f1 = sum(r["pos_f1"] for r in model_results) / len(model_results)
282
+ print(f"\n[2] Trained model ({experiment_info['name']}):")
283
+ print(f" Seg F1: {avg_seg_f1:.4f}")
284
+ print(f" POS F1: {avg_pos_f1:.4f}")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()
infer.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Show immediate feedback from the moment the command starts
5
+ print("Loading model...", flush=True)
6
+
7
+ import os
8
+ import random
9
+ from typing import Any, Dict, Optional, Tuple
10
+
11
+ # Disable WandB during inference to avoid hanging processes
12
+ os.environ["WANDB_MODE"] = "disabled"
13
+
14
+ from importlib import import_module
15
+
16
+ import numpy as np
17
+ import torch
18
+ import yaml
19
+
20
+ from mecari.analyzers.mecab import MeCabAnalyzer
21
+ from mecari.data.data_module import DataModule
22
+ from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates
23
+
24
+
25
+ def set_seed(seed: int = 42) -> None:
26
+ """Set random seeds for reproducibility during inference.
27
+
28
+ Args:
29
+ seed: Random seed value.
30
+ """
31
+ random.seed(seed)
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+ torch.cuda.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+
39
+
40
+ set_seed(42)
41
+
42
+
43
+ def _find_best_checkpoint(checkpoints_dir: str, prefer_metric: str = "val_error") -> Tuple[Optional[str], float]:
44
+ """Find the best checkpoint file in a directory.
45
+
46
+ Args:
47
+ checkpoints_dir: Path to the checkpoints directory.
48
+ prefer_metric: Preferred metric ("val_error" or "val_loss").
49
+
50
+ Returns:
51
+ Tuple of (best checkpoint filename, score).
52
+ """
53
+ checkpoint_files = [f for f in os.listdir(checkpoints_dir) if f.endswith(".ckpt")]
54
+ if not checkpoint_files:
55
+ return None, float("inf")
56
+
57
+ best_checkpoint = None
58
+ best_score = float("inf")
59
+
60
+ # Prefer filenames that include the metric keyword (e.g., val_error=..., val_error_epoch=...)
61
+ for ckpt_file in checkpoint_files:
62
+ if prefer_metric == "val_loss" and ("val_loss=" in ckpt_file or "val_loss_epoch=" in ckpt_file):
63
+ try:
64
+ if "val_loss_epoch=" in ckpt_file:
65
+ score_str = ckpt_file.split("val_loss_epoch=")[-1].split(".ckpt")[0]
66
+ else:
67
+ score_str = ckpt_file.split("val_loss=")[-1].split(".ckpt")[0]
68
+ score = float(score_str)
69
+ if score < best_score:
70
+ best_score = score
71
+ best_checkpoint = ckpt_file
72
+ except (ValueError, IndexError):
73
+ pass
74
+ elif prefer_metric == "val_error" and ("val_error=" in ckpt_file or "val_error_epoch=" in ckpt_file):
75
+ try:
76
+ if "val_error_epoch=" in ckpt_file:
77
+ score_str = ckpt_file.split("val_error_epoch=")[-1].split(".ckpt")[0]
78
+ else:
79
+ score_str = ckpt_file.split("val_error=")[-1].split(".ckpt")[0]
80
+ score = float(score_str)
81
+ if score < best_score:
82
+ best_score = score
83
+ best_checkpoint = ckpt_file
84
+ except (ValueError, IndexError):
85
+ pass
86
+
87
+ # If not found, try the alternative metric
88
+ if not best_checkpoint:
89
+ other_metric = "val_loss" if prefer_metric == "val_error" else "val_error"
90
+ for ckpt_file in checkpoint_files:
91
+ if other_metric == "val_loss" and "val_loss=" in ckpt_file:
92
+ try:
93
+ score_str = ckpt_file.split("val_loss=")[1].split("-loss.ckpt")[0]
94
+ score = float(score_str)
95
+ if score < best_score:
96
+ best_score = score
97
+ best_checkpoint = ckpt_file
98
+ except (ValueError, IndexError):
99
+ pass
100
+ elif other_metric == "val_error" and "val_error=" in ckpt_file:
101
+ try:
102
+ score_str = ckpt_file.split("val_error=")[1].split(".ckpt")[0]
103
+ score = float(score_str)
104
+ if score < best_score:
105
+ best_score = score
106
+ best_checkpoint = ckpt_file
107
+ except (ValueError, IndexError):
108
+ pass
109
+
110
+ # Additional fallback: parse score from filename pattern (model-epoch-score.ckpt)
111
+ if not best_checkpoint:
112
+ for ckpt_file in sorted(checkpoint_files):
113
+ if ckpt_file == "last.ckpt":
114
+ continue
115
+ try:
116
+ stem = ckpt_file[:-5] if ckpt_file.endswith(".ckpt") else ckpt_file
117
+ # Fallback: treat the last hyphen-separated token as a score
118
+ last_tok = stem.split("-")[-1]
119
+ score = float(last_tok)
120
+ if score < best_score:
121
+ best_score = score
122
+ best_checkpoint = ckpt_file
123
+ except Exception:
124
+ continue
125
+ # Final fallback: use last.ckpt or the first file
126
+ if not best_checkpoint:
127
+ if "last.ckpt" in checkpoint_files:
128
+ best_checkpoint = "last.ckpt"
129
+ else:
130
+ best_checkpoint = sorted(checkpoint_files)[0]
131
+
132
+ return best_checkpoint, best_score
133
+
134
+
135
+ def _load_model_by_type(model_type: str, checkpoint_path: str) -> Any:
136
+ """Load the appropriate model class based on type.
137
+
138
+ Args:
139
+ model_type: Model type ("gat" or "gatv2").
140
+ checkpoint_path: Path to the checkpoint file.
141
+
142
+ Returns:
143
+ Loaded model instance.
144
+ """
145
+ if model_type == "gatv2":
146
+ cls = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
147
+ model = cls.load_from_checkpoint(checkpoint_path, strict=False, map_location="cpu")
148
+
149
+ model.eval()
150
+ model.cpu()
151
+ return model
152
+
153
+
154
+ def _instantiate_model_from_config(config: Dict[str, Any]):
155
+ """Instantiate a model using config fields (no checkpoint loading)."""
156
+ model_cfg = config.get("model", {})
157
+ training_cfg = config.get("training", {})
158
+ features_cfg = config.get("features", {})
159
+
160
+ if model_cfg.get("type") != "gatv2":
161
+ raise ValueError(f"Unsupported model type: {model_cfg.get('type')}")
162
+
163
+ MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
164
+ model = MecariGATv2(
165
+ hidden_dim=model_cfg.get("hidden_dim", 64),
166
+ num_classes=model_cfg.get("num_classes", 1),
167
+ learning_rate=training_cfg.get("learning_rate", 1e-3),
168
+ lexical_feature_dim=features_cfg.get("lexical_feature_dim", 100000),
169
+ num_heads=model_cfg.get("num_heads", 4),
170
+ share_weights=model_cfg.get("share_weights", False),
171
+ dropout=model_cfg.get("dropout", 0.1),
172
+ attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)),
173
+ add_self_loops_flag=model_cfg.get("add_self_loops", True),
174
+ edge_dropout=model_cfg.get("edge_dropout", 0.0),
175
+ norm=model_cfg.get("norm", "layer"),
176
+ )
177
+ return model
178
+
179
+
180
+ def _load_model_from_state(config_path: str, state_path: str):
181
+ """Load model from a plain state_dict plus config.yaml."""
182
+ with open(config_path, "r", encoding="utf-8") as f:
183
+ config = yaml.safe_load(f)
184
+
185
+ model = _instantiate_model_from_config(config)
186
+ state = torch.load(state_path, map_location="cpu")
187
+ # Lightning checkpoints saved via export may store under 'state_dict' already
188
+ if (
189
+ isinstance(state, dict)
190
+ and "state_dict" in state
191
+ and all(k.startswith("model.") for k in state["state_dict"].keys())
192
+ ):
193
+ state = state["state_dict"]
194
+ # Remove potential 'model.' prefix if present (depends on save path)
195
+ new_state = {}
196
+ for k, v in state.items():
197
+ nk = k
198
+ if k.startswith("model."):
199
+ nk = k[len("model.") :]
200
+ new_state[nk] = v
201
+ model.load_state_dict(new_state, strict=False)
202
+ model.eval()
203
+ model.cpu()
204
+ return model
205
+
206
+
207
+ def load_model(
208
+ experiment_name: Optional[str] = None, model_type: Optional[str] = None, prefer_metric: str = "val_error"
209
+ ) -> Optional[Tuple[Any, Dict[str, Any]]]:
210
+ """Load a trained model and its experiment info.
211
+
212
+ Default behavior: load the single model under sample_model/.
213
+ If --experiment is provided (or sample_model is unavailable), use experiments/.
214
+ """
215
+ # Default: load from sample_model/
216
+ if not experiment_name:
217
+ root = "sample_model"
218
+ if os.path.exists(root):
219
+ fixed_config = os.path.join(root, "config.yaml")
220
+ state_path = os.path.join(root, "model.pt")
221
+ if os.path.exists(fixed_config) and os.path.exists(state_path):
222
+ try:
223
+ with open(fixed_config, "r", encoding="utf-8") as f:
224
+ config = yaml.safe_load(f)
225
+ model = _load_model_from_state(fixed_config, state_path)
226
+ experiment_info = {
227
+ "name": os.path.basename(root),
228
+ "path": root,
229
+ "best_metric": None,
230
+ "best_score": None,
231
+ "model_type": config.get("model", {}).get("type", "unknown"),
232
+ "best_model_path": state_path,
233
+ "config": config,
234
+ }
235
+ return model, experiment_info
236
+ except Exception as e:
237
+ print(f"Failed to load sample model: {e}")
238
+ return None
239
+ print("sample_model/model.pt or config.yaml not found")
240
+ return None
241
+ else:
242
+ print("sample_model directory not found")
243
+ return None
244
+
245
+ # Specific experiment provided
246
+ if experiment_name:
247
+ exp_path = os.path.join("experiments", experiment_name)
248
+ config_path = os.path.join(exp_path, "config.yaml")
249
+ checkpoints_dir = os.path.join(exp_path, "checkpoints")
250
+
251
+ if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir):
252
+ print(f"Experiment not found: {experiment_name}")
253
+ return None
254
+
255
+ try:
256
+ with open(config_path, "r", encoding="utf-8") as f:
257
+ config = yaml.safe_load(f)
258
+
259
+ model_type_from_config = config.get("model", {}).get("type", "unknown")
260
+ best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric)
261
+
262
+ if not best_checkpoint:
263
+ print("No checkpoint found")
264
+ return None
265
+
266
+ metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error"
267
+
268
+ experiment_info = {
269
+ "name": experiment_name,
270
+ "path": exp_path,
271
+ "val_error": best_score if prefer_metric == "val_error" else None,
272
+ "val_loss": best_score if prefer_metric == "val_loss" else None,
273
+ "best_metric": metric_name,
274
+ "best_score": best_score,
275
+ "model_type": model_type_from_config,
276
+ "best_model_path": os.path.join(checkpoints_dir, best_checkpoint),
277
+ "config": config,
278
+ }
279
+ except Exception as e:
280
+ print(f"Failed to read experiment info: {e}")
281
+ return None
282
+
283
+ # Auto-select the best experiment
284
+ else:
285
+ if not os.path.exists(experiments_dir):
286
+ print("Experiments directory does not exist")
287
+ return None
288
+
289
+ experiments = []
290
+ for exp_dir in os.listdir(experiments_dir):
291
+ exp_path = os.path.join(experiments_dir, exp_dir)
292
+ config_path = os.path.join(exp_path, "config.yaml")
293
+ checkpoints_dir = os.path.join(exp_path, "checkpoints")
294
+
295
+ if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir):
296
+ continue
297
+
298
+ try:
299
+ with open(config_path, "r", encoding="utf-8") as f:
300
+ config = yaml.safe_load(f)
301
+
302
+ exp_model_type = config.get("model", {}).get("type", "unknown")
303
+
304
+ if model_type and exp_model_type.lower() != model_type.lower():
305
+ continue
306
+
307
+ best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric)
308
+ if best_checkpoint:
309
+ metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error"
310
+ experiments.append(
311
+ {
312
+ "name": exp_dir,
313
+ "path": exp_path,
314
+ "val_error": best_score if prefer_metric == "val_error" else None,
315
+ "val_loss": best_score if prefer_metric == "val_loss" else None,
316
+ "best_metric": metric_name,
317
+ "best_score": best_score,
318
+ "model_type": exp_model_type,
319
+ "best_model_path": os.path.join(checkpoints_dir, best_checkpoint),
320
+ "config": config,
321
+ }
322
+ )
323
+ except Exception:
324
+ continue
325
+
326
+ if not experiments:
327
+ print("No available experiments found")
328
+ return None
329
+
330
+ experiment_info = min(experiments, key=lambda x: x["best_score"])
331
+
332
+ # Load model
333
+ print(f"Loading model: {experiment_info['best_model_path']}")
334
+ print(f"Experiment: {experiment_info['name']}")
335
+
336
+ try:
337
+ model = _load_model_by_type(experiment_info["model_type"], experiment_info["best_model_path"])
338
+
339
+ # No BERT features in this pipeline
340
+
341
+ return model, experiment_info
342
+ except Exception as e:
343
+ print(f"Model loading error: {e}")
344
+ return None
345
+
346
+
347
+ def viterbi_decode_from_morphemes(logits: torch.Tensor, morphemes: list, edges: list, silent: bool = False) -> list:
348
+ """Edge-based Viterbi decoding.
349
+
350
+ Args:
351
+ logits: Logits per morpheme.
352
+ morphemes: List of morpheme records.
353
+ edges: Edge list among morpheme indices.
354
+ silent: If True, suppress debug prints.
355
+
356
+ Returns:
357
+ Indices of morphemes on the optimal path.
358
+ """
359
+ if len(logits) != len(morphemes):
360
+ if not silent:
361
+ print(f"Warning: #logits ({len(logits)}) != #morphemes ({len(morphemes)})")
362
+ return list(range(min(len(logits), len(morphemes))))
363
+
364
+ if not silent:
365
+ print("\n=== Viterbi Decode ===")
366
+ print(f"#Morphemes: {len(morphemes)}")
367
+ print(f"Using edge info: {len(edges)} edges")
368
+
369
+ print("\nNode logits:")
370
+ for idx, (morph, logit) in enumerate(zip(morphemes, logits)):
371
+ print(
372
+ f" [{idx:3d}] {morph['surface']:10s} ({morph['start_pos']:2d}-{morph['end_pos']:2d}) {morph['pos']:10s} logit={logit:.3f}"
373
+ )
374
+
375
+ # Build adjacency from edges (forward edges only)
376
+ n = len(morphemes)
377
+ adj_list = [[] for _ in range(n)]
378
+ for edge in edges:
379
+ source_idx = edge["source_idx"]
380
+ target_idx = edge["target_idx"]
381
+ if 0 <= source_idx < n and 0 <= target_idx < n:
382
+ # Add forward edges only (source.end_pos <= target.start_pos)
383
+ source_end = morphemes[source_idx].get("end_pos", 0)
384
+ target_start = morphemes[target_idx].get("start_pos", 0)
385
+ if source_end <= target_start:
386
+ adj_list[source_idx].append(target_idx)
387
+
388
+ # POS to UD mapping (for display)
389
+ pos_to_ud = {
390
+ "名詞": "NOUN",
391
+ "動詞": "VERB",
392
+ "形容詞": "ADJ",
393
+ "副詞": "ADV",
394
+ "助詞": "ADP", # approximate
395
+ "助動詞": "AUX",
396
+ "接続詞": "CCONJ",
397
+ "連体詞": "DET",
398
+ "感動詞": "INTJ",
399
+ "代名詞": "PRON",
400
+ "形状詞": "ADJ",
401
+ "補助記号": "PUNCT",
402
+ "接頭辞": "PREFIX",
403
+ "接尾辞": "SUFFIX",
404
+ }
405
+
406
+ if not silent:
407
+ print("\nMorpheme details:")
408
+ for i, morpheme in enumerate(morphemes):
409
+ start_pos = morpheme.get("start_pos", 0)
410
+ end_pos = morpheme.get("end_pos", 0)
411
+ surface = morpheme.get("surface", "")
412
+ logit = morpheme.get("logit", 0.0)
413
+ pos = morpheme.get("pos", "")
414
+ pos_main = pos.split(",")[0] if "," in pos else pos
415
+ ud_pos = pos_to_ud.get(pos_main, "X")
416
+ print(f" {i}: {surface} ({start_pos}-{end_pos}) {pos_main}({ud_pos}) logit={logit:.3f}")
417
+
418
+ # Dynamic programming
419
+ dp = [-float("inf")] * n # max score to each node
420
+ parent = [-1] * n # best predecessor per node
421
+
422
+ # Find start nodes (earliest start position)
423
+ start_nodes = []
424
+ min_start_pos = min(m.get("start_pos", 0) for m in morphemes)
425
+ for i, m in enumerate(morphemes):
426
+ if m.get("start_pos", 0) == min_start_pos:
427
+ start_nodes.append(i)
428
+
429
+ # Initialize start nodes
430
+ for i in start_nodes:
431
+ dp[i] = morphemes[i].get("logit", 0.0)
432
+
433
+ # Process nodes in position order (topological-like)
434
+ node_positions = [(i, morphemes[i].get("start_pos", 0), morphemes[i].get("end_pos", 0)) for i in range(n)]
435
+ node_positions.sort(key=lambda x: (x[1], x[2])) # sort by start_pos, end_pos
436
+
437
+ # Relax edges for each node in order
438
+ for node_idx, _, _ in node_positions:
439
+ if dp[node_idx] == -float("inf"):
440
+ continue # unreachable node
441
+
442
+ # Relax transitions to reachable next nodes
443
+ for next_idx in adj_list[node_idx]:
444
+ new_score = dp[node_idx] + morphemes[next_idx].get("logit", 0.0)
445
+ if new_score > dp[next_idx]:
446
+ dp[next_idx] = new_score
447
+ parent[next_idx] = node_idx
448
+
449
+ # Select best end node at the final position
450
+ end_nodes = []
451
+ max_end_pos = max(m.get("end_pos", 0) for m in morphemes)
452
+ for i, m in enumerate(morphemes):
453
+ if m.get("end_pos", 0) == max_end_pos:
454
+ end_nodes.append(i)
455
+
456
+ best_end_idx = -1
457
+ best_score = -float("inf")
458
+ for i in end_nodes:
459
+ if dp[i] > best_score:
460
+ best_score = dp[i]
461
+ best_end_idx = i
462
+
463
+ # Backtracking with safety cap to avoid infinite loops
464
+ path = []
465
+ current = best_end_idx
466
+ max_iterations = n * 2 # safety cap
467
+ iteration_count = 0
468
+ visited = set()
469
+
470
+ while current != -1 and iteration_count < max_iterations:
471
+ if current in visited:
472
+ print(f"Warning: Detected cycle during backtracking (node {current})")
473
+ break
474
+ visited.add(current)
475
+ path.append(current)
476
+ current = parent[current]
477
+ iteration_count += 1
478
+
479
+ if iteration_count >= max_iterations:
480
+ print(f"Warning: Backtracking reached max iterations ({max_iterations})")
481
+
482
+ path.reverse()
483
+
484
+ # Display
485
+ if path:
486
+ total_score = sum(morphemes[idx].get("logit", 0.0) for idx in path)
487
+ if not silent:
488
+ print(f"\nOptimal path (total score: {total_score:.3f}):")
489
+ for idx in path:
490
+ morpheme = morphemes[idx]
491
+ logit = morpheme.get("logit", 0.0)
492
+ print(f" {morpheme['surface']} (logit: {logit:.3f})")
493
+
494
+ return path
495
+
496
+
497
+ ##
498
+
499
+
500
+ # Global singletons (lazy initialization)
501
+ _analyzer = None
502
+ _data_module_cache = {}
503
+
504
+
505
+ def predict_morphemes_from_text(text, model=None, experiment_info=None, silent=False):
506
+ """Predict morpheme boundaries from text.
507
+
508
+ Steps:
509
+ 1. Analyze with MeCab to get candidates.
510
+ 2. Build nodes/edges from morphemes and connections.
511
+ 3. Run the model to get per-node scores.
512
+ 4. Run Viterbi decoding over nodes and edges.
513
+
514
+ Args:
515
+ text: Input text.
516
+ model: Model to use.
517
+ experiment_info: Experiment metadata.
518
+ silent: If True, suppress prints.
519
+ """
520
+ global _analyzer
521
+
522
+ if model is None:
523
+ result = load_model()
524
+ if result is None:
525
+ return [], []
526
+ model, experiment_info = result
527
+
528
+ if not silent:
529
+ print(f"Input text: {text}")
530
+
531
+ # 1) Get morpheme candidates (initialize analyzer on first use)
532
+ if _analyzer is None:
533
+ _analyzer = MeCabAnalyzer()
534
+
535
+ # Fetch candidates directly via analyzer and deduplicate
536
+ candidates = _analyzer.get_morpheme_candidates(text)
537
+ candidates = normalize_mecab_candidates(candidates)
538
+ candidates = dedup_morphemes(candidates)
539
+
540
+ if not candidates:
541
+ print("Error: Failed to obtain morpheme candidates")
542
+ return [], []
543
+
544
+ if not silent:
545
+ print(f"#Candidates: {len(candidates)}")
546
+
547
+ # 2) Use candidates as morphemes
548
+ morphemes = candidates
549
+
550
+ # Validate type
551
+ if not isinstance(morphemes, list):
552
+ print(f"Warning: morphemes is not a list: {type(morphemes)}")
553
+ morphemes = []
554
+
555
+ # Add lexical features using the shared DataModule implementation
556
+ dm_tmp = DataModule(annotations_dir="dummy", batch_size=1, num_workers=0, lexical_feature_dim=100000, silent=True)
557
+ morphemes = dm_tmp.compute_lexical_features(morphemes, text)
558
+
559
+ # Build edges (adjacent only)
560
+ edges = build_adjacent_edges(morphemes)
561
+
562
+ # Add annotation field as '?' for inference
563
+ for morpheme in morphemes:
564
+ if "annotation" not in morpheme:
565
+ morpheme["annotation"] = "?"
566
+
567
+ if not silent:
568
+ print(f"Unified graph: {len(morphemes)} nodes, {len(edges)} edges")
569
+
570
+ # 3) Initialize DataModule per experiment settings
571
+ features_config = experiment_info["config"].get("features", {})
572
+ training_config = experiment_info["config"].get("training", {})
573
+ edge_config = experiment_info["config"].get("edge_features", {})
574
+
575
+ # Cache DataModule by annotations_dir
576
+ global _data_module_cache
577
+ cache_key = str(training_config.get("annotations_dir", "annotations_kwdlc"))
578
+
579
+ if cache_key not in _data_module_cache:
580
+ # Always use lexical features
581
+ _data_module_cache[cache_key] = DataModule(
582
+ annotations_dir=training_config.get("annotations_dir", "annotations_kwdlc"),
583
+ batch_size=1,
584
+ num_workers=0,
585
+ silent=silent,
586
+ lexical_feature_dim=features_config.get("lexical_feature_dim", 100000),
587
+ use_bidirectional_edges=edge_config.get("use_bidirectional_edges", True),
588
+ )
589
+
590
+ data_module = _data_module_cache[cache_key]
591
+
592
+ # Build graph using the same public API as preprocessing
593
+ graph = data_module.create_graph_from_morphemes_data(
594
+ morphemes=morphemes,
595
+ edges=edges,
596
+ text=text,
597
+ for_training=False,
598
+ )
599
+
600
+ if graph is None:
601
+ print("Error: Failed to create PyTorch graph")
602
+ return [], []
603
+
604
+ # Inference
605
+
606
+ # Device (CPU by default)
607
+ device = torch.device("cpu")
608
+
609
+ # Respect explicit device from experiment_info if present
610
+ if experiment_info and "device" in experiment_info:
611
+ device = experiment_info["device"]
612
+
613
+ with torch.no_grad():
614
+ # Ensure lexical feature tensors exist
615
+ if not hasattr(graph, "lexical_indices") or graph.lexical_indices is None:
616
+ print("Error: lexical_indices not found")
617
+ return [], []
618
+
619
+ logits = model(
620
+ graph.lexical_indices.to(device), # lexical_indices
621
+ graph.lexical_values.to(device), # lexical_values
622
+ graph.edge_index.to(device),
623
+ None,
624
+ graph.edge_attr.to(device) if graph.edge_attr is not None else None,
625
+ ).squeeze()
626
+
627
+ if logits.dim() == 0:
628
+ logits = logits.unsqueeze(0)
629
+ probabilities = torch.sigmoid(logits)
630
+ predictions = (probabilities >= 0.5).float()
631
+
632
+ # Move back to CPU for post-processing
633
+ logits = logits.cpu()
634
+ probabilities = probabilities.cpu()
635
+ predictions = predictions.cpu()
636
+
637
+ # Attach predictions to morphemes
638
+ for i, morpheme in enumerate(morphemes):
639
+ if i < len(predictions):
640
+ morpheme["predicted_annotation"] = "+" if predictions[i] == 1 else "-"
641
+ morpheme["logit"] = logits[i].item()
642
+ morpheme["probability"] = probabilities[i].item()
643
+
644
+ # 4) Viterbi decode over nodes/edges (no CRF)
645
+ optimal_path = viterbi_decode_from_morphemes(logits, morphemes, edges, silent=silent)
646
+
647
+ # Format results
648
+ results = []
649
+ for i, morpheme in enumerate(morphemes):
650
+ is_in_optimal_path = optimal_path and i in optimal_path
651
+
652
+ result = {
653
+ "surface": morpheme["surface"],
654
+ "pos": morpheme["pos"],
655
+ "reading": morpheme["reading"],
656
+ "predicted_annotation": morpheme.get("predicted_annotation", "?"),
657
+ "logit": morpheme.get("logit", 0.0),
658
+ "probability": morpheme.get("probability", 0.5),
659
+ "in_optimal_path": is_in_optimal_path,
660
+ }
661
+
662
+ results.append(result)
663
+
664
+ # Collect morphemes on the optimal path
665
+ optimal_morphemes = []
666
+ if optimal_path:
667
+ # Count candidates per span
668
+ position_candidates = {}
669
+ for i, m in enumerate(morphemes):
670
+ pos_key = (m.get("start_pos", 0), m.get("end_pos", 0))
671
+ if pos_key not in position_candidates:
672
+ position_candidates[pos_key] = []
673
+ position_candidates[pos_key].append(i)
674
+
675
+ for idx in optimal_path:
676
+ if idx < len(morphemes):
677
+ morph = morphemes[idx].copy()
678
+ # Add candidate count and selected rank for this span
679
+ pos_key = (morph.get("start_pos", 0), morph.get("end_pos", 0))
680
+ if pos_key in position_candidates:
681
+ candidates_at_pos = position_candidates[pos_key]
682
+ morph["num_candidates"] = len(candidates_at_pos)
683
+ morph["selected_rank"] = candidates_at_pos.index(idx) + 1 if idx in candidates_at_pos else 0
684
+ optimal_morphemes.append(morph)
685
+
686
+ return results, optimal_morphemes
687
+
688
+
689
+ def print_results(results, optimal_morphemes=None, verbose: bool = False):
690
+ """Print morphemes in MeCab-like format (surface\tCSV features)."""
691
+ if not results:
692
+ return
693
+
694
+ def mecab_features(m):
695
+ pos = m.get("pos", "*")
696
+ pos1 = m.get("pos_detail1", "*")
697
+ pos2 = m.get("pos_detail2", "*")
698
+ ctype = m.get("inflection_type", "*")
699
+ cform = m.get("inflection_form", "*")
700
+ base = m.get("base_form", m.get("lemma", "*")) or "*"
701
+ reading = m.get("reading", "*") or "*"
702
+ return f"{pos},{pos1},{pos2},{ctype},{cform},{base},{reading}"
703
+
704
+ items = (
705
+ optimal_morphemes
706
+ if optimal_morphemes
707
+ else [
708
+ {
709
+ "surface": r.get("surface", ""),
710
+ "pos": r.get("pos", "*"),
711
+ "pos_detail1": "*",
712
+ "pos_detail2": "*",
713
+ "inflection_type": "*",
714
+ "inflection_form": "*",
715
+ "base_form": r.get("surface", ""),
716
+ "reading": r.get("reading", "*"),
717
+ }
718
+ for r in results
719
+ ]
720
+ )
721
+
722
+ for m in items:
723
+ print(f"{m.get('surface', '')}\t{mecab_features(m)}")
724
+ print("EOS")
725
+
726
+
727
+ def main():
728
+ """Main inference entrypoint."""
729
+ import argparse
730
+
731
+ parser = argparse.ArgumentParser(description="Mecari morphological analysis inference")
732
+ parser.add_argument("--text", "-t", help="Input text directly")
733
+ parser.add_argument("--experiment", "-e", help="Experiment name to load (e.g., gat_20250730_145624)")
734
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output (include UD POS)")
735
+ args = parser.parse_args()
736
+
737
+ if args.experiment:
738
+ result = load_model(experiment_name=args.experiment)
739
+ else:
740
+ result = load_model()
741
+
742
+ if result is None:
743
+ return
744
+
745
+ model, experiment_info = result
746
+
747
+ if args.text:
748
+ result = predict_morphemes_from_text(args.text, model, experiment_info, silent=not args.verbose)
749
+ if result:
750
+ results, optimal_morphemes = result
751
+ print_results(results, optimal_morphemes, verbose=args.verbose)
752
+ else:
753
+ print("Inference failed.")
754
+
755
+ else:
756
+ print("\nMecari morphological inference")
757
+ print("Enter text (e.g., Tokyo is nice)")
758
+ print("Type 'quit' or 'exit' to finish.\n")
759
+
760
+ while True:
761
+ try:
762
+ user_input = input("Input: ").strip()
763
+
764
+ if user_input.lower() in ["quit", "exit", "q"]:
765
+ print("Exiting.")
766
+ break
767
+
768
+ if not user_input:
769
+ continue
770
+
771
+ print(f"Text: {user_input}")
772
+
773
+ result = predict_morphemes_from_text(user_input, model, experiment_info, silent=not args.verbose)
774
+ if result:
775
+ results, optimal_morphemes = result
776
+ print_results(results, optimal_morphemes, verbose=args.verbose)
777
+ else:
778
+ print("Inference failed.")
779
+
780
+ print()
781
+
782
+ except EOFError:
783
+ print("\nExiting.")
784
+ break
785
+ except KeyboardInterrupt:
786
+ print("\nExiting.")
787
+ break
788
+ except Exception as e:
789
+ import traceback
790
+
791
+ print(f"\nAn error occurred: {e}")
792
+ traceback.print_exc()
793
+ continue
794
+
795
+
796
+ if __name__ == "__main__":
797
+ main()
mecari/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Mecari - Japanese Morphological Analysis with Graph Neural Networks"""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ # Export minimal API (avoid heavy imports at package import time)
6
+ from mecari.config.config import get_model_config, load_config, override_config, save_config # noqa: F401
7
+ from mecari.data.data_module import DataModule # noqa: F401
8
+
9
+ __all__ = ["DataModule", "get_model_config", "override_config", "save_config", "load_config"]
mecari/analyzers/mecab.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+ from typing import Dict, List
8
+
9
+ from mecari.utils.signature import signature_key
10
+
11
+
12
+ def _byte_to_char_map(text: str) -> dict[int, int]:
13
+ mapping: dict[int, int] = {}
14
+ cpos = 0
15
+ bpos = 0
16
+ for ch in text:
17
+ mapping[bpos] = cpos
18
+ bpos += len(ch.encode("utf-8"))
19
+ cpos += 1
20
+ mapping[bpos] = cpos
21
+ return mapping
22
+
23
+
24
+ class MeCabAnalyzer:
25
+ """Obtain morpheme candidates for building graph.
26
+
27
+ Args:
28
+ jumandic_path: Filesystem path to the JUMANDIC dictionary used by MeCab.
29
+ mecab_bin: Optional MeCab binary name or full path. If None, resolves
30
+ from the MECAB_BIN environment variable or defaults to "mecab".
31
+
32
+ Methods:
33
+ version(): Return the MeCab version string, or an empty string on error.
34
+ get_morpheme_candidates(text): Analyze text and return a list of
35
+ morpheme dicts with fields such as:
36
+ - surface, base_form, reading
37
+ - pos, pos_detail1/2/3
38
+ - inflection_type, inflection_form
39
+ - start_pos, end_pos (character offsets)
40
+ Unknown or unavailable values are filled with "*" or empty strings.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ jumandic_path: str | None = None,
46
+ mecab_bin: str | None = None,
47
+ ) -> None:
48
+ # Prefer JUMANDIC if present; otherwise fall back to IPADIC
49
+ if jumandic_path is None:
50
+ candidates = [
51
+ "/var/lib/mecab/dic/juman-utf8",
52
+ "/usr/lib/x86_64-linux-gnu/mecab/dic/juman-utf8",
53
+ ]
54
+ ipadic_candidates = [
55
+ "/var/lib/mecab/dic/ipadic",
56
+ "/usr/lib/x86_64-linux-gnu/mecab/dic/ipadic",
57
+ ]
58
+ chosen = next((p for p in candidates if os.path.isdir(p)), None)
59
+ if chosen is None:
60
+ chosen = next((p for p in ipadic_candidates if os.path.isdir(p)), None)
61
+ self.jumandic_path = chosen # may be None; handled below
62
+ else:
63
+ self.jumandic_path = jumandic_path
64
+
65
+ # Allow selecting a specific mecab binary via arg or env var; default to common path
66
+ self.mecab_bin = mecab_bin or os.getenv("MECAB_BIN") or (
67
+ "/usr/bin/mecab" if os.path.exists("/usr/bin/mecab") else "mecab"
68
+ )
69
+
70
+ def version(self) -> str:
71
+ try:
72
+ out = subprocess.run([self.mecab_bin, "-v"], capture_output=True, text=True)
73
+ return (out.stdout or out.stderr).strip()
74
+ except Exception:
75
+ return ""
76
+
77
+ def get_morpheme_candidates(self, text: str) -> List[Dict]:
78
+ """Return a flat list of JUMANDIC candidates (robust %H format)."""
79
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
80
+ f.write(text)
81
+ temp_file = f.name
82
+ try:
83
+ fmt = "%pi\t%m\t%H\t%ps\t%pe\n"
84
+ cmd = [self.mecab_bin]
85
+ # Pass dictionary only if we have a resolvable path
86
+ if isinstance(self.jumandic_path, str) and os.path.isdir(self.jumandic_path):
87
+ cmd += ["-d", self.jumandic_path]
88
+ cmd += ["-F", fmt, "-E", "", "-a", temp_file]
89
+ result = subprocess.run(cmd, capture_output=True, text=True, encoding="utf-8", errors="ignore")
90
+ stdout = result.stdout
91
+ finally:
92
+ try:
93
+ import os
94
+
95
+ os.unlink(temp_file)
96
+ except Exception:
97
+ pass
98
+ if result.returncode != 0:
99
+ return []
100
+ byte_to_char = _byte_to_char_map(text)
101
+ out: list[dict] = []
102
+ seen = set()
103
+ for line in stdout.strip().split("\n"):
104
+ if not line:
105
+ continue
106
+ parts = line.split("\t")
107
+ if len(parts) < 5:
108
+ continue
109
+ node_id, surface, features, sb, eb = parts[0], parts[1], parts[2], parts[3], parts[4]
110
+ if surface in ("BOS", "EOS"):
111
+ continue
112
+ if not surface.strip():
113
+ continue
114
+ try:
115
+ start_byte = int(sb)
116
+ end_byte = int(eb)
117
+ except ValueError:
118
+ continue
119
+ start_pos = byte_to_char.get(start_byte, 0)
120
+ end_pos = byte_to_char.get(end_byte, len(text))
121
+ fs = features.split(",")
122
+ pos = fs[0] if len(fs) > 0 else "*"
123
+ pos1 = fs[1] if len(fs) > 1 else "*"
124
+ is_conj = pos in ("動詞", "形容詞", "助動詞")
125
+ ctype = fs[2] if len(fs) > 2 and fs[2] != "*" and is_conj else "*"
126
+ cform = fs[3] if len(fs) > 3 and fs[3] != "*" and is_conj else "*"
127
+ pos2 = (fs[2] if len(fs) > 2 else "*") if not is_conj else "*"
128
+ pos3 = (fs[3] if len(fs) > 3 else "*") if not is_conj else "*"
129
+ base = fs[4] if len(fs) > 4 and fs[4] != "*" else ""
130
+ reading = fs[5] if len(fs) > 5 and fs[5] != "*" else ""
131
+ m = {
132
+ "surface": surface,
133
+ "pos": pos,
134
+ "pos_detail1": pos1,
135
+ "pos_detail2": pos2,
136
+ "pos_detail3": pos3,
137
+ "base_form": base,
138
+ "reading": reading,
139
+ "inflection_type": ctype,
140
+ "inflection_form": cform,
141
+ "start_pos": start_pos,
142
+ "end_pos": end_pos,
143
+ "annotation": "?",
144
+ "node_id": node_id,
145
+ }
146
+ key = signature_key(m)
147
+ if key in seen:
148
+ continue
149
+ seen.add(key)
150
+ out.append(m)
151
+ return out
mecari/config/config.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ from typing import Any, Dict
6
+
7
+ import yaml
8
+
9
+
10
+ def load_config(config_path: str) -> Dict[str, Any]:
11
+ """Load a YAML config with inheritance (defaults/extends)."""
12
+ if not os.path.exists(config_path):
13
+ raise FileNotFoundError(f"Config file not found: {config_path}")
14
+
15
+ with open(config_path, "r", encoding="utf-8") as f:
16
+ config = yaml.safe_load(f)
17
+
18
+ # Handle inheritance (Hydra-style defaults or legacy extends)
19
+ if "defaults" in config:
20
+ # Hydra-style defaults (list format)
21
+ defaults = config["defaults"]
22
+ if isinstance(defaults, list):
23
+ base_config = {}
24
+ for default_item in defaults:
25
+ if isinstance(default_item, str):
26
+ base_config_path = default_item
27
+ else:
28
+ continue
29
+
30
+ if not os.path.isabs(base_config_path):
31
+ config_dir = os.path.dirname(config_path)
32
+ base_config_path = os.path.join(config_dir, base_config_path + ".yaml")
33
+
34
+ if os.path.exists(base_config_path):
35
+ loaded = load_config(base_config_path)
36
+ base_config = override_config(base_config, loaded)
37
+
38
+ child_config = {k: v for k, v in config.items() if k != "defaults" and v is not None}
39
+ config = override_config(base_config, child_config)
40
+ elif "extends" in config:
41
+ # Legacy extends format
42
+ base_config_path = config["extends"]
43
+ if not os.path.isabs(base_config_path):
44
+ config_dir = os.path.dirname(config_path)
45
+ base_config_path = os.path.join(config_dir, base_config_path)
46
+
47
+ base_config = load_config(base_config_path)
48
+
49
+ child_config = {k: v for k, v in config.items() if k != "extends" and v is not None}
50
+ config = override_config(base_config, child_config)
51
+
52
+ return config
53
+
54
+
55
+ def get_model_config(model_type: str) -> Dict[str, Any]:
56
+ """Return config for a given model type."""
57
+ config_path = f"configs/{model_type}.yaml"
58
+
59
+ return load_config(config_path)
60
+
61
+
62
+ def override_config(config: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
63
+ """Deep-override config with values from overrides."""
64
+
65
+ def deep_update(base_dict, update_dict):
66
+ for key, value in update_dict.items():
67
+ if isinstance(value, dict) and key in base_dict and isinstance(base_dict[key], dict):
68
+ deep_update(base_dict[key], value)
69
+ else:
70
+ base_dict[key] = value
71
+
72
+ import copy
73
+
74
+ result = copy.deepcopy(config)
75
+ deep_update(result, overrides)
76
+ return result
77
+
78
+
79
+ def save_config(config: Dict[str, Any], output_path: str):
80
+ """Save config as YAML."""
81
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
82
+
83
+ with open(output_path, "w", encoding="utf-8") as f:
84
+ yaml.dump(config, f, default_flow_style=False, allow_unicode=True, indent=2)
mecari/data/data_module.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ from typing import Dict, List, Optional
6
+
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torch_geometric.data import Data, DataLoader
11
+
12
+ # Required import for lexical feature computation
13
+ from mecari.featurizers.lexical import (
14
+ LexicalNGramFeaturizer as LexFeaturizer,
15
+ Morpheme as LexMorpheme,
16
+ )
17
+
18
+
19
+ """Data module for lexical-graph training using prebuilt .pt graphs only."""
20
+
21
+
22
+ # Prebuilt .pt graph dataset
23
+ class _PtGraphDataset(Dataset):
24
+ """Prebuilt PyG graph tensors saved as .pt per sentence.
25
+
26
+ Each file is expected to be a dict with keys:
27
+ - 'graph': torch_geometric.data.Data
28
+ - 'source_id': str (used for split)
29
+ - optional: 'text'
30
+ """
31
+
32
+ def __init__(self, files: List[str]) -> None:
33
+ self.files = files
34
+
35
+ def __len__(self) -> int:
36
+ return len(self.files)
37
+
38
+ def __getitem__(self, idx: int) -> Data:
39
+ path = self.files[idx]
40
+ obj = torch.load(path, map_location="cpu")
41
+ if isinstance(obj, dict) and "graph" in obj:
42
+ data = obj["graph"]
43
+ else:
44
+ data = obj
45
+ if not isinstance(data, Data):
46
+ raise RuntimeError(f"Invalid graph object in: {path}")
47
+ data.data_index = idx
48
+ return data
49
+
50
+
51
+ # Safe globals registration for PyTorch 2.6+
52
+ try:
53
+ import torch.serialization
54
+ from torch_geometric.data.data import DataEdgeAttr
55
+
56
+ torch.serialization.add_safe_globals([DataEdgeAttr, Data])
57
+ except (ImportError, AttributeError):
58
+ pass
59
+
60
+
61
+ class DataModule(pl.LightningDataModule):
62
+ """Loads .pt graphs and builds lexical graph features for training."""
63
+
64
+ def __init__(
65
+ self,
66
+ annotations_dir: str = "annotations",
67
+ batch_size: int = 32,
68
+ num_workers: int = 0,
69
+ max_files: Optional[int] = None,
70
+ use_bidirectional_edges: bool = True,
71
+ annotations_override_dir: Optional[str] = None,
72
+ silent: bool = False,
73
+ lexical_feature_dim: int = 100000,
74
+ lexical_max_features: int = 20,
75
+ ) -> None:
76
+ super().__init__()
77
+ self.annotations_dir = annotations_dir
78
+ self.annotations_override_dir = annotations_override_dir
79
+ self.batch_size = batch_size
80
+ self.num_workers = num_workers
81
+ self.max_files = max_files
82
+ self.use_bidirectional_edges = True
83
+ self.silent = silent
84
+ self.lexical_feature_dim = lexical_feature_dim
85
+ self.lexical_max_features = int(lexical_max_features)
86
+ self.use_bidirectional_edges = bool(use_bidirectional_edges)
87
+
88
+ # Initialized in setup()
89
+ self.train_dataset = []
90
+ self.val_dataset = []
91
+ self.test_dataset = []
92
+ # Eagerly initialize lexical featurizer (small and picklable)
93
+ self._lex_featurizer = LexFeaturizer(dim=int(self.lexical_feature_dim), add_bias=True)
94
+ # POS mapping for evaluation breakdown
95
+ self.pos_to_id = {
96
+ "名詞": 1,
97
+ "動詞": 2,
98
+ "形容詞": 3,
99
+ "副詞": 4,
100
+ "助詞": 5,
101
+ "助動詞": 6,
102
+ "接続詞": 7,
103
+ "連体詞": 8,
104
+ "感動詞": 9,
105
+ "形状詞": 10,
106
+ "補助記号": 11,
107
+ "接頭辞": 12,
108
+ "接尾辞": 13,
109
+ "特殊": 14,
110
+ }
111
+ self.id_to_pos = {v: k for k, v in self.pos_to_id.items()}
112
+
113
+ def create_graph_from_morphemes_data(self, *args, **kwargs) -> Optional[Data]:
114
+ """Create a lexical graph from morpheme data (or candidates)."""
115
+ if "candidates" in kwargs:
116
+ candidates = kwargs.pop("candidates")
117
+ text = kwargs.get("text", "")
118
+ morphemes_edges = self._build_graph_from_candidates(candidates, text)
119
+ if not morphemes_edges:
120
+ return None
121
+ kwargs["morphemes"] = morphemes_edges["morphemes"]
122
+ kwargs["edges"] = morphemes_edges["edges"]
123
+ return self._create_lexical_graph(*args, **kwargs)
124
+
125
+ # --- Lexical features helper (for preprocessing) ---
126
+ def compute_lexical_features(self, morphemes: List[Dict], text: str) -> List[Dict]:
127
+ """Add lexical_features to each morpheme using Mecari's lexical featurizer.
128
+
129
+ Requires mecari.featurizers.lexical to be importable. Raises a clear error
130
+ if the featurizer is unavailable (training/inference depend on it).
131
+ """
132
+ if not morphemes:
133
+ return morphemes
134
+
135
+ for m in morphemes:
136
+ try:
137
+ morph_obj = LexMorpheme(
138
+ surf=m.get("surface", ""),
139
+ lemma=m.get("base_form", ""),
140
+ pos=m.get("pos", "*"),
141
+ pos1=m.get("pos_detail1", "*"),
142
+ ctype=m.get("inflection_type", "*"),
143
+ cform=m.get("inflection_form", "*"),
144
+ reading=m.get("reading", "*"),
145
+ )
146
+ st = m.get("start_pos", 0)
147
+ ed = m.get("end_pos", st + len(m.get("surface", "")))
148
+ prev_char = text[st - 1] if st > 0 else None
149
+ next_char = text[ed] if ed < len(text) else None
150
+ feats = self._lex_featurizer.unigram_feats(morph_obj, prev_char, next_char)
151
+ m["lexical_features"] = feats
152
+ except Exception:
153
+ # on any failure, leave unchanged
154
+ pass
155
+ return morphemes
156
+
157
+ def _create_lexical_graph(
158
+ self, morphemes: List[Dict], edges: List[Dict], text: str, for_training: bool = True
159
+ ) -> Optional[Data]:
160
+ """Build a graph using lexical features."""
161
+ if not morphemes:
162
+ return None
163
+
164
+ # Sparse lexical features per node
165
+ all_indices = []
166
+ all_values = []
167
+ all_lengths = []
168
+ annotations = []
169
+ valid_mask = []
170
+
171
+ max_features = 0
172
+ for morpheme in morphemes:
173
+ lexical_feats = morpheme.get("lexical_features", [])
174
+ indices = []
175
+ values = []
176
+ for idx, val in lexical_feats:
177
+ if 0 <= idx < self.lexical_feature_dim:
178
+ indices.append(idx)
179
+ values.append(val)
180
+ all_lengths.append(len(indices))
181
+ max_features = max(max_features, len(indices))
182
+
183
+ all_indices.append(indices)
184
+ all_values.append(values)
185
+
186
+ if for_training:
187
+ annotation = morpheme.get("annotation", "?")
188
+ if annotation == "+":
189
+ annotations.append(1)
190
+ valid_mask.append(True)
191
+ elif annotation == "-":
192
+ annotations.append(0)
193
+ valid_mask.append(True)
194
+ else:
195
+ annotations.append(0)
196
+ valid_mask.append(False)
197
+
198
+ # Fixed-size padding/truncation for batching
199
+ FIXED_MAX_FEATURES = int(getattr(self, "lexical_max_features", 20))
200
+
201
+ padded_indices = []
202
+ padded_values = []
203
+ for indices, values in zip(all_indices, all_values):
204
+ if len(indices) > FIXED_MAX_FEATURES:
205
+ padded_indices.append(indices[:FIXED_MAX_FEATURES])
206
+ padded_values.append(values[:FIXED_MAX_FEATURES])
207
+ else:
208
+ pad_length = FIXED_MAX_FEATURES - len(indices)
209
+ padded_indices.append(indices + [0] * pad_length)
210
+ padded_values.append(values + [0.0] * pad_length)
211
+
212
+ edge_index = self._build_edge_index(edges, len(morphemes))
213
+
214
+ # POS ids per node (for evaluation breakdown)
215
+ pos_ids = []
216
+ for m in morphemes:
217
+ pos = m.get("pos", "*")
218
+ pos_ids.append(self.pos_to_id.get(pos, 0))
219
+
220
+ graph_data = Data(
221
+ lexical_indices=torch.tensor(padded_indices, dtype=torch.long),
222
+ lexical_values=torch.tensor(padded_values, dtype=torch.float32),
223
+ lexical_lengths=torch.tensor(all_lengths, dtype=torch.long),
224
+ edge_index=edge_index,
225
+ num_nodes=len(morphemes),
226
+ )
227
+ graph_data.pos_ids = torch.tensor(pos_ids, dtype=torch.long)
228
+ if for_training:
229
+ graph_data.y = torch.tensor(annotations, dtype=torch.float32)
230
+ graph_data.valid_mask = torch.tensor(valid_mask, dtype=torch.bool)
231
+
232
+ return graph_data
233
+
234
+ def _build_edge_index(self, edges: List[Dict], num_nodes: int) -> torch.Tensor:
235
+ """Build a PyG edge_index tensor from edge dicts."""
236
+ if not edges:
237
+ return torch.tensor([[], []], dtype=torch.long)
238
+
239
+ source_indices = []
240
+ target_indices = []
241
+
242
+ for edge in edges:
243
+ source = edge.get("source_idx", 0)
244
+ target = edge.get("target_idx", 0)
245
+
246
+ if 0 <= source < num_nodes and 0 <= target < num_nodes:
247
+ source_indices.append(source)
248
+ target_indices.append(target)
249
+ if self.use_bidirectional_edges:
250
+ source_indices.append(target)
251
+ target_indices.append(source)
252
+
253
+ if not source_indices:
254
+ return torch.tensor([[], []], dtype=torch.long)
255
+
256
+ return torch.tensor([source_indices, target_indices], dtype=torch.long)
257
+
258
+ def _load_kwdlc_ids(self, ids_file: str) -> set:
259
+ """Load KWDLC ID list (one ID per line)."""
260
+ ids = set()
261
+ if ids_file and os.path.exists(ids_file):
262
+ with open(ids_file, "r") as f:
263
+ for line in f:
264
+ ids.add(line.strip())
265
+ return ids
266
+
267
+ def load_annotation_data(self, max_files: Optional[int] = None) -> List[Dict]:
268
+ """Detect and list available .pt annotation graph files."""
269
+ if os.path.isdir(self.annotations_dir):
270
+ pt_files = [
271
+ os.path.join(self.annotations_dir, fn)
272
+ for fn in sorted(os.listdir(self.annotations_dir))
273
+ if fn.endswith(".pt")
274
+ ]
275
+ if pt_files:
276
+ if max_files is not None:
277
+ pt_files = pt_files[:max_files]
278
+ return [{"_mode": "pt", "_pt_files": pt_files}]
279
+ raise FileNotFoundError(f"No annotation graphs found under: {self.annotations_dir}")
280
+
281
+ def setup(self, stage: Optional[str] = None) -> None:
282
+ """Build train/val/test datasets from discovered .pt files."""
283
+ annotation_data = self.load_annotation_data(max_files=self.max_files)
284
+
285
+ if not annotation_data:
286
+ self.train_dataset = []
287
+ self.val_dataset = []
288
+ self.test_dataset = []
289
+ return
290
+
291
+ dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
292
+ test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
293
+
294
+ mode = annotation_data[0].get("_mode")
295
+ if mode == "pt":
296
+ files: List[str] = annotation_data[0]["_pt_files"]
297
+ train_files: List[str] = []
298
+ val_files: List[str] = []
299
+ test_files: List[str] = []
300
+
301
+ # Use KWDLC split ids (mandatory)
302
+ dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
303
+ test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
304
+
305
+ for fp in files:
306
+ sid = None
307
+ try:
308
+ obj = torch.load(fp, map_location="cpu")
309
+ if isinstance(obj, dict):
310
+ sid = obj.get("source_id")
311
+ except Exception:
312
+ pass
313
+ if sid and (dev_ids or test_ids):
314
+ if sid in test_ids:
315
+ test_files.append(fp)
316
+ elif sid in dev_ids:
317
+ val_files.append(fp)
318
+ else:
319
+ train_files.append(fp)
320
+ else:
321
+ train_files.append(fp)
322
+
323
+ # Build datasets strictly based on KWDLC dev/test ids
324
+ self.train_dataset = _PtGraphDataset(train_files)
325
+ self.val_dataset = _PtGraphDataset(val_files)
326
+ self.test_dataset = _PtGraphDataset(test_files)
327
+
328
+ if len(self.val_dataset) == 0 or len(self.test_dataset) == 0:
329
+ raise RuntimeError(
330
+ "KWDLC dev/test split produced empty val/test datasets. Ensure KWDLC id files exist and source_id is set in .pt files."
331
+ )
332
+ else:
333
+ raise RuntimeError("Unsupported annotation mode; expected pt")
334
+
335
+ print(
336
+ f"Data split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}"
337
+ )
338
+
339
+ def _create_dataloader(self, dataset: List[Data], batch_size: int, shuffle: bool = False) -> DataLoader:
340
+ """Create a DataLoader with optional workers/prefetching."""
341
+ return DataLoader(
342
+ dataset,
343
+ batch_size=batch_size,
344
+ shuffle=shuffle,
345
+ num_workers=self.num_workers,
346
+ pin_memory=False,
347
+ persistent_workers=True if self.num_workers > 0 else False,
348
+ prefetch_factor=2 if self.num_workers > 0 else None,
349
+ )
350
+
351
+ def train_dataloader(self) -> DataLoader:
352
+ """Return train DataLoader."""
353
+ return self._create_dataloader(self.train_dataset, self.batch_size, shuffle=True)
354
+
355
+ def val_dataloader(self) -> DataLoader:
356
+ """Return val DataLoader."""
357
+ return self._create_dataloader(self.val_dataset, self.batch_size, shuffle=False)
358
+
359
+ def test_dataloader(self) -> DataLoader:
360
+ """Return test DataLoader."""
361
+ return self._create_dataloader(self.test_dataset, self.batch_size, shuffle=False)
mecari/featurizers/lexical.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+
6
+ # -------- Basic data structures --------
7
+ @dataclass
8
+ class Morpheme:
9
+ surf: str # surface
10
+ lemma: str # lemma (base form)
11
+ pos: str # POS (coarse)
12
+ pos1: str = "*" # POS (fine)
13
+ ctype: str = "*" # conjugation type
14
+ cform: str = "*" # conjugation form
15
+ reading: str = "*" # reading (if any)
16
+
17
+
18
+ # -------- Utilities --------
19
+ def _stable_hash(s: str, dim: int) -> int:
20
+ # md5 stable hash -> lower 8 bytes -> modulo by dim
21
+ d = hashlib.md5(s.encode("utf-8")).digest()
22
+ return int.from_bytes(d[:8], "little") % dim
23
+
24
+
25
+ def _charclass(ch: str) -> str:
26
+ # Simple character classes (for boundary features)
27
+ if not ch:
28
+ return "O"
29
+ try:
30
+ o = ord(ch)
31
+ except Exception:
32
+ return "O"
33
+ if 0x3040 <= o <= 0x309F:
34
+ return "H" # hiragana
35
+ if 0x30A0 <= o <= 0x30FF:
36
+ return "K" # katakana
37
+ if 0x4E00 <= o <= 0x9FFF or 0x3400 <= o <= 0x4DBF:
38
+ return "C" # kanji
39
+ if 0x0030 <= o <= 0x0039 or 0xFF10 <= o <= 0xFF19:
40
+ return "D" # digits
41
+ if 0x0041 <= o <= 0x007A or 0xFF21 <= o <= 0xFF5A:
42
+ return "A" # letters
43
+ if ch.isspace():
44
+ return "S"
45
+ return "O" # other
46
+
47
+
48
+ def _affix(s: str, n: int) -> str:
49
+ return s[:n] if len(s) >= n else s
50
+
51
+
52
+ def _suffix(s: str, n: int) -> str:
53
+ return s[-n:] if len(s) >= n else s
54
+
55
+
56
+ # -------- Lexical n-gram featurizer --------
57
+ class LexicalNGramFeaturizer:
58
+ """Build unigram + boundary features as (index, value) pairs."""
59
+
60
+ def __init__(self, dim: int = 1_000_000, add_bias: bool = True):
61
+ self.dim = dim
62
+ self.add_bias = add_bias
63
+
64
+ def _push(self, feats: List[Tuple[int, float]], key: str, val: float = 1.0):
65
+ feats.append((_stable_hash(key, self.dim), val))
66
+
67
+ def unigram_feats(self, m: Morpheme, prev_char: Optional[str], next_char: Optional[str]) -> List[Tuple[int, float]]:
68
+ f: List[Tuple[int, float]] = []
69
+ # POS
70
+ self._push(f, f"U:POS={m.pos}")
71
+ self._push(f, f"U:POS1={m.pos}:{m.pos1}")
72
+ # Lexicalized (surface/lemma) + POS
73
+ self._push(f, f"U:LEM={m.lemma}")
74
+ self._push(f, f"U:SURF={m.surf}")
75
+ self._push(f, f"U:LEM+POS={m.lemma}|{m.pos}")
76
+ self._push(f, f"U:SURF+POS1={m.surf}|{m.pos}:{m.pos1}")
77
+ # Conjugation
78
+ self._push(f, f"U:CFORM={m.ctype}:{m.cform}")
79
+ # Reading (coarse)
80
+ if m.reading and m.reading != "*":
81
+ self._push(f, f"U:READ={m.reading}")
82
+ # Prefix/Suffix (string n-grams)
83
+ self._push(f, f"U:PREF2={_affix(m.surf, 2)}")
84
+ self._push(f, f"U:SUF2={_suffix(m.surf, 2)}")
85
+ # Boundary char types (1 char left/right)
86
+ if prev_char:
87
+ self._push(f, f"U:BTYPE_L={_charclass(prev_char)}->{_charclass(m.surf[:1])}")
88
+ if next_char:
89
+ self._push(f, f"U:BTYPE_R={_charclass(m.surf[-1:])}->{_charclass(next_char)}")
90
+ if self.add_bias:
91
+ self._push(f, "U:BIAS")
92
+ return f
93
+
94
+ def featurize_sequence(
95
+ self, morphs: List[Morpheme], raw_sentence: Optional[str] = None
96
+ ) -> List[Dict[str, List[Tuple[int, float]]]]:
97
+ if raw_sentence is None:
98
+ raw_sentence = "".join(m.surf for m in morphs)
99
+ spans = []
100
+ cur = 0
101
+ for m in morphs:
102
+ st, ed = cur, cur + len(m.surf)
103
+ spans.append((st, ed))
104
+ cur = ed
105
+
106
+ for i, m in enumerate(morphs):
107
+ st, ed = spans[i]
108
+ prev_char = raw_sentence[st - 1] if st > 0 else None
109
+ next_char = raw_sentence[ed] if ed < len(raw_sentence) else None
110
+ feats_node = self.unigram_feats(m, prev_char, next_char)
111
+
112
+ return feats_node
113
+
114
+
115
+ if __name__ == "__main__":
116
+ pass
mecari/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from .base import BaseMecariGNN # noqa: F401
5
+ from .gatv2 import MecariGATv2 # noqa: F401
6
+ __all__ = ["BaseMecariGNN", "MecariGATv2"]
mecari/models/base.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """Base model with lexical features only."""
4
+
5
+ from typing import Optional
6
+
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class BaseMecariGNN(pl.LightningModule):
14
+ """Base class for Mecari morpheme GNNs."""
15
+
16
+ def __init__(
17
+ self,
18
+ hidden_dim: int = 512,
19
+ num_classes: int = 1,
20
+ learning_rate: float = 1e-3,
21
+ lexical_feature_dim: int = 100000,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.save_hyperparameters()
25
+
26
+ self.hidden_dim = hidden_dim
27
+ self.num_classes = num_classes
28
+ self.learning_rate = learning_rate
29
+ self.lexical_feature_dim = lexical_feature_dim
30
+
31
+ self.lexical_embedding = nn.Embedding(
32
+ num_embeddings=lexical_feature_dim, embedding_dim=hidden_dim, padding_idx=0, sparse=False
33
+ )
34
+ nn.init.xavier_uniform_(self.lexical_embedding.weight[1:])
35
+ self.lexical_embedding.weight.data[0].fill_(0)
36
+
37
+ self.lexical_norm = nn.LayerNorm(hidden_dim)
38
+ self.lexical_dropout = nn.Dropout(0.2)
39
+
40
+ self.residual_proj = nn.Linear(hidden_dim, hidden_dim)
41
+
42
+ self.node_classifier = nn.Sequential(
43
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, 1)
44
+ )
45
+
46
+ def _process_features(
47
+ self, lexical_indices: torch.Tensor, lexical_values: torch.Tensor, bert_features: Optional[torch.Tensor] = None
48
+ ) -> torch.Tensor:
49
+ """Process lexical features."""
50
+ embedded = self.lexical_embedding(lexical_indices)
51
+ weighted = embedded * lexical_values.unsqueeze(-1)
52
+ aggregated = weighted.sum(dim=1)
53
+ processed = self.lexical_dropout(self.lexical_norm(aggregated))
54
+ return processed
55
+
56
+ def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None):
57
+ """Forward pass (implemented in subclasses)."""
58
+ raise NotImplementedError("Subclasses must implement forward method")
59
+
60
+ def training_step(self, batch, batch_idx):
61
+ node_predictions = self(
62
+ batch.lexical_indices,
63
+ batch.lexical_values,
64
+ batch.edge_index,
65
+ None,
66
+ batch.edge_attr if hasattr(batch, "edge_attr") else None,
67
+ ).squeeze()
68
+
69
+ valid_mask = batch.valid_mask
70
+ valid_predictions = node_predictions[valid_mask]
71
+ valid_targets = batch.y[valid_mask]
72
+
73
+ loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="train")
74
+
75
+ with torch.no_grad():
76
+ pred_probs = torch.sigmoid(valid_predictions)
77
+ pred_binary = (pred_probs > 0.5).float()
78
+ correct = (pred_binary == valid_targets).sum()
79
+ accuracy = correct / valid_targets.numel()
80
+ error_rate = 1.0 - accuracy
81
+
82
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
83
+ self.log("train_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
84
+
85
+ if self.trainer and self.trainer.optimizers:
86
+ current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
87
+ self.log("learning_rate", current_lr, on_step=True, on_epoch=False)
88
+
89
+ return loss
90
+
91
+ def validation_step(self, batch, batch_idx):
92
+ node_predictions = self(
93
+ batch.lexical_indices,
94
+ batch.lexical_values,
95
+ batch.edge_index,
96
+ None,
97
+ batch.edge_attr if hasattr(batch, "edge_attr") else None,
98
+ ).squeeze()
99
+
100
+ valid_mask = batch.valid_mask
101
+ valid_predictions = node_predictions[valid_mask]
102
+ valid_targets = batch.y[valid_mask]
103
+
104
+ loss = self._compute_bce_loss(valid_predictions, valid_targets, stage="val")
105
+
106
+ with torch.no_grad():
107
+ pred_probs = torch.sigmoid(valid_predictions)
108
+ pred_binary = (pred_probs > 0.5).float()
109
+ correct = (pred_binary == valid_targets).sum()
110
+ accuracy = correct / valid_targets.numel()
111
+ error_rate = 1.0 - accuracy
112
+
113
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
114
+ self.log("val_error", error_rate, prog_bar=True, on_step=True, on_epoch=True)
115
+
116
+ self.log("val_loss_epoch", loss, on_step=False, on_epoch=True)
117
+ self.log("val_error_epoch", error_rate, on_step=False, on_epoch=True)
118
+
119
+ return loss
120
+
121
+ def configure_optimizers(self):
122
+ """Configure optimizer."""
123
+ optimizer_config = getattr(self, "training_config", {}).get("optimizer", {})
124
+ optimizer_type = optimizer_config.get("type", "adamw")
125
+
126
+ if optimizer_type == "adamw":
127
+ optimizer = torch.optim.AdamW(
128
+ self.parameters(), lr=self.learning_rate, weight_decay=optimizer_config.get("weight_decay", 0.01)
129
+ )
130
+ else:
131
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
132
+ # Optional warmup scheduler (linear warmup to base LR)
133
+ tc = getattr(self, "training_config", {}) or {}
134
+ warmup_steps = int(tc.get("warmup_steps", 0) or 0)
135
+ warmup_start_lr = float(tc.get("warmup_start_lr", 0.0) or 0.0)
136
+ if warmup_steps > 0 and self.learning_rate > 0.0:
137
+ start_factor = max(0.0, min(1.0, warmup_start_lr / float(self.learning_rate)))
138
+
139
+ def lr_lambda(step: int):
140
+ if step <= 0:
141
+ return start_factor
142
+ if step < warmup_steps:
143
+ return start_factor + (1.0 - start_factor) * (step / float(warmup_steps))
144
+ return 1.0
145
+
146
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
147
+ return {
148
+ "optimizer": optimizer,
149
+ "lr_scheduler": {
150
+ "scheduler": scheduler,
151
+ "interval": "step",
152
+ "frequency": 1,
153
+ "name": "linear_warmup",
154
+ },
155
+ }
156
+ return {"optimizer": optimizer}
157
+
158
+ def test_step(self, batch, batch_idx):
159
+ node_predictions = self(
160
+ batch.lexical_indices,
161
+ batch.lexical_values,
162
+ batch.edge_index,
163
+ None,
164
+ batch.edge_attr if hasattr(batch, "edge_attr") else None,
165
+ ).squeeze()
166
+
167
+ valid_mask = batch.valid_mask
168
+ valid_predictions = node_predictions[valid_mask]
169
+ valid_targets = batch.y[valid_mask]
170
+
171
+ with torch.no_grad():
172
+ pred_probs = torch.sigmoid(valid_predictions)
173
+ pred_binary = (pred_probs > 0.5).float()
174
+ correct = (pred_binary == valid_targets).sum()
175
+ accuracy = correct / valid_targets.numel()
176
+ error_rate = 1.0 - accuracy
177
+
178
+ self.log("test_error", error_rate, on_step=False, on_epoch=True)
179
+ self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
180
+
181
+ return error_rate
182
+
183
+ def _compute_bce_loss(self, logits: torch.Tensor, targets: torch.Tensor, stage: str = "train") -> torch.Tensor:
184
+ """BCEWithLogits loss with optional label smoothing and pos_weight.
185
+
186
+ - label_smoothing: smooth targets toward 0.5 by epsilon.
187
+ - pos_weight: handle class imbalance using ratio (neg/pos) per batch, robustly.
188
+ """
189
+ loss_cfg = getattr(self, "training_config", {}).get("loss", {})
190
+ eps = float(loss_cfg.get("label_smoothing", 0.0) or 0.0)
191
+ use_pos_weight = bool(loss_cfg.get("use_pos_weight", True))
192
+
193
+ # Compute pos_weight from unsmoothed targets
194
+ pos = torch.clamp(targets.sum(), min=0.0)
195
+ total = torch.tensor(targets.numel(), device=targets.device, dtype=targets.dtype)
196
+ neg = total - pos
197
+ pos_weight = None
198
+ if use_pos_weight and pos > 0 and neg > 0:
199
+ # pos_weight = neg/pos; clamp to avoid extreme values
200
+ pw = (neg / pos).detach()
201
+ pw = torch.clamp(pw, 0.5, 50.0) # safety bounds
202
+ pos_weight = pw
203
+
204
+ # Apply label smoothing to targets: y' = (1-eps)*y + 0.5*eps
205
+ if eps > 0.0:
206
+ targets = (1.0 - eps) * targets + 0.5 * eps
207
+
208
+ loss = F.binary_cross_entropy_with_logits(
209
+ logits,
210
+ targets,
211
+ pos_weight=pos_weight,
212
+ )
213
+
214
+ return loss
mecari/models/gatv2.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """GATv2 model for morpheme graph classification."""
4
+
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch_geometric.nn import GATv2Conv
8
+ from torch_geometric.utils import add_self_loops, dropout_adj
9
+
10
+ from .base import BaseMecariGNN
11
+
12
+
13
+ class MecariGATv2(BaseMecariGNN):
14
+ """Graph Attention Network v2 for morpheme analysis"""
15
+
16
+ def __init__(
17
+ self,
18
+ hidden_dim: int = 512,
19
+ num_heads: int = 8,
20
+ num_layers: int = 4,
21
+ num_classes: int = 1,
22
+ learning_rate: float = 1e-3,
23
+ lexical_feature_dim: int = 100000,
24
+ share_weights: bool = False, # share-weights option of GATv2
25
+ # New knobs
26
+ dropout: float = 0.1,
27
+ attn_dropout: float = 0.1,
28
+ add_self_loops_flag: bool = True,
29
+ edge_dropout: float = 0.0,
30
+ norm: str = "layer",
31
+ **kwargs, # Ignore extra params for config compatibility
32
+ ):
33
+ super().__init__(
34
+ hidden_dim=hidden_dim,
35
+ num_classes=num_classes,
36
+ learning_rate=learning_rate,
37
+ lexical_feature_dim=lexical_feature_dim,
38
+ )
39
+ self.num_heads = num_heads
40
+ self.num_layers = num_layers
41
+ self.share_weights = share_weights
42
+ self.feat_dropout_p = dropout
43
+ self.attn_dropout_p = attn_dropout
44
+ self.add_self_loops_flag = add_self_loops_flag
45
+ self.edge_dropout_p = edge_dropout
46
+ self.norm_type = (norm or "layer").lower()
47
+
48
+ # GATv2 layers
49
+ self.gatv2_layers = nn.ModuleList()
50
+ self.layer_norms = nn.ModuleList()
51
+
52
+ for i in range(num_layers):
53
+ if i == 0:
54
+ # First layer
55
+ self.gatv2_layers.append(
56
+ GATv2Conv(
57
+ hidden_dim,
58
+ hidden_dim,
59
+ heads=num_heads,
60
+ dropout=self.attn_dropout_p,
61
+ share_weights=share_weights,
62
+ add_self_loops=False,
63
+ )
64
+ )
65
+ elif i == num_layers - 1:
66
+ # Last layer - single head
67
+ self.gatv2_layers.append(
68
+ GATv2Conv(
69
+ hidden_dim * num_heads,
70
+ hidden_dim,
71
+ heads=1,
72
+ concat=False,
73
+ dropout=self.attn_dropout_p,
74
+ share_weights=share_weights,
75
+ add_self_loops=False,
76
+ )
77
+ )
78
+ else:
79
+ # Middle layers
80
+ self.gatv2_layers.append(
81
+ GATv2Conv(
82
+ hidden_dim * num_heads,
83
+ hidden_dim,
84
+ heads=num_heads,
85
+ dropout=self.attn_dropout_p,
86
+ share_weights=share_weights,
87
+ add_self_loops=False,
88
+ )
89
+ )
90
+
91
+ # Layer normalization (all layers)
92
+ if i < num_layers - 1:
93
+ self.layer_norms.append(
94
+ nn.LayerNorm(hidden_dim * num_heads)
95
+ if self.norm_type == "layer"
96
+ else nn.BatchNorm1d(hidden_dim * num_heads)
97
+ )
98
+ else:
99
+ self.layer_norms.append(
100
+ nn.LayerNorm(hidden_dim) if self.norm_type == "layer" else nn.BatchNorm1d(hidden_dim)
101
+ )
102
+
103
+ def forward(self, lexical_indices, lexical_values, edge_index, bert_features=None, edge_attr=None):
104
+ """Forward pass of GATv2"""
105
+ x = self._process_features(lexical_indices, lexical_values, bert_features)
106
+
107
+ residual = self.residual_proj(x)
108
+
109
+ ei = edge_index
110
+ if self.add_self_loops_flag:
111
+ ei, _ = add_self_loops(ei, num_nodes=x.size(0))
112
+ if self.edge_dropout_p > 0 and self.training:
113
+ ei, _ = dropout_adj(ei, p=self.edge_dropout_p, force_undirected=False, training=True)
114
+
115
+ # Apply GATv2 layers
116
+ prev = None
117
+ for i in range(self.num_layers):
118
+ prev = x
119
+ x = self.gatv2_layers[i](x, ei)
120
+ x = self.layer_norms[i](x)
121
+
122
+ # Per-layer residual if dimension matches (middle layers)
123
+ if x.shape == prev.shape and i < self.num_layers - 1:
124
+ x = x + prev
125
+
126
+ # Add residual at last layer
127
+ if i == self.num_layers - 1:
128
+ x = x + residual
129
+
130
+ x = F.elu(x)
131
+
132
+ # Dropout except last layer
133
+ if i < self.num_layers - 1:
134
+ x = F.dropout(x, p=self.feat_dropout_p, training=self.training)
135
+
136
+ # Classification
137
+ logits = self.node_classifier(x)
138
+
139
+ return logits
mecari/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ __all__ = []
mecari/utils/morph_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Any, Optional
6
+
7
+ from mecari.utils.signature import signature_key
8
+
9
+
10
+ def dedup_morphemes(morphs: List[Dict]) -> List[Dict]:
11
+ seen = set()
12
+ out: List[Dict] = []
13
+ for m in morphs:
14
+ key = signature_key(m)
15
+ if key in seen:
16
+ continue
17
+ seen.add(key)
18
+ out.append(m)
19
+ out.sort(key=lambda m: (
20
+ m.get("start_pos", 0),
21
+ -(m.get("end_pos", 0) - m.get("start_pos", 0)),
22
+ m.get("surface", ""),
23
+ m.get("reading", ""),
24
+ m.get("pos", "*"),
25
+ ))
26
+ return out
27
+
28
+
29
+ def build_adjacent_edges(morphs: List[Dict]) -> List[Dict]:
30
+ edges: List[Dict] = []
31
+ for i, s in enumerate(morphs):
32
+ for j, t in enumerate(morphs):
33
+ if i >= j:
34
+ continue
35
+ if s.get("end_pos", 0) == t.get("start_pos", 0):
36
+ edges.append({"source_idx": i, "target_idx": j, "edge_type": "forward"})
37
+ return edges
38
+
39
+
40
+ def normalize_mecab_candidates(candidates: List[Dict]) -> List[Dict]:
41
+ """Normalize MeCab candidates consistently for preprocessing/inference.
42
+
43
+ - If surface is digit-only and base_form is empty/missing, set base_form = surface.
44
+ Modifies candidates in place and returns the list for convenience.
45
+ """
46
+ for c in candidates:
47
+ surf = c.get("surface", "")
48
+ bf = c.get("base_form")
49
+ if (bf is None or bf == "") and surf and surf.isdigit():
50
+ c["base_form"] = surf
51
+ return candidates
mecari/utils/signature.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Tuple
6
+
7
+
8
+ def to_katakana(s) -> str:
9
+ """Robust hiragana->katakana conversion for str or sequence.
10
+
11
+ Accepts str, list, tuple; concatenates string elements when a sequence is given.
12
+ Non-string inputs are stringified; None becomes empty string.
13
+ """
14
+ if isinstance(s, (list, tuple)):
15
+ s = "".join(x for x in s if isinstance(x, str))
16
+ elif not isinstance(s, str):
17
+ s = str(s) if s is not None else ""
18
+ out = []
19
+ for ch in s:
20
+ if not ch:
21
+ continue
22
+ o = ord(ch)
23
+ if 0x3041 <= o <= 0x3096:
24
+ out.append(chr(o + 0x60))
25
+ else:
26
+ out.append(ch)
27
+ return "".join(out)
28
+
29
+
30
+ def signature_key(m: Dict) -> Tuple:
31
+ """Stable deduplication key for a morpheme dict (POS up to pos1)."""
32
+ surface = m.get("surface", "")
33
+ pos = m.get("pos", "*")
34
+ pos1 = m.get("pos_detail1", "*")
35
+ base = m.get("base_form") or m.get("lemma") or ""
36
+ read = to_katakana(m.get("reading") or "")
37
+ st = m.get("start_pos", 0)
38
+ ed = m.get("end_pos", st + len(surface))
39
+ return (st, ed, surface, pos, pos1, base, read)
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mecab
2
+ mecab-utils
3
+ libmecab-dev
4
+ mecab-jumandic-utf8
5
+ mecab-ipadic-utf8
preprocess.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Build training graphs from KWDLC with JUMANDIC.
6
+
7
+ Pipeline:
8
+ 1) Read gold morphemes from KNP files
9
+ 2) Parse text with MeCab (JUMANDIC) to get candidate morphemes
10
+ 3) Match candidates to gold and assign annotations ('+', '-', '?')
11
+ 4) Save graph data as .pt
12
+ """
13
+
14
+ import argparse
15
+ from collections import defaultdict
16
+ from pathlib import Path
17
+ from typing import Dict, List
18
+
19
+ import torch
20
+ import yaml
21
+ from tqdm import tqdm
22
+
23
+ from mecari.analyzers.mecab import MeCabAnalyzer
24
+ from mecari.data.data_module import DataModule
25
+ from mecari.featurizers.lexical import LexicalNGramFeaturizer as LexicalFeaturizer
26
+ from mecari.featurizers.lexical import Morpheme
27
+ from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates
28
+
29
+
30
+ def add_lexical_features(morphemes: List[Dict], text: str, feature_dim: int = 100000) -> List[Dict]:
31
+ """Add lexical (index, value) pairs to morphemes. Not used when saving JSON.
32
+
33
+ Kept for backward-compatibility and test equivalence.
34
+ """
35
+ featurizer = LexicalFeaturizer(dim=feature_dim, add_bias=True)
36
+ for m in morphemes:
37
+ surf = m.get("surface", "")
38
+ morph_obj = Morpheme(
39
+ surf=surf,
40
+ lemma=m.get("base_form", surf),
41
+ pos=m.get("pos", "*"),
42
+ pos1=m.get("pos_detail1", "*"),
43
+ ctype="*",
44
+ cform="*",
45
+ reading=m.get("reading", "*"),
46
+ )
47
+ st = m.get("start_pos", 0)
48
+ ed = m.get("end_pos", st + len(surf))
49
+ prev_char = text[st - 1] if st > 0 and st <= len(text) else None
50
+ next_char = text[ed] if ed < len(text) else None
51
+ feats = featurizer.unigram_feats(morph_obj, prev_char, next_char)
52
+ m["lexical_features"] = feats
53
+ return morphemes
54
+
55
+
56
+ def hiragana_to_katakana(text: str) -> str:
57
+ """Convert hiragana to katakana."""
58
+ return "".join([chr(ord(c) + 96) if "ぁ" <= c <= "ん" else c for c in text])
59
+
60
+
61
+ def _load_gold_with_kyoto(knp_path: Path) -> List[Dict]:
62
+ """Load sentences and morphemes from a KNP file using kyoto-reader (required)."""
63
+ try:
64
+ from kyoto_reader import KyotoReader # type: ignore
65
+ except Exception as e: # pragma: no cover
66
+ raise RuntimeError("kyoto-reader is required for gold loading. Install it (pip install kyoto-reader).") from e
67
+
68
+ try:
69
+ try:
70
+ reader = KyotoReader(str(knp_path), n_jobs=0)
71
+ except TypeError:
72
+ reader = KyotoReader(str(knp_path))
73
+ sents: List[Dict] = []
74
+ for doc in reader.process_all_documents(n_jobs=0):
75
+ if doc is None:
76
+ continue
77
+ for sent in doc.sentences:
78
+ text = sent.surf
79
+ morphemes: List[Dict] = []
80
+ pos = 0
81
+ for mrph in sent.mrph_list():
82
+ surf = getattr(mrph, "midasi", "") or ""
83
+ read = getattr(mrph, "yomi", surf) or surf
84
+ lemma = getattr(mrph, "genkei", surf) or surf
85
+ pos_main = getattr(mrph, "hinsi", "*") or "*"
86
+ pos1 = getattr(mrph, "bunrui", "*") or "*"
87
+ st = pos
88
+ ed = st + len(surf)
89
+ pos = ed
90
+ morphemes.append(
91
+ {
92
+ "surface": surf,
93
+ "reading": read,
94
+ "base_form": lemma,
95
+ "pos": pos_main,
96
+ "pos_detail1": pos1,
97
+ "pos_detail2": "*",
98
+ "pos_detail3": "*",
99
+ "start_pos": st,
100
+ "end_pos": ed,
101
+ }
102
+ )
103
+ sents.append({"text": text, "morphemes": morphemes})
104
+ return sents
105
+ except Exception as e:
106
+ raise RuntimeError(f"Failed to parse KNP with kyoto-reader: {knp_path}") from e
107
+
108
+
109
+ def match_morphemes_with_gold(candidates: List[Dict], gold_morphemes: List[Dict], text: str) -> List[Dict]:
110
+ """Match candidate morphemes to gold and assign annotations ('?', '+', '-').
111
+
112
+ Policy:
113
+ - Initialize every candidate as '?'
114
+ - Mark '+' for candidates that strictly match gold (surface, POS, base, reading)
115
+ - Mark '-' for candidates that overlap any '+' span
116
+ """
117
+ # Reconstruct gold spans in character offsets
118
+ gold_details = []
119
+ cur = 0
120
+ for g in gold_morphemes:
121
+ surf = g.get("surface", "")
122
+ st, ed = cur, cur + len(surf)
123
+ cur = ed
124
+ gold_details.append(
125
+ {
126
+ "start_pos": st,
127
+ "end_pos": ed,
128
+ "surface": surf,
129
+ "pos": g.get("pos", "*"),
130
+ "pos_detail1": g.get("pos_detail1", "*"),
131
+ "pos_detail2": g.get("pos_detail2", "*"),
132
+ "base_form": g.get("base_form", ""),
133
+ "reading": hiragana_to_katakana(g.get("reading", "")),
134
+ }
135
+ )
136
+
137
+ # Initialize all candidates with '?'
138
+ annotated: List[Dict] = []
139
+ for cand in candidates:
140
+ a = {**cand}
141
+ a["annotation"] = "?"
142
+ if "inflection_type" not in a:
143
+ a["inflection_type"] = "*"
144
+ if "inflection_form" not in a:
145
+ a["inflection_form"] = "*"
146
+ annotated.append(a)
147
+
148
+ # Match by strict equality first; allow reading mismatch as fallback
149
+ span_to_cands: dict[tuple[int, int], list[Dict]] = {}
150
+ for a in annotated:
151
+ cs = a.get("start_pos", 0)
152
+ ce = a.get("end_pos", cs + len(a.get("surface", "")))
153
+ span_to_cands.setdefault((cs, ce), []).append(a)
154
+
155
+ matched_spans: List[tuple[int, int]] = []
156
+ for g in gold_details:
157
+ span = (g["start_pos"], g["end_pos"])
158
+ cands = span_to_cands.get(span, [])
159
+ if not cands:
160
+ continue
161
+ strict = []
162
+ fallback = []
163
+ for a in cands:
164
+ if a.get("surface", "") != g["surface"]:
165
+ continue
166
+ if a.get("pos", "*") != g["pos"]:
167
+ continue
168
+ if a.get("pos_detail1", "*") != g.get("pos_detail1", "*"):
169
+ continue
170
+ if a.get("base_form", "") != g["base_form"]:
171
+ continue
172
+ if hiragana_to_katakana(a.get("reading", "")) == g["reading"]:
173
+ strict.append(a)
174
+ else:
175
+ fallback.append(a)
176
+ chosen_list = strict if strict else fallback
177
+ if chosen_list:
178
+ for a in chosen_list:
179
+ a["annotation"] = "+"
180
+ matched_spans.append(span)
181
+ for a in cands:
182
+ if (a not in chosen_list) and a.get("annotation") != "+":
183
+ a["annotation"] = "-"
184
+
185
+ # Demote any morpheme that overlaps (by at least 1 char) with any '+' span.
186
+ plus_spans = []
187
+ for a in annotated:
188
+ if a.get("annotation") == "+":
189
+ cs = a.get("start_pos", 0)
190
+ ce = a.get("end_pos", cs + len(a.get("surface", "")))
191
+ plus_spans.append((cs, ce))
192
+
193
+ def _strict_overlap(st1: int, ed1: int, st2: int, ed2: int) -> bool:
194
+ # overlap only if intersection length > 0 (touching is not overlap)
195
+ return max(st1, st2) < min(ed1, ed2)
196
+
197
+ for a in annotated:
198
+ if a.get("annotation") == "+":
199
+ continue
200
+ cs = a.get("start_pos", 0)
201
+ ce = a.get("end_pos", cs + len(a.get("surface", "")))
202
+ for ms, me in plus_spans:
203
+ if _strict_overlap(cs, ce, ms, me):
204
+ a["annotation"] = "-"
205
+ break
206
+ return annotated
207
+
208
+
209
+ def main():
210
+ parser = argparse.ArgumentParser(description="Create training data from KWDLC (JUMANDIC)")
211
+ parser.add_argument("--input-dir", type=str, default="KWDLC/knp", help="Directory containing KNP files")
212
+ parser.add_argument("--config", type=str, default="configs/gat.yaml", help="Path to config file")
213
+ parser.add_argument("--limit", type=int, help="Max number of files to process")
214
+ parser.add_argument("--test-only", action="store_true", help="Process only test split IDs")
215
+ parser.add_argument("--jumandic-path", type=str, default="/var/lib/mecab/dic/juman-utf8", help="Path to JUMANDIC")
216
+ args = parser.parse_args()
217
+
218
+ config = {}
219
+ if args.config and Path(args.config).exists():
220
+ with open(args.config, "r") as f:
221
+ config = yaml.safe_load(f)
222
+
223
+ if "extends" in config:
224
+ parent_config_path = Path(args.config).parent / config["extends"]
225
+ if parent_config_path.exists():
226
+ with open(parent_config_path, "r") as f:
227
+ parent_config = yaml.safe_load(f)
228
+
229
+ def deep_merge(base, override):
230
+ for key, value in override.items():
231
+ if key in base and isinstance(base[key], dict) and isinstance(value, dict):
232
+ deep_merge(base[key], value)
233
+ else:
234
+ base[key] = value
235
+ return base
236
+
237
+ config = deep_merge(parent_config, config)
238
+
239
+ features_config = config.get("features", {})
240
+ feature_dim = features_config.get("lexical_feature_dim", 100000)
241
+ training_config = config.get("training", {})
242
+
243
+ if training_config.get("annotations_dir"):
244
+ output_dir = Path(training_config.get("annotations_dir"))
245
+ else:
246
+ output_dir = Path("annotations_kwdlc_juman")
247
+ output_dir.mkdir(parents=True, exist_ok=True)
248
+ print(f"Lexical features: using {feature_dim} dims")
249
+ print(f"Output directory: {output_dir}")
250
+
251
+ analyzer = MeCabAnalyzer(
252
+ jumandic_path=args.jumandic_path,
253
+ )
254
+
255
+ knp_files = []
256
+
257
+ if args.test_only:
258
+ test_id_file = Path("KWDLC/id/split_for_pas/test.id")
259
+ if test_id_file.exists():
260
+ with open(test_id_file, "r") as f:
261
+ test_ids = [line.strip() for line in f if line.strip()]
262
+
263
+ knp_base_dir = Path(args.input_dir)
264
+ for file_id in test_ids:
265
+ dir_name = file_id[:13]
266
+ file_name = f"{file_id}.knp"
267
+ knp_path = knp_base_dir / dir_name / file_name
268
+ if knp_path.exists():
269
+ knp_files.append(knp_path)
270
+ else:
271
+ knp_dir = Path(args.input_dir)
272
+ knp_files = sorted(knp_dir.glob("**/*.knp"))
273
+
274
+ if args.limit:
275
+ knp_files = knp_files[: args.limit]
276
+
277
+ print(f"Files to process: {len(knp_files)}")
278
+ print(f"JUMANDIC: {args.jumandic_path}")
279
+ print(f"Output to: {output_dir}")
280
+
281
+ total_stats = defaultdict(int)
282
+ annotation_idx = 0
283
+
284
+ dm = DataModule(
285
+ annotations_dir=str(output_dir),
286
+ lexical_feature_dim=int(feature_dim),
287
+ use_bidirectional_edges=bool(config.get("edge_features", {}).get("use_bidirectional_edges", True)),
288
+ )
289
+
290
+ # Save .pt files directly under the output_dir
291
+
292
+ for knp_path in tqdm(knp_files, desc="processing"):
293
+ try:
294
+ sentences = _load_gold_with_kyoto(knp_path)
295
+ if not sentences:
296
+ continue
297
+
298
+ doc_id = knp_path.stem
299
+ for s in sentences:
300
+ s["source_id"] = doc_id
301
+
302
+ for sent_idx, sentence in enumerate(sentences):
303
+ text = sentence["text"]
304
+ gold_morphemes = sentence["morphemes"]
305
+ source_id = sentence.get("source_id", doc_id)
306
+
307
+ candidates = analyzer.get_morpheme_candidates(text)
308
+ candidates = normalize_mecab_candidates(candidates)
309
+ candidates = dedup_morphemes(candidates)
310
+ if not candidates:
311
+ continue
312
+
313
+ annotated_morphemes = match_morphemes_with_gold(candidates, gold_morphemes, text)
314
+
315
+ edges = build_adjacent_edges(annotated_morphemes)
316
+
317
+ for m in annotated_morphemes:
318
+ if "lexical_features" in m:
319
+ m.pop("lexical_features", None)
320
+
321
+ morphemes_with_feats = dm.compute_lexical_features(annotated_morphemes, text)
322
+ graph = dm.create_graph_from_morphemes_data(
323
+ morphemes=morphemes_with_feats,
324
+ edges=edges,
325
+ text=text,
326
+ for_training=True,
327
+ )
328
+ if graph is None:
329
+ continue
330
+
331
+ graph_file = output_dir / f"graph_{annotation_idx:04d}.pt"
332
+ payload = {
333
+ "graph": graph,
334
+ "source_id": source_id,
335
+ "text": text,
336
+ }
337
+ torch.save(payload, graph_file)
338
+
339
+ total_stats["sentences"] += 1
340
+ total_stats["morphemes"] += len(annotated_morphemes)
341
+ total_stats["positive"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "+")
342
+ total_stats["negative"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "-")
343
+
344
+ annotation_idx += 1
345
+
346
+ total_stats["files"] += 1
347
+
348
+ except Exception as e:
349
+ print(f"Error ({knp_path}): {e}")
350
+ total_stats["errors"] += 1
351
+
352
+ print("\n" + "=" * 50)
353
+ print("Processing complete")
354
+ print("=" * 50)
355
+ print(f"Files: {total_stats['files']}")
356
+ print(f"Sentences: {total_stats['sentences']}")
357
+ print(f"Morphemes: {total_stats['morphemes']}")
358
+ print(f"Positive (+): {total_stats['positive']}")
359
+ print(f"Negative (-): {total_stats['negative']}")
360
+ #
361
+ if total_stats["errors"] > 0:
362
+ print(f"Errors: {total_stats['errors']}")
363
+
364
+
365
+ if __name__ == "__main__":
366
+ main()
pyproject.toml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mecari-morpheme"
3
+ version = "0.1.0"
4
+ description = "Japanese morphological analysis using Graph Neural Networks"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11,<3.12"
7
+ dependencies = [
8
+ "torch>=2.2,<2.3",
9
+ "pytorch-lightning>=2.0.0",
10
+ "torch-geometric>=2.4,<2.5",
11
+ "numpy>=1.24,<2.0",
12
+ "pyyaml>=6.0",
13
+ "tqdm>=4.65.0",
14
+ "kyoto-reader>=2.5.0",
15
+ # Optional: enabled by default via config
16
+ "wandb>=0.15.0",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ dev = [
21
+ "ipython>=8.14.0",
22
+ "jupyter>=1.0.0",
23
+ "notebook>=7.0.0",
24
+ "pytest>=7.4.0",
25
+ "black>=23.0.0",
26
+ "ruff>=0.1.0",
27
+ ]
28
+
29
+ [build-system]
30
+ requires = ["setuptools>=61.0"]
31
+ build-backend = "setuptools.build_meta"
32
+
33
+ [tool.setuptools]
34
+ packages = ["mecari"]
35
+
36
+ [tool.uv]
37
+ index-url = "https://pypi.org/simple"
38
+ # Use CUDA 12.1 compatible PyG wheels (matches torch 2.2.x + cu121 environment)
39
+ find-links = ["https://data.pyg.org/whl/torch-2.2.0+cu121.html"]
40
+
41
+ # torch-cluster等のビルド時にtorchが必要
42
+ [tool.uv.extra-build-dependencies]
43
+ # Ensure torch is available when resolving extension wheels
44
+ torch-geometric = ["torch"]
45
+
46
+ [tool.ruff]
47
+ line-length = 120
48
+ target-version = "py311"
49
+
50
+ [tool.ruff.lint]
51
+ select = ["E", "F", "I"]
52
+ ignore = ["E501"] # line too long
53
+
54
+ [tool.black]
55
+ line-length = 120
56
+ target-version = ['py311']
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html
2
+
3
+ # Core runtime
4
+ torch==2.2.2
5
+ torch-scatter
6
+ torch-sparse
7
+ torch-cluster
8
+ torch-spline-conv
9
+ torch-geometric==2.4.0
10
+ pytorch-lightning==2.5.2
11
+ numpy>=1.24,<2.1
12
+ pyyaml>=6.0
13
+ tqdm>=4.65.0
14
+ kyoto-reader>=2.5.0
15
+
16
+ # UI
17
+ gradio>=4.37.0
18
+
19
+ # Optional logger (disabled at runtime)
20
+ wandb>=0.15.0
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11
sample_model/config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ edge_features:
2
+ use_bidirectional_edges: true
3
+ features:
4
+ lexical_feature_dim: 100000
5
+ inference:
6
+ checkpoint_dir: experiments
7
+ experiment_name: null
8
+ loss:
9
+ label_smoothing: 0.0
10
+ use_pos_weight: true
11
+ model:
12
+ dropout: 0.1
13
+ hidden_dim: 64
14
+ num_classes: 1
15
+ num_heads: 4
16
+ num_layers: 4
17
+ share_weights: false
18
+ type: gatv2
19
+ training:
20
+ accumulate_grad_batches: 1
21
+ annotations_dir: annotations_new
22
+ batch_size: 128
23
+ deterministic: false
24
+ gradient_clip_algorithm: norm
25
+ gradient_clip_val: 0.5
26
+ learning_rate: 0.001
27
+ log_every_n_steps: 50
28
+ max_steps: 10000
29
+ num_workers: 4
30
+ optimizer:
31
+ type: adamw
32
+ weight_decay: 0.001
33
+ patience: 10
34
+ project_name: mecari
35
+ seed: 42
36
+ use_wandb: true
37
+ val_check_interval: 1.0
38
+ warmup_start_lr: 0.0
39
+ warmup_steps: 500
sample_model/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccfc112d4a0dcdc0b087c9cabf1b45f1aeae2f1cb8a6f86196a115aa594f68d7
3
+ size 26975745
train.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+
6
+ # Disable tokenizer parallelism warning
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+ import argparse
10
+ import random
11
+ from datetime import datetime
12
+ from importlib import import_module
13
+ from typing import Optional
14
+
15
+ import numpy as np
16
+ import pytorch_lightning as pl
17
+ import torch
18
+ from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
19
+
20
+ from mecari.config.config import get_model_config, override_config, save_config
21
+ from mecari.data.data_module import DataModule
22
+
23
+
24
+ def set_seed(seed: int = 42, deterministic: bool = True) -> None:
25
+ """Set random seeds for reproducibility.
26
+
27
+ Args:
28
+ seed: Random seed value.
29
+ deterministic: If True, enforce deterministic behavior (slower).
30
+ """
31
+ random.seed(seed)
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+ torch.cuda.manual_seed(seed)
35
+ torch.cuda.manual_seed_all(seed)
36
+ torch.backends.cudnn.deterministic = deterministic
37
+ torch.backends.cudnn.benchmark = not deterministic
38
+ pl.seed_everything(seed)
39
+
40
+
41
+ def get_config_sections(config: dict) -> dict:
42
+ """Extract structured sections from a unified config dict."""
43
+ return {
44
+ "model": config["model"],
45
+ "training": config["training"],
46
+ "features": config.get("features", {}),
47
+ "edge": config.get("edge_features", {}),
48
+ }
49
+
50
+
51
+ def calculate_feature_dim(config: dict) -> int:
52
+ """Return feature dimension from config (lexical features by default)."""
53
+ features_cfg = config.get("features", {})
54
+
55
+ lexical_dim = features_cfg.get("lexical_feature_dim", 100000)
56
+ return lexical_dim
57
+
58
+
59
+ def create_data_module(config: dict) -> DataModule:
60
+ """Create DataModule from config (lexical-only pipeline)."""
61
+ features_cfg = config.get("features", {})
62
+ training_cfg = config["training"]
63
+ edge_cfg = config.get("edge_features", {})
64
+
65
+ lexical_feature_dim = features_cfg.get("lexical_feature_dim", 100000)
66
+
67
+ return DataModule(
68
+ annotations_dir=training_cfg["annotations_dir"],
69
+ batch_size=training_cfg["batch_size"],
70
+ num_workers=training_cfg["num_workers"],
71
+ max_files=training_cfg.get("max_files"),
72
+ use_bidirectional_edges=edge_cfg.get("use_bidirectional_edges", True),
73
+ annotations_override_dir=training_cfg.get("annotations_override_dir"),
74
+ lexical_feature_dim=lexical_feature_dim,
75
+ )
76
+
77
+
78
+ def setup_loggers(config: dict, experiment_name: str):
79
+ """Configure optional loggers (e.g., Weights & Biases)."""
80
+ import subprocess
81
+
82
+ from pytorch_lightning.loggers import WandbLogger
83
+
84
+ loggers = []
85
+
86
+ if config["training"]["use_wandb"]:
87
+ try:
88
+ tags = []
89
+ try:
90
+ branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"], text=True).strip()
91
+ tags.append(f"branch:{branch}")
92
+ commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
93
+ tags.append(f"commit:{commit}")
94
+ except:
95
+ pass
96
+
97
+ wandb_logger = WandbLogger(
98
+ project=config["training"]["project_name"],
99
+ name=experiment_name,
100
+ save_dir=f"experiments/{experiment_name}",
101
+ save_code=True,
102
+ log_model=False,
103
+ tags=tags,
104
+ )
105
+ loggers.append(wandb_logger)
106
+ print("✓ Added WandB logger (metrics only)")
107
+ except Exception as e:
108
+ print(f"WandbLogger initialization error: {e}")
109
+ else:
110
+ print("WandB logging disabled")
111
+
112
+ if not loggers:
113
+ loggers = False
114
+
115
+ return loggers
116
+
117
+
118
+ def create_trainer(config: dict, callbacks: list, loggers, deterministic: bool) -> pl.Trainer:
119
+ """Create a PyTorch Lightning Trainer."""
120
+ if torch.cuda.is_available():
121
+ accelerator = "gpu"
122
+ devices = 1
123
+ else:
124
+ accelerator = "cpu"
125
+ devices = 1
126
+
127
+ max_steps = config["training"].get("max_steps", 8600)
128
+ max_epochs = -1 # use max_steps only
129
+
130
+ trainer_kwargs = {
131
+ "max_epochs": max_epochs,
132
+ "max_steps": max_steps,
133
+ "callbacks": callbacks,
134
+ "logger": loggers,
135
+ "accelerator": accelerator,
136
+ "devices": devices,
137
+ "log_every_n_steps": config["training"]["log_every_n_steps"],
138
+ "val_check_interval": config["training"]["val_check_interval"],
139
+ "gradient_clip_val": config["training"]["gradient_clip_val"],
140
+ "enable_checkpointing": True,
141
+ "enable_progress_bar": True,
142
+ "limit_train_batches": 1.0,
143
+ "limit_val_batches": 1.0,
144
+ "limit_test_batches": 1.0,
145
+ "limit_predict_batches": 1.0,
146
+ "fast_dev_run": False,
147
+ "deterministic": deterministic,
148
+ "benchmark": not deterministic,
149
+ "precision": "16-mixed",
150
+ }
151
+
152
+ if "gradient_clip_algorithm" in config["training"]:
153
+ trainer_kwargs["gradient_clip_algorithm"] = config["training"]["gradient_clip_algorithm"]
154
+
155
+ if "accumulate_grad_batches" in config["training"]:
156
+ trainer_kwargs["accumulate_grad_batches"] = config["training"]["accumulate_grad_batches"]
157
+
158
+ return pl.Trainer(**trainer_kwargs)
159
+
160
+
161
+ def create_model_and_datamodule(config: dict, feature_dim: int, data_module: Optional[DataModule] = None):
162
+ """Create model and ensure DataModule is available (lexical-only)."""
163
+ cfg = get_config_sections(config)
164
+ model_cfg = cfg["model"]
165
+ training_cfg = cfg["training"]
166
+ features_cfg = cfg["features"]
167
+
168
+ if data_module is None:
169
+ data_module = create_data_module(config)
170
+
171
+ common_params = {
172
+ "hidden_dim": model_cfg["hidden_dim"],
173
+ "num_classes": model_cfg["num_classes"],
174
+ "learning_rate": training_cfg["learning_rate"],
175
+ "lexical_feature_dim": features_cfg.get("lexical_feature_dim", 100000),
176
+ }
177
+
178
+ if model_cfg["type"] == "gatv2":
179
+ MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2")
180
+ model = MecariGATv2(
181
+ **common_params,
182
+ num_heads=model_cfg["num_heads"],
183
+ share_weights=model_cfg.get("share_weights", False),
184
+ dropout=model_cfg.get("dropout", 0.1),
185
+ attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)),
186
+ add_self_loops_flag=model_cfg.get("add_self_loops", True),
187
+ edge_dropout=model_cfg.get("edge_dropout", 0.0),
188
+ norm=model_cfg.get("norm", "layer"),
189
+ )
190
+ else:
191
+ raise ValueError(f"Unsupported model type: {model_cfg['type']}")
192
+
193
+ return model, data_module
194
+
195
+
196
+ def main():
197
+ parser = argparse.ArgumentParser(description="Train the morphological analysis model")
198
+ parser.add_argument(
199
+ "--model",
200
+ "-m",
201
+ choices=["gatv2"],
202
+ default="gatv2",
203
+ help="Model type (only gatv2 supported). If a config is provided, config.model.type takes precedence.",
204
+ )
205
+ parser.add_argument("--config", "-c", help="Path to config file (overrides model type if present)")
206
+ parser.add_argument("--batch-size", "-b", type=int, help="Batch size")
207
+ parser.add_argument("--steps", "-s", type=int, help="Max training steps")
208
+ parser.add_argument("--lr", type=float, help="Learning rate")
209
+ parser.add_argument("--hidden-dim", type=int, help="Hidden dimension size")
210
+ parser.add_argument("--patience", type=int, help="Early stopping patience")
211
+ parser.add_argument("--weight-decay", type=float, help="Weight decay")
212
+ parser.add_argument("--no-wandb", action="store_true", help="Disable Weights & Biases logging")
213
+ parser.add_argument("--seed", type=int, help="Random seed")
214
+ parser.add_argument("--no-deterministic", action="store_true", help="Disable deterministic mode for speed")
215
+ parser.add_argument("--resume", type=str, help="Experiment name to resume (e.g., gatv2_20250806_162945)")
216
+ args = parser.parse_args()
217
+
218
+ # Load/merge config
219
+ if args.config:
220
+ from mecari.config.config import load_config
221
+
222
+ config = load_config(args.config)
223
+ if "model" in config and "type" in config["model"]:
224
+ args.model = config["model"]["type"]
225
+ else:
226
+ config = get_model_config(args.model)
227
+
228
+ overrides = {}
229
+
230
+ # Training overrides
231
+ training_overrides = {}
232
+ if args.batch_size:
233
+ training_overrides["batch_size"] = args.batch_size
234
+ if args.steps:
235
+ training_overrides["max_steps"] = args.steps
236
+ if args.lr:
237
+ training_overrides["learning_rate"] = args.lr
238
+ if args.no_wandb:
239
+ training_overrides["use_wandb"] = False
240
+ if args.patience:
241
+ training_overrides["patience"] = args.patience
242
+ if args.seed:
243
+ training_overrides["seed"] = args.seed
244
+ if args.no_deterministic:
245
+ training_overrides["deterministic"] = False
246
+
247
+ if training_overrides:
248
+ overrides["training"] = training_overrides
249
+
250
+ # Model overrides
251
+ if args.hidden_dim:
252
+ overrides["model"] = {"hidden_dim": args.hidden_dim}
253
+
254
+ # Optimizer overrides
255
+ if args.weight_decay:
256
+ overrides.setdefault("training", {})
257
+ overrides["training"]["optimizer"] = {"weight_decay": args.weight_decay}
258
+
259
+ if overrides:
260
+ config = override_config(config, overrides)
261
+
262
+ deterministic = config["training"].get("deterministic", True)
263
+ set_seed(config["training"]["seed"], deterministic=deterministic)
264
+
265
+ if not deterministic:
266
+ print("⚡ Performance mode: deterministic=False (reproducibility not guaranteed)")
267
+
268
+ resume_from_checkpoint = None
269
+ experiment_name = None
270
+ if args.resume:
271
+ experiment_path = os.path.join("experiments", args.resume)
272
+ if os.path.exists(experiment_path):
273
+ checkpoint_dir = os.path.join(experiment_path, "checkpoints")
274
+ if os.path.exists(checkpoint_dir):
275
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]
276
+ if checkpoints:
277
+ checkpoints.sort()
278
+ resume_from_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
279
+ print(f"Resuming training from: {resume_from_checkpoint}")
280
+ experiment_name = args.resume
281
+
282
+ config_path = os.path.join(experiment_path, "config.yaml")
283
+ if os.path.exists(config_path):
284
+ from mecari.config.config import load_config
285
+
286
+ config = load_config(config_path)
287
+ print(f"Restored config from: {config_path}")
288
+ else:
289
+ print(f"Warning: No checkpoints found in: {checkpoint_dir}")
290
+ else:
291
+ print(f"Warning: Checkpoint directory not found: {checkpoint_dir}")
292
+ else:
293
+ print(f"Warning: Experiment directory not found: {experiment_path}")
294
+ else:
295
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
296
+ experiment_name = f"{config['model']['type']}_{timestamp}"
297
+
298
+ print(f"Experiment: {experiment_name}")
299
+ print(f"Model: {config['model']['type'].upper()}")
300
+ print("Lexical features: enabled (default)")
301
+
302
+ if torch.cuda.is_available():
303
+ print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
304
+ print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
305
+ else:
306
+ print("💻 Using CPU")
307
+
308
+ data_module = create_data_module(config)
309
+
310
+ feature_dim = calculate_feature_dim(config)
311
+
312
+ model, _ = create_model_and_datamodule(config, feature_dim, data_module)
313
+
314
+ # Attach training config for schedulers, etc.
315
+ model.training_config = config["training"]
316
+
317
+ experiment_dir = f"experiments/{experiment_name}"
318
+ if not args.resume:
319
+ os.makedirs(experiment_dir, exist_ok=True)
320
+ save_config(config, f"{experiment_dir}/config.yaml")
321
+
322
+ checkpoint_callback_error = ModelCheckpoint(
323
+ dirpath=f"experiments/{experiment_name}/checkpoints",
324
+ filename=f"{config['model']['type']}-{{epoch:02d}}-{{val_error_epoch:.3f}}",
325
+ monitor="val_error_epoch",
326
+ mode="min",
327
+ save_top_k=1,
328
+ save_last=True,
329
+ )
330
+
331
+ early_stopping = EarlyStopping(
332
+ monitor="val_error_epoch", mode="min", patience=config["training"]["patience"], verbose=True, strict=False
333
+ )
334
+
335
+ loggers = setup_loggers(config, experiment_name)
336
+
337
+ callbacks = [checkpoint_callback_error, early_stopping]
338
+ try:
339
+ if loggers:
340
+ lr_monitor = LearningRateMonitor(logging_interval="step")
341
+ callbacks.append(lr_monitor)
342
+ except Exception:
343
+ pass
344
+ trainer = create_trainer(config, callbacks, loggers, deterministic)
345
+
346
+ print("Starting training...")
347
+
348
+ try:
349
+ if resume_from_checkpoint:
350
+ trainer.fit(model, data_module, ckpt_path=resume_from_checkpoint)
351
+ else:
352
+ trainer.fit(model, data_module)
353
+ training_status = "completed"
354
+
355
+ if data_module.test_dataset:
356
+ print("Evaluating on test data...")
357
+ trainer.test(model, data_module)
358
+ print("Training complete!")
359
+ except KeyboardInterrupt:
360
+ print("\nTraining interrupted...")
361
+ training_status = "interrupted"
362
+ except Exception as e:
363
+ print(f"\nError during training: {e}")
364
+ import traceback
365
+
366
+ traceback.print_exc()
367
+ training_status = "error"
368
+
369
+ print(f"Experiment: {experiment_name}")
370
+ print(f"Experiment dir: experiments/{experiment_name}")
371
+
372
+ print("\n=== Saved models ===")
373
+
374
+ if checkpoint_callback_error.best_model_path:
375
+ best_error = (
376
+ float(checkpoint_callback_error.best_model_score)
377
+ if checkpoint_callback_error.best_model_score is not None
378
+ else 1.0
379
+ )
380
+ print(f" Best val_error: {best_error:.6f}")
381
+ print(f" → {os.path.basename(checkpoint_callback_error.best_model_path)}")
382
+
383
+ print(f"\nFinal epoch: {trainer.current_epoch}")
384
+ print(f"Training status: {training_status}")
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
up_hf.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import HfApi
3
+ api = HfApi()
4
+ repo_id = "zbller/Mecari"
5
+ api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio", private=False, exist_ok=True, token=os.environ["HF_TOKEN"])
6
+ api.upload_folder(
7
+ repo_id=repo_id,
8
+ repo_type="space",
9
+ folder_path=".",
10
+ path_in_repo=".",
11
+ ignore_patterns=[
12
+ ".git", ".git/**", ".venv", ".venv/**", "__pycache__", "**/__pycache__",
13
+ "KWDLC", "KWDLC/**", "annotations", "annotations/**", "experiments", "experiments/**",
14
+ "mecari_morpheme.egg-info", "mecari_morpheme.egg-info/**",
15
+ ],
16
+ token=os.environ["HF_TOKEN"],
17
+ )
18
+ print(f"Uploaded to https://huggingface.co/spaces/{repo_id}")
uv.lock ADDED
The diff for this file is too large to render. See raw diff