trunc_normal_init#

trunc_normal_init(std=0.02)#

Truncated-normal initializer with fixed standard deviation.

Parameters:

std (float) – Standard deviation for the truncated normal distribution.

Returns:

A callable fn(tensor) -> tensor that initializes the tensor in-place with trunc_normal_(mean=0, std=std).

Return type:

Callable[[Tensor], Tensor]