partial_wang_init_fn_with_num_layers#

partial_wang_init_fn_with_num_layers(num_layers)#

Factory that returns partial(wang_init, num_layers=...).

Useful with LazyConfig so that num_layers can be provided via OmegaConf interpolation (e.g., "${net.num_blocks}") and resolved before constructing the callable.

Parameters:

num_layers (int) – Total number of layers, baked into the returned factory.

Returns:

A callable fn(dim) -> fn(tensor) compatible with MLP’s init_method_in / init_method_out signature.

Return type:

Callable[[int], Callable[[Tensor], Tensor]]