diff --git a/entry.py b/entry.py index 574e8f9..8b0b60b 100644 --- a/entry.py +++ b/entry.py @@ -130,10 +130,14 @@ def denoiser(input, sigma, c): return model.denoiser(model.model, input, sigma, c) -model.model.to(torch.float16).cuda() -model.denoiser.to(torch.float16).cuda() -samples_z = sampler(denoiser, randn, cond=c, uc=uc) -model.model.cpu() -model.denoiser.cpu() +with torch.no_grad(): + model.model.to(torch.float16).cuda() + model.denoiser.to(torch.float16).cuda() + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + model.model.cpu() + model.denoiser.cpu() + +torch.cuda.empty_cache() +torch.cuda.ipc_collect() a = 0