use SOTA sampling for GPT2
This commit is contained in:
		
							parent
							
								
									1a088db0eb
								
							
						
					
					
						commit
						1964aec7f8
					
				@ -50,9 +50,11 @@ class FooocusExpansion:
 | 
			
		||||
        tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
 | 
			
		||||
        tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
 | 
			
		||||
 | 
			
		||||
        features = self.model.generate(**tokenized_kwargs, num_beams=5, do_sample=True, max_new_tokens=256)
 | 
			
		||||
        response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
 | 
			
		||||
        # https://huggingface.co/blog/introducing-csearch
 | 
			
		||||
        # https://huggingface.co/docs/transformers/generation_strategies
 | 
			
		||||
        features = self.model.generate(**tokenized_kwargs, penalty_alpha=0.8, top_k=8, max_new_tokens=256)
 | 
			
		||||
 | 
			
		||||
        response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
 | 
			
		||||
        result = response[0][len(origin):]
 | 
			
		||||
        result = safe_str(result)
 | 
			
		||||
        result = remove_pattern(result, dangrous_patterns)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user