use SOTA sampling for GPT2

This commit is contained in:
lvmin 2023-09-13 18:10:17 -07:00
parent 0f09b61ce5
commit f7f548ff35
2 changed files with 13 additions and 14 deletions

View File

@ -1 +1 @@
version = '2.0.8'
version = '2.0.9'

View File

@ -1,7 +1,6 @@
import torch
import comfy.model_management as model_management
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules.path import fooocus_expansion_path
from comfy.sd import ModelPatcher
@ -28,35 +27,35 @@ def remove_pattern(x, pattern):
class FooocusExpansion:
def __init__(self):
use_fp16 = model_management.should_use_fp16()
self.use_fp16 = model_management.should_use_fp16()
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
if use_fp16:
if self.use_fp16:
self.model.half()
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
self.pipe = pipeline('text-generation',
model=self.model,
tokenizer=self.tokenizer,
device='cpu',
torch_dtype=torch.float16 if use_fp16 else torch.float32)
print(f'Fooocus Expansion engine loaded.')
def __call__(self, prompt, seed):
model_management.load_model_gpu(self.patcher)
self.pipe.device = self.patcher.load_device
seed = int(seed)
set_seed(seed)
origin = safe_str(prompt)
prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)]
response = self.pipe(prompt, max_length=len(prompt) + 256)
result = response[0]['generated_text'][len(origin):]
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
features = self.model.generate(**tokenized_kwargs, num_beams=5, do_sample=True, max_new_tokens=256)
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
result = response[0][len(origin):]
result = safe_str(result)
result = remove_pattern(result, dangrous_patterns)
return result