diff --git a/extras/BLIP/models/blip_pretrain.py b/extras/BLIP/models/blip_pretrain.py index 9b8a3a4..9aba015 100644 --- a/extras/BLIP/models/blip_pretrain.py +++ b/extras/BLIP/models/blip_pretrain.py @@ -301,7 +301,7 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod len(encoder_modules) > 0 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" - all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) + all_encoder_weights = set([f"{module_name}/{sub_name}" for sub_name in encoder_modules.keys()]) encoder_layer_pos = 0 for name, module in decoder_modules.items(): if name.isdigit(): @@ -326,12 +326,12 @@ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_mod tie_encoder_to_decoder_recursively( decoder_modules[decoder_name], encoder_modules[encoder_name], - module_name + "/" + name, + f"{module_name}/{name}", uninitialized_encoder_weights, skip_key, depth=depth + 1, ) - all_encoder_weights.remove(module_name + "/" + encoder_name) + all_encoder_weights.remove(f"{module_name}/{encoder_name}") uninitialized_encoder_weights += list(all_encoder_weights)