| from abc import ABC, abstractmethod | |
| import torch | |
| from transformers import PreTrainedTokenizerBase | |
| class SensorLocFinder(ABC): | |
| def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): | |
| pass | |
| def find_sensor_locs(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| pass | |
| def __call__(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| return self.find_sensor_locs(input_ids) | |