replaced concatenation by the f-strings to avoid potential type's mismatch and make code more clear

This commit is contained in:
igeni 2024-03-22 19:22:44 +03:00
parent 978267f461
commit c00559958d

View File

@ -301,7 +301,7 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
len(encoder_modules) > 0 len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) all_encoder_weights = set([f"{module_name}/{sub_name}" for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0 encoder_layer_pos = 0
for name, module in decoder_modules.items(): for name, module in decoder_modules.items():
if name.isdigit(): if name.isdigit():
@ -326,12 +326,12 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
tie_encoder_to_decoder_recursively( tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name], decoder_modules[decoder_name],
encoder_modules[encoder_name], encoder_modules[encoder_name],
module_name + "/" + name, f"{module_name}/{name}",
uninitialized_encoder_weights, uninitialized_encoder_weights,
skip_key, skip_key,
depth=depth + 1, depth=depth + 1,
) )
all_encoder_weights.remove(module_name + "/" + encoder_name) all_encoder_weights.remove(f"{module_name}/{encoder_name}")
uninitialized_encoder_weights += list(all_encoder_weights) uninitialized_encoder_weights += list(all_encoder_weights)