Udayshankar Ravikumar commited on
Commit
7bea7be
·
unverified ·
1 Parent(s): a86d37e

Updated GUI.

Browse files
Files changed (1) hide show
  1. app.py +13 -27
app.py CHANGED
@@ -241,8 +241,7 @@ def run_inference_fast(file):
241
  save_output_csv(agg, "ranked"),
242
  save_output_csv(topk, "topk"),
243
  save_output_csv(df, "per_benchmark"),
244
- topk.to_dict("records"),
245
- "Inference complete. Generating SHAP plots..."
246
  )
247
 
248
 
@@ -251,11 +250,8 @@ def run_inference_fast(file):
251
  # =========================================================
252
  def generate_shap(topk_records):
253
  topk_df = pd.DataFrame(topk_records)
254
-
255
  if len(topk_df) < 2:
256
- return [None] * (len(WORKLOADS) * len(TARGETS) * 2) + [
257
- "Need at least 2 configs for SHAP comparison."
258
- ]
259
 
260
  rank1 = topk_df.iloc[0]
261
  rank2 = topk_df.iloc[1]
@@ -275,7 +271,7 @@ def generate_shap(topk_records):
275
  f"model_{workload}_{target}"
276
  ).shap_values(X)[0]
277
 
278
- plt.figure(figsize=(7, 4))
279
 
280
  shap.plots.bar(
281
  shap.Explanation(
@@ -290,12 +286,16 @@ def generate_shap(topk_records):
290
  path = OUTPUT_DIR / f"shap_{uuid.uuid4().hex[:8]}.png"
291
 
292
  plt.tight_layout()
293
- plt.savefig(path, dpi=120, bbox_inches="tight")
 
 
 
 
 
294
  plt.close()
295
 
296
  outputs.append(str(path))
297
 
298
- outputs.append("SHAP generation complete.")
299
  return outputs
300
 
301
 
@@ -304,20 +304,9 @@ def generate_shap(topk_records):
304
  # =========================================================
305
  CUSTOM_CSS = """
306
  .gradio-container {
307
- max-width: 1600px !important;
308
- margin: auto !important;
309
- }
310
-
311
- .gr-image,
312
- .gr-image img,
313
- .gr-dataframe {
314
- max-width: 100% !important;
315
  }
316
-
317
- .gr-row {
318
- flex-wrap: wrap !important;
319
- }
320
-
321
  footer {
322
  display: none !important;
323
  }
@@ -333,8 +322,6 @@ Upload cache configurations once — automatically evaluated across all workload
333
  file_input = gr.File(label="Upload Cache Config CSV")
334
  run_btn = gr.Button("Run Inference", variant="primary")
335
 
336
- status_md = gr.Markdown()
337
-
338
  with gr.Tabs():
339
  with gr.Tab("📊 Summary"):
340
  summary_df = gr.Dataframe(interactive=False)
@@ -379,15 +366,14 @@ Upload cache configurations once — automatically evaluated across all workload
379
  ranked_csv,
380
  topk_csv,
381
  per_benchmark_csv,
382
- topk_state,
383
- status_md
384
  ]
385
  )
386
 
387
  inference_event.then(
388
  fn=generate_shap,
389
  inputs=topk_state,
390
- outputs=shap_image_components + [status_md]
391
  )
392
 
393
  demo.launch(
 
241
  save_output_csv(agg, "ranked"),
242
  save_output_csv(topk, "topk"),
243
  save_output_csv(df, "per_benchmark"),
244
+ topk.to_dict("records")
 
245
  )
246
 
247
 
 
250
  # =========================================================
251
  def generate_shap(topk_records):
252
  topk_df = pd.DataFrame(topk_records)
 
253
  if len(topk_df) < 2:
254
+ return [None] * (len(WORKLOADS) * len(TARGETS) * 2)
 
 
255
 
256
  rank1 = topk_df.iloc[0]
257
  rank2 = topk_df.iloc[1]
 
271
  f"model_{workload}_{target}"
272
  ).shap_values(X)[0]
273
 
274
+ plt.figure(figsize=(7, 4), dpi=180)
275
 
276
  shap.plots.bar(
277
  shap.Explanation(
 
286
  path = OUTPUT_DIR / f"shap_{uuid.uuid4().hex[:8]}.png"
287
 
288
  plt.tight_layout()
289
+ plt.savefig(
290
+ path,
291
+ dpi=180,
292
+ bbox_inches="tight",
293
+ pad_inches=0.05
294
+ )
295
  plt.close()
296
 
297
  outputs.append(str(path))
298
 
 
299
  return outputs
300
 
301
 
 
304
  # =========================================================
305
  CUSTOM_CSS = """
306
  .gradio-container {
307
+ max-width: 1450px !important;
308
+ margin: auto;
 
 
 
 
 
 
309
  }
 
 
 
 
 
310
  footer {
311
  display: none !important;
312
  }
 
322
  file_input = gr.File(label="Upload Cache Config CSV")
323
  run_btn = gr.Button("Run Inference", variant="primary")
324
 
 
 
325
  with gr.Tabs():
326
  with gr.Tab("📊 Summary"):
327
  summary_df = gr.Dataframe(interactive=False)
 
366
  ranked_csv,
367
  topk_csv,
368
  per_benchmark_csv,
369
+ topk_state
 
370
  ]
371
  )
372
 
373
  inference_event.then(
374
  fn=generate_shap,
375
  inputs=topk_state,
376
+ outputs=shap_image_components
377
  )
378
 
379
  demo.launch(