use SOTA sampling for GPT2
This commit is contained in:
parent
1a088db0eb
commit
1964aec7f8
@ -50,9 +50,11 @@ class FooocusExpansion:
|
|||||||
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
|
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
|
||||||
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
|
tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
|
||||||
|
|
||||||
features = self.model.generate(**tokenized_kwargs, num_beams=5, do_sample=True, max_new_tokens=256)
|
# https://huggingface.co/blog/introducing-csearch
|
||||||
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
|
# https://huggingface.co/docs/transformers/generation_strategies
|
||||||
|
features = self.model.generate(**tokenized_kwargs, penalty_alpha=0.8, top_k=8, max_new_tokens=256)
|
||||||
|
|
||||||
|
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
|
||||||
result = response[0][len(origin):]
|
result = response[0][len(origin):]
|
||||||
result = safe_str(result)
|
result = safe_str(result)
|
||||||
result = remove_pattern(result, dangrous_patterns)
|
result = remove_pattern(result, dangrous_patterns)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user