Shaoan commited on
Commit
cc9a502
·
verified ·
1 Parent(s): 6f0d4df

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. text_encoder.py +2 -2
text_encoder.py CHANGED
@@ -154,7 +154,7 @@ class T5Embedder(torch.nn.Module):
154
  self.max_length = max_length
155
  dtype = torch.bfloat16
156
  self.dtype = dtype
157
- t5_version = './t5-v1_1-xxl'
158
  self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
159
  self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device)
160
  self.t5_encoder = self.t5_encoder.eval().requires_grad_(False)
@@ -214,7 +214,7 @@ class LoraT5EmbedderNoGradientCheck(torch.nn.Module):
214
  self.max_length = max_length
215
  dtype = torch.bfloat16
216
  self.dtype = dtype
217
- t5_version = './t5-v1_1-xxl'
218
  self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
219
  self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype)
220
  self.t5_encoder.gradient_checkpointing_enable()
 
154
  self.max_length = max_length
155
  dtype = torch.bfloat16
156
  self.dtype = dtype
157
+ t5_version = 'google/t5-v1_1-xxl'
158
  self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
159
  self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device)
160
  self.t5_encoder = self.t5_encoder.eval().requires_grad_(False)
 
214
  self.max_length = max_length
215
  dtype = torch.bfloat16
216
  self.dtype = dtype
217
+ t5_version = 'google/t5-v1_1-xxl'
218
  self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
219
  self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype)
220
  self.t5_encoder.gradient_checkpointing_enable()