use our blip

This commit is contained in:
lllyasviel 2023-12-12 21:07:39 -08:00
parent c175afb394
commit 322aa5a724
10 changed files with 25 additions and 31 deletions

View File

@ -8,8 +8,8 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
from models.vit import VisionTransformer, interpolate_pos_embed from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed
from models.med import BertConfig, BertModel, BertLMHeadModel from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer from transformers import BertTokenizer
import torch import torch

View File

@ -1,11 +1,11 @@
from models.med import BertConfig, BertModel from extras.BLIP.models.med import BertConfig, BertModel
from transformers import BertTokenizer from transformers import BertTokenizer
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_ITM(nn.Module): class BLIP_ITM(nn.Module):
def __init__(self, def __init__(self,

View File

@ -1,7 +1,7 @@
from models.med import BertConfig from extras.BLIP.models.med import BertConfig
from models.nlvr_encoder import BertModel from extras.BLIP.models.nlvr_encoder import BertModel
from models.vit import interpolate_pos_embed from extras.BLIP.models.vit import interpolate_pos_embed
from models.blip import create_vit, init_tokenizer, is_url from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url
from timm.models.hub import download_cached_file from timm.models.hub import download_cached_file
@ -10,6 +10,8 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BertTokenizer from transformers import BertTokenizer
import numpy as np import numpy as np
import os
class BLIP_NLVR(nn.Module): class BLIP_NLVR(nn.Module):
def __init__(self, def __init__(self,

View File

@ -5,7 +5,7 @@
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li * By Junnan Li
''' '''
from models.med import BertConfig, BertModel, BertLMHeadModel from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer from transformers import BertTokenizer
import transformers import transformers
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -14,7 +14,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Pretrain(nn.Module): class BLIP_Pretrain(nn.Module):
def __init__(self, def __init__(self,
@ -270,7 +270,7 @@ from typing import List
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
uninitialized_encoder_weights: List[str] = [] uninitialized_encoder_weights: List[str] = []
if decoder.__class__ != encoder.__class__: if decoder.__class__ != encoder.__class__:
logger.info( print(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
) )

View File

@ -1,11 +1,11 @@
from models.med import BertConfig, BertModel from extras.BLIP.models.med import BertConfig, BertModel
from transformers import BertTokenizer from transformers import BertTokenizer
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Retrieval(nn.Module): class BLIP_Retrieval(nn.Module):
def __init__(self, def __init__(self,

View File

@ -1,5 +1,5 @@
from models.med import BertConfig, BertModel, BertLMHeadModel from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from models.blip import create_vit, init_tokenizer, load_checkpoint from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
import torch import torch
from torch import nn from torch import nn

View File

@ -18,7 +18,10 @@ from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv from timm.models.helpers import named_apply, adapt_input_conv
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
def checkpoint_wrapper(x):
return x
class Mlp(nn.Module): class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks """ MLP as used in Vision Transformer, MLP-Mixer and related networks

View File

@ -1,5 +1,4 @@
import os import os
import sys
import torch import torch
import ldm_patched.modules.model_management as model_management import ldm_patched.modules.model_management as model_management
@ -8,19 +7,11 @@ from torchvision.transforms.functional import InterpolationMode
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.config import path_clip_vision from modules.config import path_clip_vision
from ldm_patched.modules.model_patcher import ModelPatcher from ldm_patched.modules.model_patcher import ModelPatcher
from extras.BLIP.models.blip import blip_decoder
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_repo_root = os.path.join(os.path.dirname(__file__), 'BLIP') blip_repo_root = os.path.join(os.path.dirname(__file__), 'BLIP')
sys.path.append(blip_repo_root)
class FakeFairscale:
def checkpoint_wrapper(self):
pass
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
class Interrogator: class Interrogator:
@ -34,16 +25,14 @@ class Interrogator:
@torch.inference_mode() @torch.inference_mode()
def interrogate(self, img_rgb): def interrogate(self, img_rgb):
if self.blip_model is None: if self.blip_model is None:
import models.blip
filename = load_file_from_url( filename = load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/model_base_caption_capfilt_large.pth', url='https://huggingface.co/lllyasviel/misc/resolve/main/model_base_caption_capfilt_large.pth',
model_dir=path_clip_vision, model_dir=path_clip_vision,
file_name='model_base_caption_capfilt_large.pth', file_name='model_base_caption_capfilt_large.pth',
) )
model = models.blip.blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base', model = blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base',
med_config=os.path.join(blip_repo_root, "configs", "med_config.json")) med_config=os.path.join(blip_repo_root, "configs", "med_config.json"))
model.eval() model.eval()
self.load_device = model_management.text_encoder_device() self.load_device = model_management.text_encoder_device()

View File

@ -1 +1 @@
version = '2.1.832' version = '2.1.834'