replaced concatenation by the f-strings to avoid potential type's mismatch and make code more clear
This commit is contained in:
parent
978267f461
commit
c00559958d
@ -301,7 +301,7 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
|
||||
len(encoder_modules) > 0
|
||||
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
|
||||
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
||||
all_encoder_weights = set([f"{module_name}/{sub_name}" for sub_name in encoder_modules.keys()])
|
||||
encoder_layer_pos = 0
|
||||
for name, module in decoder_modules.items():
|
||||
if name.isdigit():
|
||||
@ -326,12 +326,12 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
|
||||
tie_encoder_to_decoder_recursively(
|
||||
decoder_modules[decoder_name],
|
||||
encoder_modules[encoder_name],
|
||||
module_name + "/" + name,
|
||||
f"{module_name}/{name}",
|
||||
uninitialized_encoder_weights,
|
||||
skip_key,
|
||||
depth=depth + 1,
|
||||
)
|
||||
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
||||
all_encoder_weights.remove(f"{module_name}/{encoder_name}")
|
||||
|
||||
uninitialized_encoder_weights += list(all_encoder_weights)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user