From f7f548ff355f61e07aa3dd7223d4c25fa8817217 Mon Sep 17 00:00:00 2001 From: lvmin Date: Wed, 13 Sep 2023 18:10:17 -0700 Subject: [PATCH] use SOTA sampling for GPT2 --- fooocus_version.py | 2 +- modules/expansion.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/fooocus_version.py b/fooocus_version.py index 14c910b..08bb1e6 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.8' +version = '2.0.9' diff --git a/modules/expansion.py b/modules/expansion.py index 716d343..f2dba0f 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -1,7 +1,6 @@ -import torch import comfy.model_management as model_management -from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed +from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from modules.path import fooocus_expansion_path from comfy.sd import ModelPatcher @@ -28,35 +27,35 @@ def remove_pattern(x, pattern): class FooocusExpansion: def __init__(self): - use_fp16 = model_management.should_use_fp16() + self.use_fp16 = model_management.should_use_fp16() self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path) self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path) - if use_fp16: + if self.use_fp16: self.model.half() load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device) - self.pipe = pipeline('text-generation', - model=self.model, - tokenizer=self.tokenizer, - device='cpu', - torch_dtype=torch.float16 if use_fp16 else torch.float32) - print(f'Fooocus Expansion engine loaded.') def __call__(self, prompt, seed): model_management.load_model_gpu(self.patcher) - self.pipe.device = self.patcher.load_device seed = int(seed) set_seed(seed) origin = safe_str(prompt) prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)] - response = self.pipe(prompt, max_length=len(prompt) + 256) - result = response[0]['generated_text'][len(origin):] + + tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt") + tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device) + tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device) + + features = self.model.generate(**tokenized_kwargs, num_beams=5, do_sample=True, max_new_tokens=256) + response = self.tokenizer.batch_decode(features, skip_special_tokens=True) + + result = response[0][len(origin):] result = safe_str(result) result = remove_pattern(result, dangrous_patterns) return result