Spaces:
Paused
Paused
| """ | |
| Hugging Face Space β EMG β Animated Hand Pose (UmeTrack Mesh β MP4) | |
| """ | |
| import subprocess | |
| # Install Chrome for kaleido (required for fig.to_image on Hugging Face) | |
| try: | |
| import kaleido | |
| kaleido.get_chrome_sync() | |
| print("Chrome ready for kaleido.") | |
| except Exception as e: | |
| print(f"Chrome install warning: {e}") | |
| import json | |
| import types | |
| import importlib.util | |
| import sys | |
| import tempfile | |
| import os | |
| import numpy as np | |
| import h5py | |
| import cv2 | |
| import torch | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from model import EMGPoseLSTM | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_PATH = "best_model.pt" | |
| DEVICE = torch.device("cpu") | |
| EMG_HZ = 2000 | |
| MODEL_HZ = 50 | |
| ANIMATION_HZ = 8 # frames per second in output video | |
| ANIMATION_STRIDE = max(1, MODEL_HZ // ANIMATION_HZ) | |
| MAX_SECONDS = 10 | |
| MAX_SAMPLES = MAX_SECONDS * EMG_HZ | |
| FRAME_W = 320 # smaller = faster rendering | |
| FRAME_H = 240 | |
| # ββ Load UmeTrack FK pipeline ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_umetrack(): | |
| def _load_module(fullname, filepath): | |
| spec = importlib.util.spec_from_file_location(fullname, filepath) | |
| mod = importlib.util.module_from_spec(spec) | |
| sys.modules[fullname] = mod | |
| spec.loader.exec_module(mod) | |
| return mod | |
| sys.modules['lib'] = types.ModuleType('lib') | |
| sys.modules['lib.common'] = types.ModuleType('lib.common') | |
| _load_module('lib.common.pytorch3d_transforms_so3', 'pytorch3d_transforms_so3.py') | |
| hand_mod = _load_module('lib.common.hand', 'hand.py') | |
| skinning_mod = _load_module('lib.common.hand_skinning', 'hand_skinning.py') | |
| HandModel = hand_mod.HandModel | |
| _skin_points = skinning_mod._skin_points | |
| with open('generic_hand_model.json') as f: | |
| hm_dict = json.load(f) | |
| hand_model = HandModel(**{ | |
| k: torch.Tensor(v) if isinstance(v, list) else v | |
| for k, v in hm_dict.items() | |
| }) | |
| return hand_model, _skin_points | |
| HAND_MODEL, _SKIN_POINTS = load_umetrack() | |
| # ββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| ckpt = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True) | |
| hidden_size = ckpt["model_state_dict"]["lstm.weight_hh_l0"].shape[1] | |
| num_layers = sum(1 for k in ckpt["model_state_dict"] | |
| if k.startswith("lstm.weight_hh_l")) | |
| print(f"Loading model: hidden_size={hidden_size}, num_layers={num_layers}") | |
| model = EMGPoseLSTM(in_channels=16, out_channels=20, | |
| hidden_size=hidden_size, num_layers=num_layers) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.eval() | |
| return model | |
| MODEL = load_model() | |
| # ββ HDF5 β EMG tensor βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_emg_from_hdf5(path: str) -> torch.Tensor: | |
| with h5py.File(path, "r") as f: | |
| emg = f["emg2pose/timeseries"][:]["emg"].astype(np.float32) | |
| emg = emg[:MAX_SAMPLES] | |
| mean = emg.mean(axis=0, keepdims=True) | |
| std = np.sqrt(np.maximum(((emg - mean) ** 2).mean(axis=0, keepdims=True), 0.0)) | |
| emg = (emg - mean) / (std + 1e-8) | |
| return torch.from_numpy(emg).T.unsqueeze(0) # (1, 16, T) | |
| # ββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_inference(emg_t: torch.Tensor) -> np.ndarray: | |
| preds = MODEL(emg_t)[0].cpu().numpy().T # (T_emg, 20) | |
| preds = preds[::EMG_HZ // MODEL_HZ] # β 50Hz | |
| preds = preds[::ANIMATION_STRIDE] # β animation rate | |
| print(f"Pred shape: {preds.shape} std: {preds.std(axis=0).mean():.4f}") | |
| return preds # (T_anim, 20) | |
| # ββ Joint angles β PNG frame (numpy array) ββββββββββββββββββββββββββββββββββββ | |
| def render_frame_np(ja_20: np.ndarray) -> np.ndarray: | |
| """Render one hand pose frame as a numpy RGB array via Plotly + kaleido.""" | |
| ja_22 = np.concatenate([ja_20, np.zeros(2)]).astype(np.float32) | |
| ja_t = torch.from_numpy(ja_22) | |
| wrist = torch.eye(4) | |
| vertices = _SKIN_POINTS( | |
| HAND_MODEL.joint_rest_positions, | |
| HAND_MODEL.joint_rotation_axes, | |
| HAND_MODEL.dense_bone_weights, | |
| ja_t, | |
| HAND_MODEL.mesh_vertices, | |
| wrist, | |
| ) | |
| verts = vertices.numpy() | |
| tris = HAND_MODEL.mesh_triangles.numpy() | |
| x, y, z = verts.T | |
| i, j, k = tris.T.astype(int) | |
| mesh = go.Mesh3d( | |
| x=x.tolist(), y=y.tolist(), z=z.tolist(), | |
| i=i.tolist(), j=j.tolist(), k=k.tolist(), | |
| color="lightpink", opacity=1.0, | |
| lighting=dict(ambient=0.85, diffuse=0.2, specular=0.5, roughness=1.0), | |
| lightposition=dict(x=10, y=-500, z=-1), | |
| ) | |
| fig = go.Figure(data=[mesh]) | |
| fig.update_layout( | |
| paper_bgcolor="black", | |
| height=FRAME_H, width=FRAME_W, | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| scene=dict( | |
| xaxis=dict(showbackground=False, showticklabels=False, | |
| showgrid=False, title=""), | |
| yaxis=dict(showbackground=False, showticklabels=False, | |
| showgrid=False, title=""), | |
| zaxis=dict(showbackground=False, showticklabels=False, | |
| showgrid=False, title=""), | |
| camera=dict(eye=dict(x=-0.45, y=-1.6, z=0.55), | |
| projection=dict(type="perspective")), | |
| aspectmode="data", | |
| bgcolor="black", | |
| ), | |
| ) | |
| img_bytes = fig.to_image(format="png") | |
| img_array = np.frombuffer(img_bytes, dtype=np.uint8) | |
| img_rgb = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # BGR | |
| return img_rgb | |
| # ββ Render all frames β MP4 video file ββββββββββββββββββββββββββββββββββββββββ | |
| def build_video(joint_angles: np.ndarray) -> str: | |
| """ | |
| Renders each frame as PNG via kaleido, writes them to an MP4 using OpenCV. | |
| Returns path to the output video file. | |
| """ | |
| T = len(joint_angles) | |
| out_path = tempfile.mktemp(suffix=".mp4") | |
| writer = cv2.VideoWriter( | |
| out_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| ANIMATION_HZ, | |
| (FRAME_W, FRAME_H), | |
| ) | |
| print(f"Rendering {T} frames to video...") | |
| for t in range(T): | |
| frame = render_frame_np(joint_angles[t]) | |
| writer.write(frame) | |
| if t % 10 == 0: | |
| print(f" Frame {t}/{T}") | |
| writer.release() | |
| print(f"Video saved to {out_path}") | |
| return out_path | |
| # ββ Gradio pipeline ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(hdf5_file): | |
| if hdf5_file is None: | |
| return None, "Please upload an HDF5 file." | |
| try: | |
| emg_t = load_emg_from_hdf5(hdf5_file.name) | |
| except Exception as e: | |
| return None, f"Failed to read HDF5: {e}" | |
| duration_s = emg_t.shape[-1] / EMG_HZ | |
| joint_angles = run_inference(emg_t) | |
| video_path = build_video(joint_angles) | |
| info = (f"EMG duration: {duration_s:.1f}s | " | |
| f"Frames: {len(joint_angles)} | " | |
| f"Video rate: {ANIMATION_HZ} fps") | |
| return video_path, info | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="EMG β Hand Pose") as demo: | |
| gr.Markdown( | |
| "## ποΈ EMG β Animated Hand Pose\n" | |
| "Upload an **emg2pose-format HDF5 file**. " | |
| "The model predicts 20 joint angles per timestep and renders " | |
| "an animated 3D hand mesh as a video." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Upload HDF5 file", | |
| file_types=[".h5", ".hdf5"]) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| info_box = gr.Textbox(label="Info", interactive=False) | |
| with gr.Column(scale=2): | |
| video_out = gr.Video(label="Predicted Hand Pose", | |
| autoplay=True) | |
| run_btn.click( | |
| fn=predict, | |
| inputs=file_input, | |
| outputs=[video_out, info_box], | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "**Model:** LSTM trained on the [emg2pose](https://arxiv.org/abs/2412.02725) dataset (Salter et al., 2024). \n\n" | |
| "**Model development:** Brian Mullen, Dayoung Lee, Kristin Dona, Sero Toriano Parel \n\n" | |
| "**Hand visualization:** 3D mesh rendered using forward kinematics from [UmeTrack](https://dl.acm.org/doi/10.1145/3550469.3555378) " | |
| "(Han et al., 2022). \n\n" | |
| "**License:** CC-BY-NC-SA 4.0 β non-commercial research use only." | |
| ) | |
| demo.launch(theme=gr.themes.Base()) |