| import os | |
| def slow_rerank(args, query, pids, passages): | |
| colbert = args.colbert | |
| inference = args.inference | |
| Q = inference.queryFromText([query]) | |
| D_ = inference.docFromText(passages, bsize=args.bsize) | |
| scores = colbert.score(Q, D_).cpu() | |
| scores = scores.sort(descending=True) | |
| ranked = scores.indices.tolist() | |
| ranked_scores = scores.values.tolist() | |
| ranked_pids = [pids[position] for position in ranked] | |
| ranked_passages = [passages[position] for position in ranked] | |
| assert len(ranked_pids) == len(set(ranked_pids)) | |
| return list(zip(ranked_scores, ranked_pids, ranked_passages)) | |