use our blip

This commit is contained in:
lllyasviel 2023-12-12 21:07:39 -08:00
parent c175afb394
commit 322aa5a724
10 changed files with 25 additions and 31 deletions

View File

@ -8,8 +8,8 @@
import warnings
warnings.filterwarnings("ignore")
from models.vit import VisionTransformer, interpolate_pos_embed
from models.med import BertConfig, BertModel, BertLMHeadModel
from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import torch

View File

@ -1,11 +1,11 @@
from models.med import BertConfig, BertModel
from extras.BLIP.models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_ITM(nn.Module):
def __init__(self,

View File

@ -1,7 +1,7 @@
from models.med import BertConfig
from models.nlvr_encoder import BertModel
from models.vit import interpolate_pos_embed
from models.blip import create_vit, init_tokenizer, is_url
from extras.BLIP.models.med import BertConfig
from extras.BLIP.models.nlvr_encoder import BertModel
from extras.BLIP.models.vit import interpolate_pos_embed
from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url
from timm.models.hub import download_cached_file
@ -10,6 +10,8 @@ from torch import nn
import torch.nn.functional as F
from transformers import BertTokenizer
import numpy as np
import os
class BLIP_NLVR(nn.Module):
def __init__(self,

View File

@ -5,7 +5,7 @@
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
from models.med import BertConfig, BertModel, BertLMHeadModel
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import transformers
transformers.logging.set_verbosity_error()
@ -14,7 +14,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Pretrain(nn.Module):
def __init__(self,
@ -270,7 +270,7 @@ from typing import List
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
uninitialized_encoder_weights: List[str] = []
if decoder.__class__ != encoder.__class__:
logger.info(
print(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
)

View File

@ -1,11 +1,11 @@
from models.med import BertConfig, BertModel
from extras.BLIP.models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_Retrieval(nn.Module):
def __init__(self,

View File

@ -1,5 +1,5 @@
from models.med import BertConfig, BertModel, BertLMHeadModel
from models.blip import create_vit, init_tokenizer, load_checkpoint
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
import torch
from torch import nn

View File

@ -18,7 +18,10 @@ from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
def checkpoint_wrapper(x):
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks

View File

@ -1,5 +1,4 @@
import os
import sys
import torch
import ldm_patched.modules.model_management as model_management
@ -8,19 +7,11 @@ from torchvision.transforms.functional import InterpolationMode
from modules.model_loader import load_file_from_url
from modules.config import path_clip_vision
from ldm_patched.modules.model_patcher import ModelPatcher
from extras.BLIP.models.blip import blip_decoder
blip_image_eval_size = 384
blip_repo_root = os.path.join(os.path.dirname(__file__), 'BLIP')
sys.path.append(blip_repo_root)
class FakeFairscale:
def checkpoint_wrapper(self):
pass
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
class Interrogator:
@ -34,16 +25,14 @@ class Interrogator:
@torch.inference_mode()
def interrogate(self, img_rgb):
if self.blip_model is None:
import models.blip
filename = load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/model_base_caption_capfilt_large.pth',
model_dir=path_clip_vision,
file_name='model_base_caption_capfilt_large.pth',
)
model = models.blip.blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base',
med_config=os.path.join(blip_repo_root, "configs", "med_config.json"))
model = blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base',
med_config=os.path.join(blip_repo_root, "configs", "med_config.json"))
model.eval()
self.load_device = model_management.text_encoder_device()

View File

@ -1 +1 @@
version = '2.1.832'
version = '2.1.834'