Fooocus/modules/expansion.py
2023-09-10 12:16:35 -07:00

47 lines
1.3 KiB
Python

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from modules.path import fooocus_expansion_path
fooocus_magic_split = [
', extremely',
', trending',
', best',
', with',
', perfect',
', harmonious',
', consistent',
', intricate',
'. The',
]
def safe_str(x):
x = str(x)
for _ in range(16):
x = x.replace(' ', ' ')
return x.rstrip(",. \r\n")
class FooocusExpansion:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
self.pipe = pipeline('text-generation',
model=self.model,
tokenizer=self.tokenizer,
device='cpu',
torch_dtype=torch.float32)
print('Fooocus Expansion engine loaded.')
def __call__(self, prompt, seed):
seed = int(seed)
set_seed(seed)
prompt = safe_str(prompt) + fooocus_magic_split[seed % len(fooocus_magic_split)]
response = self.pipe(prompt, max_length=len(prompt) + 256)
result = response[0]['generated_text']
result = safe_str(result)
return result