diff --git a/extras/resampler.py b/extras/resampler.py index 539f309..10e5897 100644 --- a/extras/resampler.py +++ b/extras/resampler.py @@ -1,6 +1,4 @@ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -import math - import torch import torch.nn as nn @@ -65,7 +63,7 @@ class PerceiverAttention(nn.Module): v = reshape_tensor(v, self.heads) # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + scale = 1 / (self.dim_head ** 0.5) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v