Reality8081 commited on
Commit
7296cca
·
1 Parent(s): 55e47b2

Update App.py

Browse files
Files changed (1) hide show
  1. app.py +129 -44
app.py CHANGED
@@ -5,14 +5,12 @@ from transformers import BartForConditionalGeneration, BartTokenizer
5
  import re
6
  import numpy as np
7
  import networkx as nx
 
8
  from typing import List, Dict
9
  from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
10
  from src.preprocessing.edu_sentences import preprocess_external_text
11
  from src.model.baseline_extractive_model import get_trigrams
12
- from dotenv import load_dotenv
13
- load_dotenv() # Tải biến môi trường từ file .env
14
- os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # Thay bằng token của bạn nếu cần
15
- PORT = int(os.environ.get("PORT", 7860))
16
 
17
 
18
  REPO_ID_baseline_model = "Reality8081/bart-base"
@@ -21,11 +19,39 @@ REPO_ID_baseline_extractive_model = "Reality8081/bart_extractive"
21
  REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
22
  REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
23
  REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
 
24
 
 
 
 
 
 
 
 
 
25
 
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def model_baseline(prepro_dict: Dict) -> str:
 
29
  text_to_summarize = prepro_dict.get("article", "")
30
 
31
  if not text_to_summarize.strip():
@@ -38,15 +64,15 @@ def model_baseline(prepro_dict: Dict) -> str:
38
  repo_id = REPO_ID_baseline_model
39
  summarizer = get_summarizer(repo_id)
40
  summary = summarizer.summarize(text_to_summarize)
41
- return summary
42
 
43
 
44
 
45
- def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
46
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
47
  segments = prepro_dict["segments"]
48
  if not segments:
49
- return "Không thể phân tách văn bản thành các câu/EDU."
50
  input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
51
  attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
52
  segmentation_method = prepro_dict.get("segmentation_method")
@@ -100,12 +126,16 @@ def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
100
  # Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
101
  selected_indices = sorted(selected_indices)
102
 
 
103
  # Lắp ráp kết quả
104
  extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
105
- return extractive_summary
 
 
106
 
107
- def model_extractive_abstract(prepro_dict: Dict) -> str:
108
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
 
109
  input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
110
  attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
111
  segmentation_method = prepro_dict.get("segmentation_method")
@@ -114,7 +144,30 @@ def model_extractive_abstract(prepro_dict: Dict) -> str:
114
  else:
115
  repo_id = REPO_ID_Extabs_model
116
  model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
 
 
117
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  summary_ids = model.generate_summary(
119
  input_ids=input_ids,
120
  attention_mask=attention_mask,
@@ -124,9 +177,11 @@ def model_extractive_abstract(prepro_dict: Dict) -> str:
124
  early_stopping=True
125
  )
126
 
127
- # 4. Giải mã chuỗi
128
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
129
- return summary
 
 
130
 
131
 
132
  # ====================== MAIN FUNCTION ======================
@@ -136,22 +191,48 @@ def ATS(
136
  model: str = None,
137
  reference_summary: str = None
138
  ) -> str:
139
- """Main workflow: Raw Text Preprocessing → Model"""
140
- # Step 1: Preprocessing
141
- if segmentation_method == "Sentence-based Preprocessing":
142
- prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
143
- else:
144
- prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
 
145
 
