rafmacalaba commited on
Commit
dcddeb3
·
verified ·
1 Parent(s): 18addf6

Mirror rafmacalaba/gliner2-datause-large-v1-deval-synth-v2 -> production

Browse files
Files changed (1) hide show
  1. README.md +50 -51
README.md CHANGED
@@ -45,8 +45,8 @@ from huggingface_hub import snapshot_download
45
  # Install the patched GLiNER2 library:
46
  # pip install git+https://github.com/rafmacalaba/GLiNER2.git@feat/main-mirror
47
 
48
- BASE_MODEL = "fastino/gliner2-large-v1"
49
- ADAPTER_ID = "ai4data/datause-extraction"
50
 
51
  extractor = GLiNER2.from_pretrained(BASE_MODEL)
52
  extractor.load_adapter(snapshot_download(ADAPTER_ID))
@@ -62,60 +62,59 @@ CLASSIFICATION_TASKS = {
62
  "usage_context": ["primary", "supporting", "background"],
63
  }
64
 
65
- text = "We use the Demographic and Health Survey (DHS) 2020 as our primary data source."
66
-
67
- # Pass 1 — extract entity spans
68
- entity_result = extractor.extract_entities(
69
- text, ["data_mention"], threshold=0.3, include_confidence=True
70
- )
71
- spans = entity_result.get("entities", {}).get("data_mention", [])
72
-
73
- # Pass 2 — classify each span using its context window
74
- CONTEXT = 150
75
- results = []
76
- for span in spans:
77
- mention = span.get("text", "")
78
- start = text.find(mention)
79
- ctx = text[max(0, start - CONTEXT) : start + len(mention) + CONTEXT]
80
- context_str = f"Mention: {mention} | Context: {ctx}"
81
-
82
- classes = extractor.classify_text(context_str, CLASSIFICATION_TASKS, threshold=0.3)
83
- results.append({
84
- "mention_name": mention,
85
- "confidence": span.get("confidence", 0),
86
- "specificity_tag": classes.get("specificity_tag", ("", 0))[0],
87
- "typology_tag": classes.get("typology_tag", ("", 0))[0],
88
- "is_used": classes.get("is_used", ("", 0))[0],
89
- "usage_context": classes.get("usage_context", ("", 0))[0],
90
- })
91
-
92
- print(results)
93
- ```
94
 
95
- ### Batch inference (recommended for documents)
96
 
97
- ```python
98
- # Pass 1 — batched
99
- all_res_ent = extractor.batch_extract_entities(
100
- texts, ["data_mention"], threshold=0.3, batch_size=8, include_confidence=True
101
- )
 
 
 
 
 
 
102
 
103
- # Build context strings for every extracted span, then Pass 2 — batched
104
  classification_queue = []
105
  for idx, (res_ent, text) in enumerate(zip(all_res_ent, texts)):
106
- for span in res_ent.get("entities", {}).get("data_mention", []):
107
- mention = span.get("text", "")
108
- start = text.find(mention)
109
- ctx = text[max(0, start - 150) : start + len(mention) + 150]
110
- classification_queue.append((idx, mention, span.get("confidence", 0),
111
- f"Mention: {mention} | Context: {ctx}"))
112
-
113
- all_classes = extractor.batch_classify_text(
114
- [q[3] for q in classification_queue],
115
- CLASSIFICATION_TASKS,
116
- threshold=0.3,
117
- batch_size=8,
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  ```
120
 
121
  ## Training Details
 
45
  # Install the patched GLiNER2 library:
46
  # pip install git+https://github.com/rafmacalaba/GLiNER2.git@feat/main-mirror
47
 
48
+ BASE_MODEL = "fastino/gliner2-large-v1"
49
+ ADAPTER_ID = "ai4data/datause-extraction"
50
 
51
  extractor = GLiNER2.from_pretrained(BASE_MODEL)
52
  extractor.load_adapter(snapshot_download(ADAPTER_ID))
 
62
  "usage_context": ["primary", "supporting", "background"],
63
  }
64
 
65
+ # texts: list of passage strings to run extraction on
66
+ texts = ["We use the Demographic and Health Survey (DHS) 2020 as our primary data source."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ BZ = 8 # batch size
69
 
70
+ # Pass 1: batched entity extraction
71
+ all_res_ent = []
72
+ for i in range(0, len(texts), BZ):
73
+ batch = texts[i : i + BZ]
74
+ res = extractor.batch_extract_entities(
75
+ batch, ["data_mention"],
76
+ threshold=0.3,
77
+ batch_size=BZ,
78
+ include_confidence=True,
79
+ )
80
+ all_res_ent.extend(res)
81
 
82
+ # Build classification queue one entry per valid extracted span
83
  classification_queue = []
84
  for idx, (res_ent, text) in enumerate(zip(all_res_ent, texts)):
85
+ spans = (
86
+ res_ent.get("entities", {}).get("data_mention", [])
87
+ if isinstance(res_ent, dict)
88
+ else res_ent
89
+ )
90
+ for span_data in spans:
91
+ span_text = span_data.get("text", "") if isinstance(span_data, dict) else str(span_data)
92
+ span_conf = span_data.get("confidence", 0.0) if isinstance(span_data, dict) else 1.0
93
+ if len(span_text) < 3:
94
+ continue
95
+ start = text.find(span_text)
96
+ ctx_start = max(0, start - 150) if start != -1 else 0
97
+ ctx_end = min(len(text), start + len(span_text) + 150) if start != -1 else len(text)
98
+ context_str = f"Mention: {span_text} | Context: {text[ctx_start:ctx_end]}"
99
+ classification_queue.append((idx, span_text, span_conf, context_str))
100
+
101
+ # Pass 2: batched zero-shot classification on context windows
102
+ all_classes = []
103
+ for i in range(0, len(classification_queue), BZ):
104
+ batch_ctx = [q[3] for q in classification_queue[i : i + BZ]]
105
+ res = extractor.batch_classify_text(
106
+ batch_ctx, CLASSIFICATION_TASKS, threshold=0.3, batch_size=BZ
107
+ )
108
+ all_classes.extend(res)
109
+
110
+ # Assemble results grouped by source chunk index
111
+ chunk_results = {i: [] for i in range(len(texts))}
112
+ for q_item, classes in zip(classification_queue, all_classes):
113
+ idx, span_text, conf, _ = q_item
114
+ mention = {"mention_name": span_text, "confidence": conf}
115
+ for task, out in classes.items():
116
+ mention[task] = out[0] if isinstance(out, tuple) and len(out) == 2 else out
117
+ chunk_results[idx].append(mention)
118
  ```
119
 
120
  ## Training Details