Jason-thingnario's picture
upload DDPM inference script
be89dda
import torch
import numpy as np
def variance_scaling(scale, mode, distribution,
in_axis=1, out_axis=0,
dtype=torch.float32,
device='cpu'):
"""Ported from JAX. """
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
def default_init(scale=1.):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, 'fan_avg', 'uniform')