replaced concatenation by the f-strings to avoid potential type's mismatch and make code more clear
This commit is contained in:
parent
978267f461
commit
c00559958d
@ -301,7 +301,7 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
|
|||||||
len(encoder_modules) > 0
|
len(encoder_modules) > 0
|
||||||
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||||
|
|
||||||
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
all_encoder_weights = set([f"{module_name}/{sub_name}" for sub_name in encoder_modules.keys()])
|
||||||
encoder_layer_pos = 0
|
encoder_layer_pos = 0
|
||||||
for name, module in decoder_modules.items():
|
for name, module in decoder_modules.items():
|
||||||
if name.isdigit():
|
if name.isdigit():
|
||||||
@ -326,12 +326,12 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod
|
|||||||
tie_encoder_to_decoder_recursively(
|
tie_encoder_to_decoder_recursively(
|
||||||
decoder_modules[decoder_name],
|
decoder_modules[decoder_name],
|
||||||
encoder_modules[encoder_name],
|
encoder_modules[encoder_name],
|
||||||
module_name + "/" + name,
|
f"{module_name}/{name}",
|
||||||
uninitialized_encoder_weights,
|
uninitialized_encoder_weights,
|
||||||
skip_key,
|
skip_key,
|
||||||
depth=depth + 1,
|
depth=depth + 1,
|
||||||
)
|
)
|
||||||
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
all_encoder_weights.remove(f"{module_name}/{encoder_name}")
|
||||||
|
|
||||||
uninitialized_encoder_weights += list(all_encoder_weights)
|
uninitialized_encoder_weights += list(all_encoder_weights)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user