From 72a07190bf34e678b9bc8cfcfa932180ea245fe1 Mon Sep 17 00:00:00 2001 From: "Dr. Christoph Mittendorf" <34183942+Cassini-chris@users.noreply.github.com> Date: Fri, 5 Jan 2024 21:37:33 +0100 Subject: [PATCH] Excluding the math library INSTEAD using the pow() function scale = 1 / (self.dim_head ** 0.5) does the same as scale = 1 / math.sqrt(math.sqrt(self.dim_head)) therefore we do not need to import math here and safe some. Recommended: Import libraries only when you need them: This will reduce the number of times that the interpreter needs to load the library's code. --- extras/resampler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/extras/resampler.py b/extras/resampler.py index 539f309..10e5897 100644 --- a/extras/resampler.py +++ b/extras/resampler.py @@ -1,6 +1,4 @@ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -import math - import torch import torch.nn as nn @@ -65,7 +63,7 @@ class PerceiverAttention(nn.Module): v = reshape_tensor(v, self.heads) # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + scale = 1 / (self.dim_head ** 0.5) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v