emg-pose / app.py
dayounglee's picture
Update app.py
8ac208c verified
"""
Hugging Face Space β€” EMG β†’ Animated Hand Pose (UmeTrack Mesh β†’ MP4)
"""
import subprocess
# Install Chrome for kaleido (required for fig.to_image on Hugging Face)
try:
import kaleido
kaleido.get_chrome_sync()
print("Chrome ready for kaleido.")
except Exception as e:
print(f"Chrome install warning: {e}")
import json
import types
import importlib.util
import sys
import tempfile
import os
import numpy as np
import h5py
import cv2
import torch
import gradio as gr
import plotly.graph_objects as go
from model import EMGPoseLSTM
# ── Config ─────────────────────────────────────────────────────────────────────
MODEL_PATH = "best_model.pt"
DEVICE = torch.device("cpu")
EMG_HZ = 2000
MODEL_HZ = 50
ANIMATION_HZ = 8 # frames per second in output video
ANIMATION_STRIDE = max(1, MODEL_HZ // ANIMATION_HZ)
MAX_SECONDS = 10
MAX_SAMPLES = MAX_SECONDS * EMG_HZ
FRAME_W = 320 # smaller = faster rendering
FRAME_H = 240
# ── Load UmeTrack FK pipeline ──────────────────────────────────────────────────
def load_umetrack():
def _load_module(fullname, filepath):
spec = importlib.util.spec_from_file_location(fullname, filepath)
mod = importlib.util.module_from_spec(spec)
sys.modules[fullname] = mod
spec.loader.exec_module(mod)
return mod
sys.modules['lib'] = types.ModuleType('lib')
sys.modules['lib.common'] = types.ModuleType('lib.common')
_load_module('lib.common.pytorch3d_transforms_so3', 'pytorch3d_transforms_so3.py')
hand_mod = _load_module('lib.common.hand', 'hand.py')
skinning_mod = _load_module('lib.common.hand_skinning', 'hand_skinning.py')
HandModel = hand_mod.HandModel
_skin_points = skinning_mod._skin_points
with open('generic_hand_model.json') as f:
hm_dict = json.load(f)
hand_model = HandModel(**{
k: torch.Tensor(v) if isinstance(v, list) else v
for k, v in hm_dict.items()
})
return hand_model, _skin_points
HAND_MODEL, _SKIN_POINTS = load_umetrack()
# ── Load model ─────────────────────────────────────────────────────────────────
def load_model():
ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
hidden_size = ckpt["model_state_dict"]["lstm.weight_hh_l0"].shape[1]
num_layers = sum(1 for k in ckpt["model_state_dict"]
if k.startswith("lstm.weight_hh_l"))
print(f"Loading model: hidden_size={hidden_size}, num_layers={num_layers}")
model = EMGPoseLSTM(in_channels=16, out_channels=20,
hidden_size=hidden_size, num_layers=num_layers)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model
MODEL = load_model()
# ── HDF5 β†’ EMG tensor ─────────────────────────────────────────────────────────
def load_emg_from_hdf5(path: str) -> torch.Tensor:
with h5py.File(path, "r") as f:
emg = f["emg2pose/timeseries"][:]["emg"].astype(np.float32)
emg = emg[:MAX_SAMPLES]
mean = emg.mean(axis=0, keepdims=True)
std = np.sqrt(np.maximum(((emg - mean) ** 2).mean(axis=0, keepdims=True), 0.0))
emg = (emg - mean) / (std + 1e-8)
return torch.from_numpy(emg).T.unsqueeze(0) # (1, 16, T)
# ── Inference ──────────────────────────────────────────────────────────────────
@torch.no_grad()
def run_inference(emg_t: torch.Tensor) -> np.ndarray:
preds = MODEL(emg_t)[0].cpu().numpy().T # (T_emg, 20)
preds = preds[::EMG_HZ // MODEL_HZ] # β†’ 50Hz
preds = preds[::ANIMATION_STRIDE] # β†’ animation rate
print(f"Pred shape: {preds.shape} std: {preds.std(axis=0).mean():.4f}")
return preds # (T_anim, 20)
# ── Joint angles β†’ PNG frame (numpy array) ────────────────────────────────────
def render_frame_np(ja_20: np.ndarray) -> np.ndarray:
"""Render one hand pose frame as a numpy RGB array via Plotly + kaleido."""
ja_22 = np.concatenate([ja_20, np.zeros(2)]).astype(np.float32)
ja_t = torch.from_numpy(ja_22)
wrist = torch.eye(4)
vertices = _SKIN_POINTS(
HAND_MODEL.joint_rest_positions,
HAND_MODEL.joint_rotation_axes,
HAND_MODEL.dense_bone_weights,
ja_t,
HAND_MODEL.mesh_vertices,
wrist,
)
verts = vertices.numpy()
tris = HAND_MODEL.mesh_triangles.numpy()
x, y, z = verts.T
i, j, k = tris.T.astype(int)
mesh = go.Mesh3d(
x=x.tolist(), y=y.tolist(), z=z.tolist(),
i=i.tolist(), j=j.tolist(), k=k.tolist(),
color="lightpink", opacity=1.0,
lighting=dict(ambient=0.85, diffuse=0.2, specular=0.5, roughness=1.0),
lightposition=dict(x=10, y=-500, z=-1),
)
fig = go.Figure(data=[mesh])
fig.update_layout(
paper_bgcolor="black",
height=FRAME_H, width=FRAME_W,
margin=dict(l=0, r=0, t=0, b=0),
scene=dict(
xaxis=dict(showbackground=False, showticklabels=False,
showgrid=False, title=""),
yaxis=dict(showbackground=False, showticklabels=False,
showgrid=False, title=""),
zaxis=dict(showbackground=False, showticklabels=False,
showgrid=False, title=""),
camera=dict(eye=dict(x=-0.45, y=-1.6, z=0.55),
projection=dict(type="perspective")),
aspectmode="data",
bgcolor="black",
),
)
img_bytes = fig.to_image(format="png")
img_array = np.frombuffer(img_bytes, dtype=np.uint8)
img_rgb = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # BGR
return img_rgb
# ── Render all frames β†’ MP4 video file ────────────────────────────────────────
def build_video(joint_angles: np.ndarray) -> str:
"""
Renders each frame as PNG via kaleido, writes them to an MP4 using OpenCV.
Returns path to the output video file.
"""
T = len(joint_angles)
out_path = tempfile.mktemp(suffix=".mp4")
writer = cv2.VideoWriter(
out_path,
cv2.VideoWriter_fourcc(*"mp4v"),
ANIMATION_HZ,
(FRAME_W, FRAME_H),
)
print(f"Rendering {T} frames to video...")
for t in range(T):
frame = render_frame_np(joint_angles[t])
writer.write(frame)
if t % 10 == 0:
print(f" Frame {t}/{T}")
writer.release()
print(f"Video saved to {out_path}")
return out_path
# ── Gradio pipeline ────────────────────────────────────────────────────────────
def predict(hdf5_file):
if hdf5_file is None:
return None, "Please upload an HDF5 file."
try:
emg_t = load_emg_from_hdf5(hdf5_file.name)
except Exception as e:
return None, f"Failed to read HDF5: {e}"
duration_s = emg_t.shape[-1] / EMG_HZ
joint_angles = run_inference(emg_t)
video_path = build_video(joint_angles)
info = (f"EMG duration: {duration_s:.1f}s | "
f"Frames: {len(joint_angles)} | "
f"Video rate: {ANIMATION_HZ} fps")
return video_path, info
# ── UI ─────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="EMG β†’ Hand Pose") as demo:
gr.Markdown(
"## πŸ–οΈ EMG β†’ Animated Hand Pose\n"
"Upload an **emg2pose-format HDF5 file**. "
"The model predicts 20 joint angles per timestep and renders "
"an animated 3D hand mesh as a video."
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload HDF5 file",
file_types=[".h5", ".hdf5"])
run_btn = gr.Button("Run Inference", variant="primary")
info_box = gr.Textbox(label="Info", interactive=False)
with gr.Column(scale=2):
video_out = gr.Video(label="Predicted Hand Pose",
autoplay=True)
run_btn.click(
fn=predict,
inputs=file_input,
outputs=[video_out, info_box],
)
gr.Markdown(
"---\n"
"**Model:** LSTM trained on the [emg2pose](https://arxiv.org/abs/2412.02725) dataset (Salter et al., 2024). \n\n"
"**Model development:** Brian Mullen, Dayoung Lee, Kristin Dona, Sero Toriano Parel \n\n"
"**Hand visualization:** 3D mesh rendered using forward kinematics from [UmeTrack](https://dl.acm.org/doi/10.1145/3550469.3555378) "
"(Han et al., 2022). \n\n"
"**License:** CC-BY-NC-SA 4.0 β€” non-commercial research use only."
)
demo.launch(theme=gr.themes.Base())