146
- # Step 2: Chọn model
147
- if model == "Baseline Model":
148
- result = model_baseline(prepro_dict)
149
- elif model == "Baseline Model with Extractive":
150
- result = model_baseline_extractive(prepro_dict)
151
- else:
152
- result = model_extractive_abstract(prepro_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- return result
155
 
156
 
157
  # ====================== GRADIO INTERFACE ======================
@@ -167,13 +248,19 @@ with gr.Blocks(
167
  )
168
 
169
  with gr.Row():
170
- with gr.Column(scale=3):
171
  input_text = gr.Textbox(
172
  label="📝 Text to Summarize",
173
  placeholder="Paste your long text here (up to several thousand words)...",
174
- lines=12,
175
  max_lines=30,
176
  )
 
 
 
 
 
 
177
 
178
  with gr.Column(scale=1):
179
  gr.Markdown("### ⚙️ Settings")
@@ -198,25 +285,23 @@ with gr.Blocks(
198
  label="Summarization Model",
199
  info="Select the model you want to use"
200
  )
201
-
 
202
  with gr.Row():
203
- btn_tom_tat = gr.Button(
204
- "🔍 Summarize Now",
205
- variant="primary",
206
- size="large"
207
- )
208
-
209
- output_text = gr.Textbox(
210
- label="📄 Summary Result",
211
- lines=10,
212
- placeholder="The result will appear here...",
213
- )
214
 
215
  # Connect button click
216
- btn_tom_tat.click(
217
  fn=ATS,
218
  inputs=[input_text, method, model],
219
- outputs=output_text
220
  )
221
 
222
  # Examples
 
5
  import re
6
  import numpy as np
7
  import networkx as nx
8
+ import plotly.graph_objects as go
9
  from typing import List, Dict
10
  from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
11
  from src.preprocessing.edu_sentences import preprocess_external_text
12
  from src.model.baseline_extractive_model import get_trigrams
13
+ from typing import List, Dict, Tuple
 
 
 
14
 
15
 
16
  REPO_ID_baseline_model = "Reality8081/bart-base"
 
19
  REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
20
  REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
21
  REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
+ def create_saliency_plot(segment_scores: List[float], selected_indices: List[int], segments: List[str]) -> go.Figure:
25
+ """Tạo biểu đồ cột thể hiện Saliency Score của từng Segment"""
26
+ colors = ['#EF4444' if i in selected_indices else '#3B82F6' for i in range(len(segment_scores))]
27
+ labels = ["Chọn tóm tắt" if i in selected_indices else "Bỏ qua" for i in range(len(segment_scores))]
28
+
29
+ # Rút gọn text để hiển thị khi hover chuột
30
+ hover_texts = [f"<b>Segment {i}</b><br>Score: {score:.3f}<br>Text: {seg[:60]}..."
31
+ for i, (score, seg) in enumerate(zip(segment_scores, segments))]
32
 
33
+ fig = go.Figure(data=[go.Bar(
34
+ x=[f"Seg {i}" for i in range(len(segment_scores))],
35
+ y=segment_scores,
36
+ marker_color=colors,
37
+ text=[f"{s:.2f}" for s in segment_scores],
38
+ textposition='auto',
39
+ hovertext=hover_texts,
40
+ hoverinfo="text"
41
+ )])
42
+
43
+ fig.update_layout(
44
+ title="Saliency Scores trên từng Câu/EDU",
45
+ xaxis_title="Vị trí Câu / EDU",
46
+ yaxis_title="Saliency Score (0-1)",
47
+ template="plotly_white",
48
+ margin=dict(l=40, r=40, t=40, b=40),
49
+ height=350
50
+ )
51
+ return fig
52
 
53
+
54
+ def model_baseline(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
55
  text_to_summarize = prepro_dict.get("article", "")
56
 
57
  if not text_to_summarize.strip():
 
64
  repo_id = REPO_ID_baseline_model
65
  summarizer = get_summarizer(repo_id)
66
  summary = summarizer.summarize(text_to_summarize)
67
+ return summary, None, None
68
 
69
 
70
 
71
+ def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> Tuple[str, float, go.Figure]:
72
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
73
  segments = prepro_dict["segments"]
74
  if not segments:
75
+ raise ValueError("Không thể phân tách văn bản thành các câu/EDU.")
76
  input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
77
  attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
78
  segmentation_method = prepro_dict.get("segmentation_method")
 
126
  # Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
127
  selected_indices = sorted(selected_indices)
128
 
129
+
130
  # Lắp ráp kết quả
131
  extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
132
+ avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
133
+ fig = create_saliency_plot(segment_scores, selected_indices, segments)
134
+ return extractive_summary, avg_confidence, fig
135
 
136
+ def model_extractive_abstract(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
137
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
138
+ segments = prepro_dict["segments"]
139
  input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
140
  attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
141
  segmentation_method = prepro_dict.get("segmentation_method")
 
144
  else:
145
  repo_id = REPO_ID_Extabs_model
146
  model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
147
+
148
+
149
  with torch.no_grad():
150
+ encoder_outputs = model.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
151
+
152
+ hidden_states = encoder_outputs.last_hidden_state
153
+ ext_logits = model.ext_head(hidden_states).squeeze(-1)
154
+ probs = torch.sigmoid(ext_logits).squeeze(0).cpu().numpy()
155
+ segment_scores = []
156
+ current_idx = 1 # Bỏ qua token đặc biệt <s> ở đầu chuỗi
157
+ for seg in segments:
158
+ seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
159
+ end_idx = min(current_idx + seg_len, len(probs))
160
+
161
+ if current_idx < len(probs):
162
+ seg_score = float(np.mean(probs[current_idx:end_idx]))
163
+ else:
164
+ seg_score = 0.0
165
+
166
+ segment_scores.append(seg_score)
167
+ current_idx += seg_len
168
+
169
+ selected_indices = [i for i, score in enumerate(segment_scores) if score >= 0.5]
170
+ avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
171
  summary_ids = model.generate_summary(
172
  input_ids=input_ids,
173
  attention_mask=attention_mask,
 
177
  early_stopping=True
178
  )
179
 
180
+
181
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
182
+
183
+ fig = create_saliency_plot(segment_scores, selected_indices, segments) # Hiện tại chưa chọn câu nào nên selected_indices rỗng
184
+ return summary, avg_confidence, fig
185
 
186
 
187
  # ====================== MAIN FUNCTION ======================
 
191
  model: str = None,
192
  reference_summary: str = None
193
  ) -> str:
194
+ if not text or len(text.split()) < 5:
195
+ return "Văn bản quá ngắn hoặc rỗng, vui lòng nhập thêm nội dung.", "⚠️ Không có dữ liệu", None
196
+ try:
197
+ if segmentation_method == "Sentence-based Preprocessing":
198
+ prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
199
+ else:
200
+ prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
201
 
202
+ # Step 2: Chọn model
203
+ if model == "Baseline Model":
204
+ result, avg_conf, plot_fig = model_baseline(prepro_dict)
205
+ elif model == "Baseline Model with Extractive":
206
+ result, avg_conf, plot_fig = model_baseline_extractive(prepro_dict)
207
+ else:
208
+ result, avg_conf, plot_fig = model_extractive_abstract(prepro_dict)
209
+
210
+ origin_words = len(text.split())
211
+ sum_words = len(result.split())
212
+ comp_ratio = (sum_words / origin_words * 100) if origin_words > 0 else 0
213
+ conf_str = f"**{avg_conf * 100:.2f}%**" if avg_conf is not None else "*N/A (Không tính Saliency Score)*"
214
+
215
+ metrics_md = f"""
216
+ ### 📊 Thống kê kết quả
217
+ - **Tổng số từ văn bản gốc:** {origin_words} từ
218
+ - **Số từ bản tóm tắt:** {sum_words} từ
219
+ - **Tỉ lệ nén:** {comp_ratio:.1f}% *(Giữ lại {comp_ratio:.1f}% dung lượng gốc)*
220
+ - **Độ tin cậy trung bình (Confidence):** {conf_str}
221
+ """
222
+
223
+ if plot_fig is None:
224
+ plot_fig = go.Figure()
225
+ plot_fig.update_layout(title="Mô hình Baseline không hỗ trợ Saliency Plot.", template="plotly_white")
226
+
227
+ return result, metrics_md, plot_fig
228
+ except Exception as e:
229
+ return f"Đã xảy ra lỗi: {str(e)}", "⚠️ Lỗi hệ thống", None
230
+
231
+
232
+
233
+
234
+
235
 
 
236
 
237
 
238
  # ====================== GRADIO INTERFACE ======================
 
248
  )
249
 
250
  with gr.Row():
251
+ with gr.Column(scale=2):
252
  input_text = gr.Textbox(
253
  label="📝 Text to Summarize",
254
  placeholder="Paste your long text here (up to several thousand words)...",
255
+ lines=15,
256
  max_lines=30,
257
  )
258
+ with gr.Row():
259
+ btn_summary = gr.Button(
260
+ "🔍 Summarize Now",
261
+ variant="primary",
262
+ size="large"
263
+ )
264
 
265
  with gr.Column(scale=1):
266
  gr.Markdown("### ⚙️ Settings")
 
285
  label="Summarization Model",
286
  info="Select the model you want to use"
287
  )
288
+ output_metrics = gr.Markdown("### 📊 Thống kê kết quả\n*Chờ xử lý...*")
289
+ gr.Markdown("---")
290
  with gr.Row():
291
+ with gr.Column(scale=1):
292
+ output_text = gr.Textbox(
293
+ label="📄 Summary Result",
294
+ lines=10,
295
+ placeholder="The result will appear here...",
296
+ )
297
+ with gr.Column(scale=1):
298
+ output_plot = gr.Plot(label="📈 Saliency Score Plot")
 
 
 
299
 
300
  # Connect button click
301
+ btn_summary.click(
302
  fn=ATS,
303
  inputs=[input_text, method, model],
304
+ outputs=[output_text, output_metrics, output_plot]
305
  )
306
 
307
  # Examples