trunc_normal_init_factory#

trunc_normal_init_factory(std=0.02)#

Factory that returns fn(dim) -> fn(tensor) for truncated-normal init.

The dim argument is accepted but ignored — the standard deviation is fixed. This makes the returned callable compatible with MLP’s init_method_in / init_method_out curried signature.

Parameters:

std (float) – Standard deviation for the truncated normal distribution.

Returns:

A callable fn(dim) -> fn(tensor) compatible with MLP’s init_method_in / init_method_out signature.

Return type:

Callable[[int], Callable[[Tensor], Tensor]]