| import torch | |
| import numpy as np | |
| def variance_scaling(scale, mode, distribution, | |
| in_axis=1, out_axis=0, | |
| dtype=torch.float32, | |
| device='cpu'): | |
| """Ported from JAX. """ | |
| def _compute_fans(shape, in_axis=1, out_axis=0): | |
| receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] | |
| fan_in = shape[in_axis] * receptive_field_size | |
| fan_out = shape[out_axis] * receptive_field_size | |
| return fan_in, fan_out | |
| def init(shape, dtype=dtype, device=device): | |
| fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) | |
| if mode == "fan_in": | |
| denominator = fan_in | |
| elif mode == "fan_out": | |
| denominator = fan_out | |
| elif mode == "fan_avg": | |
| denominator = (fan_in + fan_out) / 2 | |
| else: | |
| raise ValueError( | |
| "invalid mode for variance scaling initializer: {}".format(mode)) | |
| variance = scale / denominator | |
| if distribution == "normal": | |
| return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) | |
| elif distribution == "uniform": | |
| return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) | |
| else: | |
| raise ValueError("invalid distribution for variance scaling initializer") | |
| return init | |
| def default_init(scale=1.): | |
| """The same initialization used in DDPM.""" | |
| scale = 1e-10 if scale == 0 else scale | |
| return variance_scaling(scale, 'fan_avg', 'uniform') |