-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_classifier.py
More file actions
33 lines (29 loc) · 846 Bytes
/
model_classifier.py
File metadata and controls
33 lines (29 loc) · 846 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from modules import ISAB, PMA, SAB
class SetTransformer(nn.Module):
def __init__(
self,
dim_input=2,
num_outputs=1,
dim_output=4,
num_inds=32,
dim_hidden=128,
num_heads=4,
ln=False,
):
super(SetTransformer, self).__init__()
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln),
)
self.dec = nn.Sequential(
nn.Dropout(),
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Dropout(),
nn.Linear(dim_hidden, dim_output),
)
def forward(self, X):
return self.dec(self.enc(X)).squeeze()