45 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			45 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
 | 
						|
from modules.path import fooocus_expansion_path
 | 
						|
 | 
						|
 | 
						|
magic_split = [
 | 
						|
    ', extremely',
 | 
						|
    ', trending',
 | 
						|
    ', best',
 | 
						|
    ', with',
 | 
						|
    ', detailed',
 | 
						|
    '. ',
 | 
						|
    '. The',
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def safe_str(x):
 | 
						|
    x = str(x)
 | 
						|
    for _ in range(16):
 | 
						|
        x = x.replace('  ', ' ')
 | 
						|
    return x.rstrip(",. \r\n")
 | 
						|
 | 
						|
 | 
						|
class FooocusExpansion:
 | 
						|
    def __init__(self):
 | 
						|
        self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
 | 
						|
        self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
 | 
						|
        self.pipe = pipeline('text-generation',
 | 
						|
                             model=self.model,
 | 
						|
                             tokenizer=self.tokenizer,
 | 
						|
                             device='cpu',
 | 
						|
                             torch_dtype=torch.float32)
 | 
						|
        print('Fooocus Expansion engine loaded.')
 | 
						|
 | 
						|
    def __call__(self, prompt, seed):
 | 
						|
        seed = int(seed)
 | 
						|
        set_seed(seed)
 | 
						|
 | 
						|
        prompt = safe_str(prompt) + magic_split[seed % len(magic_split)]
 | 
						|
 | 
						|
        response = self.pipe(prompt, max_length=len(prompt) + 256)
 | 
						|
        result = response[0]['generated_text']
 | 
						|
        result = safe_str(result)
 | 
						|
        return result
 |