import torch import numpy as np def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'): """Ported from JAX. """ def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] fan_in = shape[in_axis] * receptive_field_size fan_out = shape[out_axis] * receptive_field_size return fan_in, fan_out def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( "invalid mode for variance scaling initializer: {}".format(mode)) variance = scale / denominator if distribution == "normal": return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) elif distribution == "uniform": return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) else: raise ValueError("invalid distribution for variance scaling initializer") return init def default_init(scale=1.): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale return variance_scaling(scale, 'fan_avg', 'uniform')