Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import torch | |
| MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_URL, | |
| low_cpu_mem_usage=True, | |
| return_dict=True, | |
| torch_dtype=torch.float16, | |
| device_map="cpu") | |
| def prediction(text): | |
| # create pipeline | |
| pipe = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16, | |
| device_map="cpu",) | |
| prompt = f"""Classify the text into Normal, Depression, Anxiety, Bipolar, and return the answer as the corresponding mental health disorder label. | |
| text: {text} | |
| label: """.strip() | |
| outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1) | |
| preds = outputs[0]["generated_text"].split("label: ")[-1].strip() | |
| return preds | |
| gradio_ui = gr.Interface( | |
| fn=prediction, | |
| title="Mental Health Disorder Classification", | |
| description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.", | |
| examples=[ | |
| ['trouble sleeping, confused mind, restless heart. All out of tune'], | |
| ["In the quiet hours, even the shadows seem too heavy to bear."], | |
| ["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."] | |
| ], | |
| inputs=gr.Textbox(lines=10, label="Write the text here"), | |
| outputs=gr.Label(num_top_classes=4, label="Mental Health Disorder Category"), | |
| theme= gr.themes.Soft(), | |
| article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>", | |
| ) | |
| gradio_ui.launch() |