This commit is contained in:
Dr. Christoph Mittendorf 2024-03-21 20:47:38 +02:00 committed by GitHub
commit ca22cdc697
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,4 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -65,7 +63,7 @@ class PerceiverAttention(nn.Module):
v = reshape_tensor(v, self.heads) v = reshape_tensor(v, self.heads)
# attention # attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head)) scale = 1 / (self.dim_head ** 0.5)
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v out = weight @ v