hiitsmeme
initial commit
b25d2b6
raw
history blame
698 Bytes
from itertools import product
import random
def generate_random_search(grid, num_trials=300, seed=42):
random.seed(seed)
keys = list(grid.keys())
values = list(grid.values())
all_combinations = []
for combo in product(*values):
params = dict(zip(keys, combo))
# Convert ratios to actual LR values
params["real_init_lr"] = params["max_lr"] / params["init_lr"]
params["real_final_lr"] = params["max_lr"] / params["final_lr"]
all_combinations.append(params)
# sample num_runs out of this
indices = random.sample(range(len(all_combinations)), num_trials)
hp_subset = [all_combinations[i] for i in indices]
return hp_subset