Merge 72a07190bf
into 978267f461
This commit is contained in:
commit
ca22cdc697
@ -1,6 +1,4 @@
|
|||||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -65,7 +63,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
v = reshape_tensor(v, self.heads)
|
v = reshape_tensor(v, self.heads)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
scale = 1 / (self.dim_head ** 0.5)
|
||||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
out = weight @ v
|
out = weight @ v
|
||||||
|
Loading…
Reference in New Issue
Block a user