maomao88 commited on
Commit
119afbd
·
1 Parent(s): 5850346

more updates

Browse files
README.md CHANGED
@@ -57,7 +57,7 @@ This app is developed using React + Fastapi. You can run this app locally with t
57
 
58
  4. **(Optional) Install GPU-related packages**
59
 
60
- If you are running on a GPU-enabled device, you can install additional packages (e.g., for faster inference and support for more models):
61
 
62
  ```commandline
63
  python install_gpu_packages.py
 
57
 
58
  4. **(Optional) Install GPU-related packages**
59
 
60
+ If you are running on a GPU-enabled device, you can install additional packages (support for more models):
61
 
62
  ```commandline
63
  python install_gpu_packages.py
backend/__pycache__/hf_model_utils.cpython-313.pyc CHANGED
Binary files a/backend/__pycache__/hf_model_utils.cpython-313.pyc and b/backend/__pycache__/hf_model_utils.cpython-313.pyc differ
 
backend/hf_model_utils.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import json
4
  import hashlib
5
  import gc
6
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM
7
  from accelerate import init_empty_weights
8
 
9
 
@@ -80,10 +80,30 @@ def get_model_structure(model_name: str, model_type: str | None):
80
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
81
  with init_empty_weights():
82
  model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
83
- if model_type == "masked":
84
- config = AutoConfig.from_pretrained("model_name")
85
  with init_empty_weights():
86
  model = AutoModelForMaskedLM.from_config(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
89
  with torch.device("meta"):
 
3
  import json
4
  import hashlib
5
  import gc
6
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForImageClassification
7
  from accelerate import init_empty_weights
8
 
9
 
 
80
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
81
  with init_empty_weights():
82
  model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
83
+ elif model_type == "masked":
84
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
85
  with init_empty_weights():
86
  model = AutoModelForMaskedLM.from_config(config)
87
+ elif model_type == "sequence":
88
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
89
+ with init_empty_weights():
90
+ model = AutoModelForSequenceClassification.from_config(config)
91
+ elif model_type == "token":
92
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
93
+ with init_empty_weights():
94
+ model = AutoModelForTokenClassification.from_config(config)
95
+ elif model_type == "qa":
96
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
97
+ with init_empty_weights():
98
+ model = AutoModelForQuestionAnswering.from_config(config)
99
+ elif model_type == "s2s":
100
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
101
+ with init_empty_weights():
102
+ model = AutoModelForSeq2SeqLM.from_config(config)
103
+ elif model_type == "vision":
104
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
105
+ with init_empty_weights():
106
+ model = AutoModelForImageClassification.from_config(config)
107
  else:
108
  config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
109
  with torch.device("meta"):
frontend/src/components/ModelInputBar.jsx CHANGED
@@ -5,7 +5,7 @@ export default function ModelInputBar({ loading, fetchModelStructure }) {
5
  const options = [
6
  { label: "Not Sure", value: "none", default: "deepseek-ai/DeepSeek-V3.1" },
7
  { label: "Causal Language Models (e.g. GPT, LLaMA, Phi, Mistral)", value: "causal", default: "gpt2"},
8
- { label: "Masked Language Models (BERT, RoBERTa, DistilBERT)", value: "masked", default: "distilbert-base-uncased" },
9
  { label: "Sequence Classification (text classification, sentiment analysis)", value: "sequence", default: "distilbert-base-uncased" },
10
  { label: "Token Classification (NER, POS tagging)", value: "token", default: "dbmdz/bert-large-cased-finetuned-conll03-english" },
11
  { label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },
 
5
  const options = [
6
  { label: "Not Sure", value: "none", default: "deepseek-ai/DeepSeek-V3.1" },
7
  { label: "Causal Language Models (e.g. GPT, LLaMA, Phi, Mistral)", value: "causal", default: "gpt2"},
8
+ { label: "Masked Language Models (BERT, RoBERTa, DistilBERT)", value: "masked", default: "google-bert/bert-base-uncased" },
9
  { label: "Sequence Classification (text classification, sentiment analysis)", value: "sequence", default: "distilbert-base-uncased" },
10
  { label: "Token Classification (NER, POS tagging)", value: "token", default: "dbmdz/bert-large-cased-finetuned-conll03-english" },
11
  { label: "Question Answering Models (e.g. BERT QA, RoBERTa QA)", value: "qa", default: "distilbert-base-uncased-distilled-squad" },