use our blip
This commit is contained in:
parent
c175afb394
commit
322aa5a724
@ -8,8 +8,8 @@
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from models.vit import VisionTransformer, interpolate_pos_embed
|
||||
from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed
|
||||
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
import torch
|
||||
|
@ -1,11 +1,11 @@
|
||||
from models.med import BertConfig, BertModel
|
||||
from extras.BLIP.models.med import BertConfig, BertModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
|
||||
class BLIP_ITM(nn.Module):
|
||||
def __init__(self,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from models.med import BertConfig
|
||||
from models.nlvr_encoder import BertModel
|
||||
from models.vit import interpolate_pos_embed
|
||||
from models.blip import create_vit, init_tokenizer, is_url
|
||||
from extras.BLIP.models.med import BertConfig
|
||||
from extras.BLIP.models.nlvr_encoder import BertModel
|
||||
from extras.BLIP.models.vit import interpolate_pos_embed
|
||||
from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url
|
||||
|
||||
from timm.models.hub import download_cached_file
|
||||
|
||||
@ -10,6 +10,8 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BertTokenizer
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class BLIP_NLVR(nn.Module):
|
||||
def __init__(self,
|
||||
|
@ -5,7 +5,7 @@
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
'''
|
||||
from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from transformers import BertTokenizer
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -14,7 +14,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
|
||||
class BLIP_Pretrain(nn.Module):
|
||||
def __init__(self,
|
||||
@ -270,7 +270,7 @@ from typing import List
|
||||
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
if decoder.__class__ != encoder.__class__:
|
||||
logger.info(
|
||||
print(
|
||||
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
from models.med import BertConfig, BertModel
|
||||
from extras.BLIP.models.med import BertConfig, BertModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
|
||||
class BLIP_Retrieval(nn.Module):
|
||||
def __init__(self,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -18,7 +18,10 @@ from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
from timm.models.helpers import named_apply, adapt_input_conv
|
||||
|
||||
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||
|
||||
def checkpoint_wrapper(x):
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
|
||||
@ -8,19 +7,11 @@ from torchvision.transforms.functional import InterpolationMode
|
||||
from modules.model_loader import load_file_from_url
|
||||
from modules.config import path_clip_vision
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from extras.BLIP.models.blip import blip_decoder
|
||||
|
||||
|
||||
blip_image_eval_size = 384
|
||||
blip_repo_root = os.path.join(os.path.dirname(__file__), 'BLIP')
|
||||
sys.path.append(blip_repo_root)
|
||||
|
||||
|
||||
class FakeFairscale:
|
||||
def checkpoint_wrapper(self):
|
||||
pass
|
||||
|
||||
|
||||
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
||||
|
||||
|
||||
class Interrogator:
|
||||
@ -34,16 +25,14 @@ class Interrogator:
|
||||
@torch.inference_mode()
|
||||
def interrogate(self, img_rgb):
|
||||
if self.blip_model is None:
|
||||
import models.blip
|
||||
|
||||
filename = load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/model_base_caption_capfilt_large.pth',
|
||||
model_dir=path_clip_vision,
|
||||
file_name='model_base_caption_capfilt_large.pth',
|
||||
)
|
||||
|
||||
model = models.blip.blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base',
|
||||
med_config=os.path.join(blip_repo_root, "configs", "med_config.json"))
|
||||
model = blip_decoder(pretrained=filename, image_size=blip_image_eval_size, vit='base',
|
||||
med_config=os.path.join(blip_repo_root, "configs", "med_config.json"))
|
||||
model.eval()
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
|
@ -1 +1 @@
|
||||
version = '2.1.832'
|
||||
version = '2.1.834'
|
||||
|
Loading…
Reference in New Issue
Block a user