wang_init#

wang_init(dim, num_layers)#

Depth-scaled initializer (Wang et al.).

Computes std = 2 / (num_layers * sqrt(dim)) and returns a normal initializer with that standard deviation.

Parameters:
  • dim (int) – Layer width used to compute the standard deviation.

  • num_layers (int) – Total number of layers in the network.

Returns:

A callable fn(tensor) -> tensor that initializes the tensor in-place with normal_(mean=0, std=2 / (num_layers * sqrt(dim))).

Return type:

Callable[[Tensor], Tensor]