import os from huggingface_hub import hf_hub_download from nail_classification.inference import Inference def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str: try: return hf_hub_download(repo_id, filename, token=token) except TypeError: # Backward compatibility for older huggingface_hub releases. return hf_hub_download(repo_id, filename, use_auth_token=token) class Model: def __init__(self, DEBUG): if DEBUG: base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1" file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)] else: file_paths = [ _hf_hub_download_compat( "lfolle/DeepNAPSIModel", f"version_{v}.ckpt", os.environ["DeepNAPSIModel"] ) for v in [10, 11, 12, 13, 14] ] self.inference = Inference(file_paths) def predict(self, x): y_hat, uncertainty = self.inference.predict(x) return y_hat, uncertainty def __call__(self, x): return self.predict(x)