progs2002's picture
updated README.md, changed slider default values
c41f655
raw
history blame contribute delete
938 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
class LLM:
def __init__(self):
self.model = AutoModelForCausalLM.from_pretrained('progs2002/star-trek-tng-script-generator')
self.tokenizer = AutoTokenizer.from_pretrained('progs2002/star-trek-tng-script-generator')
def generate(self, text, max_len=512, temp=1, k=50, p=0.95):
encoded_prompt = self.tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
output_tokens = self.model.generate(
input_ids = encoded_prompt,
max_new_tokens = max_len,
do_sample=True,
num_return_sequences=1,
pad_token_id=self.model.config.eos_token_id,
temperature=temp,
top_k=k,
top_p=p
)
text_out = self.tokenizer.decode(output_tokens[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
return text_out