Spaces:
Sleeping
Sleeping
Commit ·
7296cca
1
Parent(s): 55e47b2
Update App.py
Browse files
app.py
CHANGED
|
@@ -5,14 +5,12 @@ from transformers import BartForConditionalGeneration, BartTokenizer
|
|
| 5 |
import re
|
| 6 |
import numpy as np
|
| 7 |
import networkx as nx
|
|
|
|
| 8 |
from typing import List, Dict
|
| 9 |
from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
|
| 10 |
from src.preprocessing.edu_sentences import preprocess_external_text
|
| 11 |
from src.model.baseline_extractive_model import get_trigrams
|
| 12 |
-
from
|
| 13 |
-
load_dotenv() # Tải biến môi trường từ file .env
|
| 14 |
-
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # Thay bằng token của bạn nếu cần
|
| 15 |
-
PORT = int(os.environ.get("PORT", 7860))
|
| 16 |
|
| 17 |
|
| 18 |
REPO_ID_baseline_model = "Reality8081/bart-base"
|
|
@@ -21,11 +19,39 @@ REPO_ID_baseline_extractive_model = "Reality8081/bart_extractive"
|
|
| 21 |
REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
|
| 22 |
REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
|
| 23 |
REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
text_to_summarize = prepro_dict.get("article", "")
|
| 30 |
|
| 31 |
if not text_to_summarize.strip():
|
|
@@ -38,15 +64,15 @@ def model_baseline(prepro_dict: Dict) -> str:
|
|
| 38 |
repo_id = REPO_ID_baseline_model
|
| 39 |
summarizer = get_summarizer(repo_id)
|
| 40 |
summary = summarizer.summarize(text_to_summarize)
|
| 41 |
-
return summary
|
| 42 |
|
| 43 |
|
| 44 |
|
| 45 |
-
def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
|
| 46 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
| 47 |
segments = prepro_dict["segments"]
|
| 48 |
if not segments:
|
| 49 |
-
|
| 50 |
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
|
| 51 |
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
|
| 52 |
segmentation_method = prepro_dict.get("segmentation_method")
|
|
@@ -100,12 +126,16 @@ def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> str:
|
|
| 100 |
# Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
|
| 101 |
selected_indices = sorted(selected_indices)
|
| 102 |
|
|
|
|
| 103 |
# Lắp ráp kết quả
|
| 104 |
extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
def model_extractive_abstract(prepro_dict: Dict) -> str:
|
| 108 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
|
|
|
| 109 |
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
|
| 110 |
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
|
| 111 |
segmentation_method = prepro_dict.get("segmentation_method")
|
|
@@ -114,7 +144,30 @@ def model_extractive_abstract(prepro_dict: Dict) -> str:
|
|
| 114 |
else:
|
| 115 |
repo_id = REPO_ID_Extabs_model
|
| 116 |
model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
|
|
|
|
|
|
|
| 117 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
summary_ids = model.generate_summary(
|
| 119 |
input_ids=input_ids,
|
| 120 |
attention_mask=attention_mask,
|
|
@@ -124,9 +177,11 @@ def model_extractive_abstract(prepro_dict: Dict) -> str:
|
|
| 124 |
early_stopping=True
|
| 125 |
)
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
# ====================== MAIN FUNCTION ======================
|
|
@@ -136,22 +191,48 @@ def ATS(
|
|
| 136 |
model: str = None,
|
| 137 |
reference_summary: str = None
|
| 138 |
) -> str:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
return result
|
| 155 |
|
| 156 |
|
| 157 |
# ====================== GRADIO INTERFACE ======================
|
|
@@ -167,13 +248,19 @@ with gr.Blocks(
|
|
| 167 |
)
|
| 168 |
|
| 169 |
with gr.Row():
|
| 170 |
-
with gr.Column(scale=
|
| 171 |
input_text = gr.Textbox(
|
| 172 |
label="📝 Text to Summarize",
|
| 173 |
placeholder="Paste your long text here (up to several thousand words)...",
|
| 174 |
-
lines=
|
| 175 |
max_lines=30,
|
| 176 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
with gr.Column(scale=1):
|
| 179 |
gr.Markdown("### ⚙️ Settings")
|
|
@@ -198,25 +285,23 @@ with gr.Blocks(
|
|
| 198 |
label="Summarization Model",
|
| 199 |
info="Select the model you want to use"
|
| 200 |
)
|
| 201 |
-
|
|
|
|
| 202 |
with gr.Row():
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
lines=10,
|
| 212 |
-
placeholder="The result will appear here...",
|
| 213 |
-
)
|
| 214 |
|
| 215 |
# Connect button click
|
| 216 |
-
|
| 217 |
fn=ATS,
|
| 218 |
inputs=[input_text, method, model],
|
| 219 |
-
outputs=output_text
|
| 220 |
)
|
| 221 |
|
| 222 |
# Examples
|
|
|
|
| 5 |
import re
|
| 6 |
import numpy as np
|
| 7 |
import networkx as nx
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
from typing import List, Dict
|
| 10 |
from src.utils.get_model import get_summarizer, get_extractive_model, get_extractive_abstractive
|
| 11 |
from src.preprocessing.edu_sentences import preprocess_external_text
|
| 12 |
from src.model.baseline_extractive_model import get_trigrams
|
| 13 |
+
from typing import List, Dict, Tuple
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
REPO_ID_baseline_model = "Reality8081/bart-base"
|
|
|
|
| 19 |
REPO_ID_baseline_extractive_model_edu = "Reality8081/bart_extractive-edu"
|
| 20 |
REPO_ID_Extabs_model = "Reality8081/bart-encoder-decoder"
|
| 21 |
REPO_ID_Extabs_model_edu = "Reality8081/bart-encoder-decoder-edu"
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
|
| 24 |
+
def create_saliency_plot(segment_scores: List[float], selected_indices: List[int], segments: List[str]) -> go.Figure:
|
| 25 |
+
"""Tạo biểu đồ cột thể hiện Saliency Score của từng Segment"""
|
| 26 |
+
colors = ['#EF4444' if i in selected_indices else '#3B82F6' for i in range(len(segment_scores))]
|
| 27 |
+
labels = ["Chọn tóm tắt" if i in selected_indices else "Bỏ qua" for i in range(len(segment_scores))]
|
| 28 |
+
|
| 29 |
+
# Rút gọn text để hiển thị khi hover chuột
|
| 30 |
+
hover_texts = [f"<b>Segment {i}</b><br>Score: {score:.3f}<br>Text: {seg[:60]}..."
|
| 31 |
+
for i, (score, seg) in enumerate(zip(segment_scores, segments))]
|
| 32 |
|
| 33 |
+
fig = go.Figure(data=[go.Bar(
|
| 34 |
+
x=[f"Seg {i}" for i in range(len(segment_scores))],
|
| 35 |
+
y=segment_scores,
|
| 36 |
+
marker_color=colors,
|
| 37 |
+
text=[f"{s:.2f}" for s in segment_scores],
|
| 38 |
+
textposition='auto',
|
| 39 |
+
hovertext=hover_texts,
|
| 40 |
+
hoverinfo="text"
|
| 41 |
+
)])
|
| 42 |
+
|
| 43 |
+
fig.update_layout(
|
| 44 |
+
title="Saliency Scores trên từng Câu/EDU",
|
| 45 |
+
xaxis_title="Vị trí Câu / EDU",
|
| 46 |
+
yaxis_title="Saliency Score (0-1)",
|
| 47 |
+
template="plotly_white",
|
| 48 |
+
margin=dict(l=40, r=40, t=40, b=40),
|
| 49 |
+
height=350
|
| 50 |
+
)
|
| 51 |
+
return fig
|
| 52 |
|
| 53 |
+
|
| 54 |
+
def model_baseline(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
|
| 55 |
text_to_summarize = prepro_dict.get("article", "")
|
| 56 |
|
| 57 |
if not text_to_summarize.strip():
|
|
|
|
| 64 |
repo_id = REPO_ID_baseline_model
|
| 65 |
summarizer = get_summarizer(repo_id)
|
| 66 |
summary = summarizer.summarize(text_to_summarize)
|
| 67 |
+
return summary, None, None
|
| 68 |
|
| 69 |
|
| 70 |
|
| 71 |
+
def model_baseline_extractive(prepro_dict: Dict, top_n = 5) -> Tuple[str, float, go.Figure]:
|
| 72 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
| 73 |
segments = prepro_dict["segments"]
|
| 74 |
if not segments:
|
| 75 |
+
raise ValueError("Không thể phân tách văn bản thành các câu/EDU.")
|
| 76 |
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
|
| 77 |
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
|
| 78 |
segmentation_method = prepro_dict.get("segmentation_method")
|
|
|
|
| 126 |
# Sắp xếp lại thứ tự index xuất hiện của các câu trong văn bản gốc để tóm tắt được mạch lạc
|
| 127 |
selected_indices = sorted(selected_indices)
|
| 128 |
|
| 129 |
+
|
| 130 |
# Lắp ráp kết quả
|
| 131 |
extractive_summary = " ".join([segments[i] for i in selected_indices if i < len(segments)])
|
| 132 |
+
avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
|
| 133 |
+
fig = create_saliency_plot(segment_scores, selected_indices, segments)
|
| 134 |
+
return extractive_summary, avg_confidence, fig
|
| 135 |
|
| 136 |
+
def model_extractive_abstract(prepro_dict: Dict) -> Tuple[str, float, go.Figure]:
|
| 137 |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
| 138 |
+
segments = prepro_dict["segments"]
|
| 139 |
input_ids = torch.tensor([prepro_dict["input_ids"]]).to(device)
|
| 140 |
attention_mask = torch.tensor([prepro_dict["attention_mask"]]).to(device)
|
| 141 |
segmentation_method = prepro_dict.get("segmentation_method")
|
|
|
|
| 144 |
else:
|
| 145 |
repo_id = REPO_ID_Extabs_model
|
| 146 |
model = get_extractive_abstractive(repo_id=repo_id, base_model_name="facebook/bart-large", device=device)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
with torch.no_grad():
|
| 150 |
+
encoder_outputs = model.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 151 |
+
|
| 152 |
+
hidden_states = encoder_outputs.last_hidden_state
|
| 153 |
+
ext_logits = model.ext_head(hidden_states).squeeze(-1)
|
| 154 |
+
probs = torch.sigmoid(ext_logits).squeeze(0).cpu().numpy()
|
| 155 |
+
segment_scores = []
|
| 156 |
+
current_idx = 1 # Bỏ qua token đặc biệt <s> ở đầu chuỗi
|
| 157 |
+
for seg in segments:
|
| 158 |
+
seg_len = len(tokenizer.encode(seg, add_special_tokens=False))
|
| 159 |
+
end_idx = min(current_idx + seg_len, len(probs))
|
| 160 |
+
|
| 161 |
+
if current_idx < len(probs):
|
| 162 |
+
seg_score = float(np.mean(probs[current_idx:end_idx]))
|
| 163 |
+
else:
|
| 164 |
+
seg_score = 0.0
|
| 165 |
+
|
| 166 |
+
segment_scores.append(seg_score)
|
| 167 |
+
current_idx += seg_len
|
| 168 |
+
|
| 169 |
+
selected_indices = [i for i, score in enumerate(segment_scores) if score >= 0.5]
|
| 170 |
+
avg_confidence = float(np.mean([segment_scores[i] for i in selected_indices])) if selected_indices else 0.0
|
| 171 |
summary_ids = model.generate_summary(
|
| 172 |
input_ids=input_ids,
|
| 173 |
attention_mask=attention_mask,
|
|
|
|
| 177 |
early_stopping=True
|
| 178 |
)
|
| 179 |
|
| 180 |
+
|
| 181 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 182 |
+
|
| 183 |
+
fig = create_saliency_plot(segment_scores, selected_indices, segments) # Hiện tại chưa chọn câu nào nên selected_indices rỗng
|
| 184 |
+
return summary, avg_confidence, fig
|
| 185 |
|
| 186 |
|
| 187 |
# ====================== MAIN FUNCTION ======================
|
|
|
|
| 191 |
model: str = None,
|
| 192 |
reference_summary: str = None
|
| 193 |
) -> str:
|
| 194 |
+
if not text or len(text.split()) < 5:
|
| 195 |
+
return "Văn bản quá ngắn hoặc rỗng, vui lòng nhập thêm nội dung.", "⚠️ Không có dữ liệu", None
|
| 196 |
+
try:
|
| 197 |
+
if segmentation_method == "Sentence-based Preprocessing":
|
| 198 |
+
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='sentence')
|
| 199 |
+
else:
|
| 200 |
+
prepro_dict = preprocess_external_text(text, reference_summary, segmentation_method='edu')
|
| 201 |
|
| 202 |
+
# Step 2: Chọn model
|
| 203 |
+
if model == "Baseline Model":
|
| 204 |
+
result, avg_conf, plot_fig = model_baseline(prepro_dict)
|
| 205 |
+
elif model == "Baseline Model with Extractive":
|
| 206 |
+
result, avg_conf, plot_fig = model_baseline_extractive(prepro_dict)
|
| 207 |
+
else:
|
| 208 |
+
result, avg_conf, plot_fig = model_extractive_abstract(prepro_dict)
|
| 209 |
+
|
| 210 |
+
origin_words = len(text.split())
|
| 211 |
+
sum_words = len(result.split())
|
| 212 |
+
comp_ratio = (sum_words / origin_words * 100) if origin_words > 0 else 0
|
| 213 |
+
conf_str = f"**{avg_conf * 100:.2f}%**" if avg_conf is not None else "*N/A (Không tính Saliency Score)*"
|
| 214 |
+
|
| 215 |
+
metrics_md = f"""
|
| 216 |
+
### 📊 Thống kê kết quả
|
| 217 |
+
- **Tổng số từ văn bản gốc:** {origin_words} từ
|
| 218 |
+
- **Số từ bản tóm tắt:** {sum_words} từ
|
| 219 |
+
- **Tỉ lệ nén:** {comp_ratio:.1f}% *(Giữ lại {comp_ratio:.1f}% dung lượng gốc)*
|
| 220 |
+
- **Độ tin cậy trung bình (Confidence):** {conf_str}
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
if plot_fig is None:
|
| 224 |
+
plot_fig = go.Figure()
|
| 225 |
+
plot_fig.update_layout(title="Mô hình Baseline không hỗ trợ Saliency Plot.", template="plotly_white")
|
| 226 |
+
|
| 227 |
+
return result, metrics_md, plot_fig
|
| 228 |
+
except Exception as e:
|
| 229 |
+
return f"Đã xảy ra lỗi: {str(e)}", "⚠️ Lỗi hệ thống", None
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
|
|
|
|
| 236 |
|
| 237 |
|
| 238 |
# ====================== GRADIO INTERFACE ======================
|
|
|
|
| 248 |
)
|
| 249 |
|
| 250 |
with gr.Row():
|
| 251 |
+
with gr.Column(scale=2):
|
| 252 |
input_text = gr.Textbox(
|
| 253 |
label="📝 Text to Summarize",
|
| 254 |
placeholder="Paste your long text here (up to several thousand words)...",
|
| 255 |
+
lines=15,
|
| 256 |
max_lines=30,
|
| 257 |
)
|
| 258 |
+
with gr.Row():
|
| 259 |
+
btn_summary = gr.Button(
|
| 260 |
+
"🔍 Summarize Now",
|
| 261 |
+
variant="primary",
|
| 262 |
+
size="large"
|
| 263 |
+
)
|
| 264 |
|
| 265 |
with gr.Column(scale=1):
|
| 266 |
gr.Markdown("### ⚙️ Settings")
|
|
|
|
| 285 |
label="Summarization Model",
|
| 286 |
info="Select the model you want to use"
|
| 287 |
)
|
| 288 |
+
output_metrics = gr.Markdown("### 📊 Thống kê kết quả\n*Chờ xử lý...*")
|
| 289 |
+
gr.Markdown("---")
|
| 290 |
with gr.Row():
|
| 291 |
+
with gr.Column(scale=1):
|
| 292 |
+
output_text = gr.Textbox(
|
| 293 |
+
label="📄 Summary Result",
|
| 294 |
+
lines=10,
|
| 295 |
+
placeholder="The result will appear here...",
|
| 296 |
+
)
|
| 297 |
+
with gr.Column(scale=1):
|
| 298 |
+
output_plot = gr.Plot(label="📈 Saliency Score Plot")
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
# Connect button click
|
| 301 |
+
btn_summary.click(
|
| 302 |
fn=ATS,
|
| 303 |
inputs=[input_text, method, model],
|
| 304 |
+
outputs=[output_text, output_metrics, output_plot]
|
| 305 |
)
|
| 306 |
|
| 307 |
# Examples
|