Spaces:
Runtime error
Runtime error
more updates
Browse files
README.md
CHANGED
|
@@ -57,7 +57,7 @@ This app is developed using React + Fastapi. You can run this app locally with t
|
|
| 57 |
|
| 58 |
4. **(Optional) Install GPU-related packages**
|
| 59 |
|
| 60 |
-
If you are running on a GPU-enabled device, you can install additional packages (
|
| 61 |
|
| 62 |
```commandline
|
| 63 |
python install_gpu_packages.py
|
|
|
|
| 57 |
|
| 58 |
4. **(Optional) Install GPU-related packages**
|
| 59 |
|
| 60 |
+
If you are running on a GPU-enabled device, you can install additional packages (support for more models):
|
| 61 |
|
| 62 |
```commandline
|
| 63 |
python install_gpu_packages.py
|
backend/__pycache__/hf_model_utils.cpython-313.pyc
CHANGED
|
Binary files a/backend/__pycache__/hf_model_utils.cpython-313.pyc and b/backend/__pycache__/hf_model_utils.cpython-313.pyc differ
|
|
|
backend/hf_model_utils.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
| 3 |
import json
|
| 4 |
import hashlib
|
| 5 |
import gc
|
| 6 |
-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM
|
| 7 |
from accelerate import init_empty_weights
|
| 8 |
|
| 9 |
|
|
@@ -80,10 +80,30 @@ def get_model_structure(model_name: str, model_type: str | None):
|
|
| 80 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 81 |
with init_empty_weights():
|
| 82 |
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
| 83 |
-
|
| 84 |
-
config = AutoConfig.from_pretrained(
|
| 85 |
with init_empty_weights():
|
| 86 |
model = AutoModelForMaskedLM.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 89 |
with torch.device("meta"):
|
|
|
|
| 3 |
import json
|
| 4 |
import hashlib
|
| 5 |
import gc
|
| 6 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
|
| 7 |
from accelerate import init_empty_weights
|
| 8 |
|
| 9 |
|
|
|
|
| 80 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 81 |
with init_empty_weights():
|
| 82 |
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
| 83 |
+
elif model_type == "masked":
|
| 84 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 85 |
with init_empty_weights():
|
| 86 |
model = AutoModelForMaskedLM.from_config(config)
|
| 87 |
+
elif model_type == "sequence":
|
| 88 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 89 |
+
with init_empty_weights():
|
| 90 |
+
model = AutoModelForSequenceClassification.from_config(config)
|
| 91 |
+
elif model_type == "token":
|
| 92 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 93 |
+
with init_empty_weights():
|
| 94 |
+
model = AutoModelForTokenClassification.from_config(config)
|
| 95 |
+
elif model_type == "qa":
|
| 96 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 97 |
+
with init_empty_weights():
|
| 98 |
+
model = AutoModelForQuestionAnswering.from_config(config)
|
| 99 |
+
elif model_type == "s2s":
|
| 100 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 101 |
+
with init_empty_weights():
|
| 102 |
+
model = AutoModelForSeq2SeqLM.from_config(config)
|
| 103 |
+
elif model_type == "vision":
|
| 104 |
+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 105 |
+
with init_empty_weights():
|
| 106 |
+
model = AutoModelForImageClassification.from_config(config)
|
| 107 |
else:
|
| 108 |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 109 |
with torch.device("meta"):
|
frontend/src/components/ModelInputBar.jsx
CHANGED
|
@@ -5,7 +5,7 @@ export default function ModelInputBar({ loading, fetchModelStructure }) {
|
|
| 5 |
const options = [
|
| 6 |
{ label: "Not Sure", value: "none", default: "deepseek-ai/DeepSeek-V3.1" },
|
| 7 |
{ label: "Causal Language Models (e.g. GPT, LLaMA, Phi, Mistral)", value: "causal", default: "gpt2"},
|
| 8 |
-
{ label: "Masked Language Models (BERT, RoBERTa, DistilBERT)", value: "masked", default: "
|
| 9 |
{ label: "Sequence Classification (text classification, sentiment analysis)", value: "sequence", default: "distilbert-base-uncased" },
|
| 10 |
{ label: "Token Classification (NER, POS tagging)", value: "token", default: "dbmdz/bert-large-cased-finetuned-conll03-english" },
|
| 11 |
{ label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },
|
|
|
|
| 5 |
const options = [
|
| 6 |
{ label: "Not Sure", value: "none", default: "deepseek-ai/DeepSeek-V3.1" },
|
| 7 |
{ label: "Causal Language Models (e.g. GPT, LLaMA, Phi, Mistral)", value: "causal", default: "gpt2"},
|
| 8 |
+
{ label: "Masked Language Models (BERT, RoBERTa, DistilBERT)", value: "masked", default: "google-bert/bert-base-uncased" },
|
| 9 |
{ label: "Sequence Classification (text classification, sentiment analysis)", value: "sequence", default: "distilbert-base-uncased" },
|
| 10 |
{ label: "Token Classification (NER, POS tagging)", value: "token", default: "dbmdz/bert-large-cased-finetuned-conll03-english" },
|
| 11 |
{ label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },
|