File size: 914 Bytes
9bee602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

MODEL_NAME = "Yuk050/gemma-3-1b-text-to-sql-model"
LOCAL_DIR = "./model_cache"

_tokenizer = None
_model = None

def load_model():
    global _tokenizer, _model
    if _tokenizer is not None and _model is not None:
        return _tokenizer, _model

    print("πŸ”„ Loading model...")
    if os.path.exists(LOCAL_DIR):
        _tokenizer = AutoTokenizer.from_pretrained(LOCAL_DIR)
        _model = AutoModelForCausalLM.from_pretrained(LOCAL_DIR, trust_remote_code=True)
    else:
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        _model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
        os.makedirs(LOCAL_DIR, exist_ok=True)
        _tokenizer.save_pretrained(LOCAL_DIR)
        _model.save_pretrained(LOCAL_DIR)

    print("βœ… Model loaded successfully!")
    return _tokenizer, _model