| ### Train Dataset Means and stds | |
| lat_mean = 39.951572994535354 | |
| lat_std = 0.0006556104083785816 | |
| lon_mean = -75.19137012508818 | |
| lon_std = 0.0006895844560639971 | |
| ### Custom Model Class | |
| from transformers import ViTModel | |
| class ViTGPSModel(nn.Module): | |
| def __init__(self, output_size=2): | |
| super().__init__() | |
| self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") | |
| self.regression_head = nn.Linear(self.vit.config.hidden_size, output_size) | |
| def forward(self, x): | |
| cls_embedding = self.vit(x).last_hidden_state[:, 0, :] | |
| return self.regression_head(cls_embedding) | |
| ### Running Inference | |
| model_path = hf_hub_download(repo_id="Latitude-Attitude/vit-gps-coordinates-predictor", filename="vit-gps-coordinates-predictor.pth") | |
| model = torch.load(model_path) | |
| model.eval() | |
| with torch.no_grad(): | |
| for images in dataloader: | |
| images = images.to(device) | |
| outputs = model(images) | |
| preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |