""" Hugging Face Space — EMG → Animated Hand Pose (UmeTrack Mesh → MP4) """ import subprocess # Install Chrome for kaleido (required for fig.to_image on Hugging Face) try: import kaleido kaleido.get_chrome_sync() print("Chrome ready for kaleido.") except Exception as e: print(f"Chrome install warning: {e}") import json import types import importlib.util import sys import tempfile import os import numpy as np import h5py import cv2 import torch import gradio as gr import plotly.graph_objects as go from model import EMGPoseLSTM # ── Config ───────────────────────────────────────────────────────────────────── MODEL_PATH = "best_model.pt" DEVICE = torch.device("cpu") EMG_HZ = 2000 MODEL_HZ = 50 ANIMATION_HZ = 8 # frames per second in output video ANIMATION_STRIDE = max(1, MODEL_HZ // ANIMATION_HZ) MAX_SECONDS = 10 MAX_SAMPLES = MAX_SECONDS * EMG_HZ FRAME_W = 320 # smaller = faster rendering FRAME_H = 240 # ── Load UmeTrack FK pipeline ────────────────────────────────────────────────── def load_umetrack(): def _load_module(fullname, filepath): spec = importlib.util.spec_from_file_location(fullname, filepath) mod = importlib.util.module_from_spec(spec) sys.modules[fullname] = mod spec.loader.exec_module(mod) return mod sys.modules['lib'] = types.ModuleType('lib') sys.modules['lib.common'] = types.ModuleType('lib.common') _load_module('lib.common.pytorch3d_transforms_so3', 'pytorch3d_transforms_so3.py') hand_mod = _load_module('lib.common.hand', 'hand.py') skinning_mod = _load_module('lib.common.hand_skinning', 'hand_skinning.py') HandModel = hand_mod.HandModel _skin_points = skinning_mod._skin_points with open('generic_hand_model.json') as f: hm_dict = json.load(f) hand_model = HandModel(**{ k: torch.Tensor(v) if isinstance(v, list) else v for k, v in hm_dict.items() }) return hand_model, _skin_points HAND_MODEL, _SKIN_POINTS = load_umetrack() # ── Load model ───────────────────────────────────────────────────────────────── def load_model(): ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True) hidden_size = ckpt["model_state_dict"]["lstm.weight_hh_l0"].shape[1] num_layers = sum(1 for k in ckpt["model_state_dict"] if k.startswith("lstm.weight_hh_l")) print(f"Loading model: hidden_size={hidden_size}, num_layers={num_layers}") model = EMGPoseLSTM(in_channels=16, out_channels=20, hidden_size=hidden_size, num_layers=num_layers) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model MODEL = load_model() # ── HDF5 → EMG tensor ───────────────────────────────────────────────────────── def load_emg_from_hdf5(path: str) -> torch.Tensor: with h5py.File(path, "r") as f: emg = f["emg2pose/timeseries"][:]["emg"].astype(np.float32) emg = emg[:MAX_SAMPLES] mean = emg.mean(axis=0, keepdims=True) std = np.sqrt(np.maximum(((emg - mean) ** 2).mean(axis=0, keepdims=True), 0.0)) emg = (emg - mean) / (std + 1e-8) return torch.from_numpy(emg).T.unsqueeze(0) # (1, 16, T) # ── Inference ────────────────────────────────────────────────────────────────── @torch.no_grad() def run_inference(emg_t: torch.Tensor) -> np.ndarray: preds = MODEL(emg_t)[0].cpu().numpy().T # (T_emg, 20) preds = preds[::EMG_HZ // MODEL_HZ] # → 50Hz preds = preds[::ANIMATION_STRIDE] # → animation rate print(f"Pred shape: {preds.shape} std: {preds.std(axis=0).mean():.4f}") return preds # (T_anim, 20) # ── Joint angles → PNG frame (numpy array) ──────────────────────────────────── def render_frame_np(ja_20: np.ndarray) -> np.ndarray: """Render one hand pose frame as a numpy RGB array via Plotly + kaleido.""" ja_22 = np.concatenate([ja_20, np.zeros(2)]).astype(np.float32) ja_t = torch.from_numpy(ja_22) wrist = torch.eye(4) vertices = _SKIN_POINTS( HAND_MODEL.joint_rest_positions, HAND_MODEL.joint_rotation_axes, HAND_MODEL.dense_bone_weights, ja_t, HAND_MODEL.mesh_vertices, wrist, ) verts = vertices.numpy() tris = HAND_MODEL.mesh_triangles.numpy() x, y, z = verts.T i, j, k = tris.T.astype(int) mesh = go.Mesh3d( x=x.tolist(), y=y.tolist(), z=z.tolist(), i=i.tolist(), j=j.tolist(), k=k.tolist(), color="lightpink", opacity=1.0, lighting=dict(ambient=0.85, diffuse=0.2, specular=0.5, roughness=1.0), lightposition=dict(x=10, y=-500, z=-1), ) fig = go.Figure(data=[mesh]) fig.update_layout( paper_bgcolor="black", height=FRAME_H, width=FRAME_W, margin=dict(l=0, r=0, t=0, b=0), scene=dict( xaxis=dict(showbackground=False, showticklabels=False, showgrid=False, title=""), yaxis=dict(showbackground=False, showticklabels=False, showgrid=False, title=""), zaxis=dict(showbackground=False, showticklabels=False, showgrid=False, title=""), camera=dict(eye=dict(x=-0.45, y=-1.6, z=0.55), projection=dict(type="perspective")), aspectmode="data", bgcolor="black", ), ) img_bytes = fig.to_image(format="png") img_array = np.frombuffer(img_bytes, dtype=np.uint8) img_rgb = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # BGR return img_rgb # ── Render all frames → MP4 video file ──────────────────────────────────────── def build_video(joint_angles: np.ndarray) -> str: """ Renders each frame as PNG via kaleido, writes them to an MP4 using OpenCV. Returns path to the output video file. """ T = len(joint_angles) out_path = tempfile.mktemp(suffix=".mp4") writer = cv2.VideoWriter( out_path, cv2.VideoWriter_fourcc(*"mp4v"), ANIMATION_HZ, (FRAME_W, FRAME_H), ) print(f"Rendering {T} frames to video...") for t in range(T): frame = render_frame_np(joint_angles[t]) writer.write(frame) if t % 10 == 0: print(f" Frame {t}/{T}") writer.release() print(f"Video saved to {out_path}") return out_path # ── Gradio pipeline ──────────────────────────────────────────────────────────── def predict(hdf5_file): if hdf5_file is None: return None, "Please upload an HDF5 file." try: emg_t = load_emg_from_hdf5(hdf5_file.name) except Exception as e: return None, f"Failed to read HDF5: {e}" duration_s = emg_t.shape[-1] / EMG_HZ joint_angles = run_inference(emg_t) video_path = build_video(joint_angles) info = (f"EMG duration: {duration_s:.1f}s | " f"Frames: {len(joint_angles)} | " f"Video rate: {ANIMATION_HZ} fps") return video_path, info # ── UI ───────────────────────────────────────────────────────────────────────── with gr.Blocks(title="EMG → Hand Pose") as demo: gr.Markdown( "## 🖐️ EMG → Animated Hand Pose\n" "Upload an **emg2pose-format HDF5 file**. " "The model predicts 20 joint angles per timestep and renders " "an animated 3D hand mesh as a video." ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File(label="Upload HDF5 file", file_types=[".h5", ".hdf5"]) run_btn = gr.Button("Run Inference", variant="primary") info_box = gr.Textbox(label="Info", interactive=False) with gr.Column(scale=2): video_out = gr.Video(label="Predicted Hand Pose", autoplay=True) run_btn.click( fn=predict, inputs=file_input, outputs=[video_out, info_box], ) gr.Markdown( "---\n" "**Model:** LSTM trained on the [emg2pose](https://arxiv.org/abs/2412.02725) dataset (Salter et al., 2024). \n\n" "**Model development:** Brian Mullen, Dayoung Lee, Kristin Dona, Sero Toriano Parel \n\n" "**Hand visualization:** 3D mesh rendered using forward kinematics from [UmeTrack](https://dl.acm.org/doi/10.1145/3550469.3555378) " "(Han et al., 2022). \n\n" "**License:** CC-BY-NC-SA 4.0 — non-commercial research use only." ) demo.launch(theme=gr.themes.Base())