diff --git a/extras/BLIP/models/__init__.py b/extras/BLIP/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/extras/BLIP/models/blip.py b/extras/BLIP/models/blip.py index 38678f6..ea8de37 100644 --- a/extras/BLIP/models/blip.py +++ b/extras/BLIP/models/blip.py @@ -8,8 +8,8 @@ import warnings warnings.filterwarnings("ignore") -from models.vit import VisionTransformer, interpolate_pos_embed -from models.med import BertConfig, BertModel, BertLMHeadModel +from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed +from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel from transformers import BertTokenizer import torch diff --git a/extras/BLIP/models/blip_itm.py b/extras/BLIP/models/blip_itm.py index cf354c8..6f4da82 100644 --- a/extras/BLIP/models/blip_itm.py +++ b/extras/BLIP/models/blip_itm.py @@ -1,11 +1,11 @@ -from models.med import BertConfig, BertModel +from extras.BLIP.models.med import BertConfig, BertModel from transformers import BertTokenizer import torch from torch import nn import torch.nn.functional as F -from models.blip import create_vit, init_tokenizer, load_checkpoint +from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint class BLIP_ITM(nn.Module): def __init__(self, diff --git a/extras/BLIP/models/blip_nlvr.py b/extras/BLIP/models/blip_nlvr.py index 8483716..0eb9eaa 100644 --- a/extras/BLIP/models/blip_nlvr.py +++ b/extras/BLIP/models/blip_nlvr.py @@ -1,7 +1,7 @@ -from models.med import BertConfig -from models.nlvr_encoder import BertModel -from models.vit import interpolate_pos_embed -from models.blip import create_vit, init_tokenizer, is_url +from extras.BLIP.models.med import BertConfig +from extras.BLIP.models.nlvr_encoder import BertModel +from extras.BLIP.models.vit import interpolate_pos_embed +from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url from timm.models.hub import download_cached_file @@ -10,6 +10,8 @@ from torch import nn import torch.nn.functional as F from transformers import BertTokenizer import numpy as np +import os + class BLIP_NLVR(nn.Module): def __init__(self, diff --git a/extras/BLIP/models/blip_pretrain.py b/extras/BLIP/models/blip_pretrain.py index e42ce5f..9b8a3a4 100644 --- a/extras/BLIP/models/blip_pretrain.py +++ b/extras/BLIP/models/blip_pretrain.py @@ -5,7 +5,7 @@ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * By Junnan Li ''' -from models.med import BertConfig, BertModel, BertLMHeadModel +from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel from transformers import BertTokenizer import transformers transformers.logging.set_verbosity_error() @@ -14,7 +14,7 @@ import torch from torch import nn import torch.nn.functional as F -from models.blip import create_vit, init_tokenizer, load_checkpoint +from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint class BLIP_Pretrain(nn.Module): def __init__(self, @@ -270,7 +270,7 @@ from typing import List def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): uninitialized_encoder_weights: List[str] = [] if decoder.__class__ != encoder.__class__: - logger.info( + print( f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." ) diff --git a/extras/BLIP/models/blip_retrieval.py b/extras/BLIP/models/blip_retrieval.py index 1debe7e..0949358 100644 --- a/extras/BLIP/models/blip_retrieval.py +++ b/extras/BLIP/models/blip_retrieval.py @@ -1,11 +1,11 @@ -from models.med import BertConfig, BertModel +from extras.BLIP.models.med import BertConfig, BertModel from transformers import BertTokenizer import torch from torch import nn import torch.nn.functional as F -from models.blip import create_vit, init_tokenizer, load_checkpoint +from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint class BLIP_Retrieval(nn.Module): def __init__(self, diff --git a/extras/BLIP/models/blip_vqa.py b/extras/BLIP/models/blip_vqa.py index d4cb368..99928a8 100644 --- a/extras/BLIP/models/blip_vqa.py +++ b/extras/BLIP/models/blip_vqa.py @@ -1,5 +1,5 @@ -from models.med import BertConfig, BertModel, BertLMHeadModel -from models.blip import create_vit, init_tokenizer, load_checkpoint +from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel +from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint import torch from torch import nn diff --git a/extras/BLIP/models/vit.py b/extras/BLIP/models/vit.py index cec3d8e..91c0ada 100644 --- a/extras/BLIP/models/vit.py +++ b/extras/BLIP/models/vit.py @@ -18,7 +18,10 @@ from timm.models.registry import register_model from timm.models.layers import trunc_normal_, DropPath from timm.models.helpers import named_apply, adapt_input_conv -from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def checkpoint_wrapper(x): + return x + class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks diff --git a/extras/interrogate.py b/extras/interrogate.py index 484c5bf..410d685 100644 --- a/extras/interrogate.py +++ b/extras/interrogate.py @@ -1,5 +1,4 @@ import os -import sys import torch import ldm_patched.modules.model_management as model_management @@ -8,19 +7,11 @@ from torchvision.transforms.functional import InterpolationMode from modules.model_loader import load_file_from_url from modules.config import path_clip_vision from ldm_patched.modules.model_patcher import ModelPatcher +from extras.BLIP.models.blip import blip_decoder blip_image_eval_size = 384 blip_repo_root = os.path.join(os.path.dirname(__file__), 'BLIP') -sys.path.append(blip_repo_root) - - -class FakeFairscale: - def checkpoint_wrapper(self): - pass - - -sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale class Interrogator: @@ -34,16 +25,14 @@ class Interrogator: @torch.inference_mode() def interrogate(self, img_rgb): if self.blip_model is None: - import models.blip - filename = load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/model_base_caption_capfilt_large.pth', model_dir=path_clip_vision, file_name='model_base_caption_capfilt_large.pth', ) - model = models.blip.blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base', - med_config=os.path.join(blip_repo_root, "configs", "med_config.json")) + model = blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base', + med_config=os.path.join(blip_repo_root, "configs", "med_config.json")) model.eval() self.load_device = model_management.text_encoder_device() diff --git a/fooocus_version.py b/fooocus_version.py index 2284de5..e7ababc 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.832' +version = '2.1.834'