import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # --- Configuration --- MODEL_NAME = "NorwAI/NorwAI-Llama2-7B" #"google/gemma-2-9b" # --- Model Loading (Explicit) --- # Use a try-except block to handle potential loading errors try: # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Load the model with appropriate configurations. model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", # Use "auto" to let Transformers handle device placement. torch_dtype=torch.bfloat16, # Use bfloat16 for reduced memory usage (if supported by your hardware). ) except Exception as e: print(f"Error loading model: {e}") # You might want to raise the exception or exit gracefully here. raise # --- Inference Function --- def respond(message, history, system_message, max_tokens, temperature, top_p): try: # Build the conversation history. Use the correct roles ("user", "model"). formatted_history = "" for user_msg, model_msg in history: formatted_history += f"user\n{user_msg}\n" if model_msg: # Check if model_msg is not None formatted_history += f"model\n{model_msg}\n" # Combine system message, history, and current message. prompt = f"system\n{system_message}\n{formatted_history}user\n{message}\nmodel\n" # Tokenize the input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate text with streaming (important for a chatbot). streamer = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, # Enable sampling for more diverse responses. streamer=True, #for stream pad_token_id=tokenizer.eos_token_id ) # Accumulate the response. Decode in chunks. response = "" for chunk in streamer: if chunk is not None: response += tokenizer.decode(chunk[0], skip_special_tokens=True) yield response except Exception as e: print(f"Error during inference: {e}") yield "An error occurred during generation." return # --- Gradio Interface --- demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()