4. You have this Transformer decoder code snippet that throws an error:
import torch
from torch import nn
class SimpleDecoder(nn.Module):
def __init__(self):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim=8, num_heads=4)
def forward(self, tgt, memory):
attn_output, _ = self.attention(tgt, memory, memory)
return attn_output
tgt = torch.rand(10, 2, 8) # target seq len=10, batch=2, embed=8
memory = torch.rand(5, 3, 8) # memory seq len=5, batch=3, embed=8
model = SimpleDecoder()
output = model(tgt, memory)
print(output.shape)
What is the likely cause of the error?