diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..73f69e0
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/image-captioning.iml b/.idea/image-captioning.iml
new file mode 100644
index 0000000..8b8c395
--- /dev/null
+++ b/.idea/image-captioning.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..eeb3b77
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,87 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..79a4365
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..1197701
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/artimes_experiments/iuxray_rgmg/config.yml b/artimes_experiments/iuxray_rgmg/config.yml
new file mode 100644
index 0000000..5dada3f
--- /dev/null
+++ b/artimes_experiments/iuxray_rgmg/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 20
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 100
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/artimes_experiments/iuxray_rgmg/train.sh b/artimes_experiments/iuxray_rgmg/train.sh
new file mode 100644
index 0000000..fca3a35
--- /dev/null
+++ b/artimes_experiments/iuxray_rgmg/train.sh
@@ -0,0 +1,7 @@
+#CUDA_VISIBLE_DEVICES=0 -m torch.distributed.launch --nproc_per_node=1
+python3 main.py --folder ./artimes_experiments/iuxray_rgmg --resume 0 --submodel rgmg --KG_path /project/CVML/pretrained_kg/rgmg_iuxray_pretrain.pth --dataset_name IUXRAY --image_dir /project/CVML/Parallel-R2Gen-KG/data/iu_xray/images/ --ann_path /project/CVML/Parallel-R2Gen-KG/data/iuxray/annotation.json
+
+
+
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/datasets/__init__.py b/datasets/__init__.py
index e69de29..8a3d705 100755
--- a/datasets/__init__.py
+++ b/datasets/__init__.py
@@ -0,0 +1,19 @@
+from datasets.coco_dataset import CocoDataset
+from datasets.radiology_dataset import IUXRAY
+from datasets.radiology_dataset import MIMICCXR
+from datasets.radiology_dataset import MimiccxrMultiImage
+
+__factory = {
+ 'IUXRAY': IUXRAY,
+ 'MIMICCXR': MIMICCXR,
+ 'MIMICCXR_MultiImages': MimiccxrMultiImage,
+ 'COCO': CocoDataset,
+}
+
+def names():
+ return sorted(__factory.keys())
+
+def create(name, *args, **kwargs):
+ if name not in __factory:
+ raise KeyError("Unknown Dataset:", name)
+ return __factory[name](*args, **kwargs)
\ No newline at end of file
diff --git a/datasets/data_loader.py b/datasets/data_loader.py
index 2dc9476..f147add 100755
--- a/datasets/data_loader.py
+++ b/datasets/data_loader.py
@@ -3,97 +3,118 @@
from torchvision import transforms
from lib.config import cfg
from datasets.coco_dataset import CocoDataset
+from datasets.radiology_dataset import IUXRAY
import samplers.distributed
import numpy as np
+import argparse
+import sys
+
def sample_collate(batch):
+ mask_dim = 70
indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch)
-
- indices = np.stack(indices, axis=0).reshape(-1)
- input_seq = torch.cat([torch.from_numpy(b) for b in input_seq], 0)
- target_seq = torch.cat([torch.from_numpy(b) for b in target_seq], 0)
- gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
- atts_num = [x.shape[0] for x in att_feats]
- max_att_num = np.max(atts_num)
+ max_seq_length = max([len(x) for x in input_seq])
+ input_seqs = np.zeros((len(input_seq), max_seq_length), dtype=int)
+ target_seqs = np.zeros((len(target_seq), max_seq_length), dtype=int)
+ # print(max_seq_length)
- feat_arr = []
- mask_arr = []
- for i, num in enumerate(atts_num):
- tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32)
- tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i]
- feat_arr.append(torch.from_numpy(tmp_feat))
+ for i, input in enumerate(input_seq):
+ input_seqs[i, :len(input)] = input
- tmp_mask = np.zeros((1, max_att_num), dtype=np.float32)
- tmp_mask[:, 0:num] = 1
- mask_arr.append(torch.from_numpy(tmp_mask))
+ for i, target in enumerate(target_seq):
+ target_seqs[i, :len(target)] = target
- att_feats = torch.cat(feat_arr, 0)
- att_mask = torch.cat(mask_arr, 0)
+ indices = np.stack(indices, axis=0).reshape(-1)
- return indices, input_seq, target_seq, gv_feat, att_feats, att_mask
+ gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
+
+ # IT DOESNT MATTER WHAT THE SHAPE OF MASK IS, WE WILL GET THE MASK LATER
+ mask_arr = torch.ones([len(att_feats), mask_dim]).float()
+
+ att_mask = torch.cat([mask_arr], 0)
+ att_feats = torch.stack(att_feats, 0) # TODO for mimic
+ """
+ indices (40, )
+ input_seq (40, 60)
+ target_seq (40, 60)
+ gv_feat (40, 1)
+ att_feats (40, 2, 3, 224, 224)
+ att_mask (40, 1, 3) => (40, 49)
+ """
+ return indices, torch.LongTensor(input_seqs), torch.LongTensor(target_seqs), gv_feat, att_feats, att_mask
def sample_collate_val(batch):
- indices, gv_feat, att_feats = zip(*batch)
-
- indices = np.stack(indices, axis=0).reshape(-1)
- gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
+ mask_dim = 70
+
+ indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch)
+
+ max_seq_length = max([len(x) for x in input_seq])
+ input_seqs = np.zeros((len(input_seq), max_seq_length), dtype=int)
+ target_seqs = np.zeros((len(target_seq), max_seq_length), dtype=int)
+ # print(max_seq_length)
- atts_num = [x.shape[0] for x in att_feats]
- max_att_num = np.max(atts_num)
+ for i, input in enumerate(input_seq):
+ input_seqs[i, :len(input)] = input
- feat_arr = []
- mask_arr = []
- for i, num in enumerate(atts_num):
- tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32)
- tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i]
- feat_arr.append(torch.from_numpy(tmp_feat))
+ for i, target in enumerate(target_seq):
+ target_seqs[i, :len(target)] = target
- tmp_mask = np.zeros((1, max_att_num), dtype=np.float32)
- tmp_mask[:, 0:num] = 1
- mask_arr.append(torch.from_numpy(tmp_mask))
+ indices = np.stack(indices, axis=0).reshape(-1)
+
+ gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
- att_feats = torch.cat(feat_arr, 0)
- att_mask = torch.cat(mask_arr, 0)
+ # IT DOESNT MATTER WHAT THE SHAPE OF MASK IS, WE WILL GET THE MASK LATER
+ mask_arr = torch.ones([len(att_feats), mask_dim]).float()
- return indices, gv_feat, att_feats, att_mask
+ att_mask = torch.cat([mask_arr], 0)
+ att_feats = torch.stack(att_feats, 0)
+ """
+ indices (40, )
+ input_seq (40, 60)
+ target_seq (40, 60)
+ gv_feat (40, 1)
+ att_feats (40, 2, 3, 224, 224)
+ att_mask (40, 1, 3) => (40, 49)
+ """
+ return indices, torch.LongTensor(target_seqs), gv_feat, att_feats, att_mask
-def load_train(distributed, epoch, coco_set):
- sampler = samplers.distributed.DistributedSampler(coco_set, epoch=epoch) \
+def load_train(distributed, epoch, dataset):
+ sampler = samplers.distributed.DistributedSampler(dataset, epoch=epoch) \
if distributed else None
shuffle = cfg.DATA_LOADER.SHUFFLE if sampler is None else False
-
+
loader = torch.utils.data.DataLoader(
- coco_set,
+ dataset,
batch_size = cfg.TRAIN.BATCH_SIZE,
- shuffle = shuffle,
- num_workers = cfg.DATA_LOADER.NUM_WORKERS,
- drop_last = cfg.DATA_LOADER.DROP_LAST,
+ shuffle = shuffle,
+ num_workers = cfg.DATA_LOADER.NUM_WORKERS,
+ drop_last = cfg.DATA_LOADER.DROP_LAST,
pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
- sampler = sampler,
+ sampler = sampler,
collate_fn = sample_collate
)
return loader
-def load_val(image_ids_path, gv_feat_path, att_feats_folder):
- coco_set = CocoDataset(
- image_ids_path = image_ids_path,
- input_seq = None,
- target_seq = None,
- gv_feat_path = gv_feat_path,
- att_feats_folder = att_feats_folder,
- seq_per_img = 1,
- max_feat_num = cfg.DATA_LOADER.MAX_FEAT
- )
+def load_val(dataset):
+ # coco_set = CocoDataset(
+ # image_ids_path = image_ids_path,
+ # input_seq = None,
+ # target_seq = None,
+ # gv_feat_path = gv_feat_path,
+ # att_feats_folder = att_feats_folder,
+ # seq_per_img = 1,
+ # max_feat_num = cfg.DATA_LOADER.MAX_FEAT
+ # )
loader = torch.utils.data.DataLoader(
- coco_set,
+ dataset,
batch_size = cfg.TEST.BATCH_SIZE,
- shuffle = False,
- num_workers = cfg.DATA_LOADER.NUM_WORKERS,
- drop_last = False,
- pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
+ shuffle = False,
+ num_workers = cfg.DATA_LOADER.NUM_WORKERS,
+ drop_last = False,
+ pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
collate_fn = sample_collate_val
)
- return loader
\ No newline at end of file
+ return loader
diff --git a/datasets/radiology_dataset.py b/datasets/radiology_dataset.py
new file mode 100644
index 0000000..cee6976
--- /dev/null
+++ b/datasets/radiology_dataset.py
@@ -0,0 +1,217 @@
+import os
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torchvision import transforms
+import json
+from PIL import Image
+from .tokenizers import Tokenizer
+import random
+import time
+import copy
+
+def random_position(image_1, image_2, thr):
+ if random.random() < thr:
+ img = torch.stack((image_1, image_2), 0)
+ else:
+ img = torch.stack((image_2, image_1), 0)
+ return img
+
+
+
+class BaseDataset(Dataset):
+ def __init__(self, image_dir, ann_path, tokenizer, split, args):
+ self.image_dir = image_dir
+ self.ann_path = ann_path
+ self.max_seq_length = 60 # hardcode
+ self.split = split
+ self.args = args
+ self.tokenizer = tokenizer
+ self.training_ratio = args.training_ratio
+ assert 0.0 <= self.training_ratio <= 1.0
+
+ if split == 'train':
+ self.transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.RandomCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225))])
+ else:
+ self.transform = transforms.Compose([
+ transforms.Resize((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225))])
+ self.texts = json.loads(open(self.ann_path, 'r').read())
+
+ self.examples = self.texts[self.split]
+ if args.dataset_name == 'MIMICCXR_MultiImages':
+ self.examples = self.convert_to_multi_images(self.examples)
+ if self.split == 'train':
+ self.examples = self.apply_training_ratio(self.examples)
+ for i in range(len(self.examples)):
+ self.examples[i]['ids'] = self.tokenizer(self.examples[i]['report'])[:self.max_seq_length]
+
+ def apply_training_ratio(self, dataset):
+ t = time.time()
+ total = len(dataset)
+ print('{} set: applying training_ratio {} ... '.format(self.split,self.training_ratio), end='', flush=True)
+ select = int(total // (1/self.training_ratio)) + 1
+ dataset = dataset[:select]
+ print('done %d->%d (%.2fs)' % (total, select, time.time() - t), flush=True)
+ return dataset
+
+ def convert_to_multi_images(self, dataset, print_num=True):
+ t = time.time()
+ n = 0
+ if print_num:
+ print('{} set: Converting to multiple image reports ... '.format(self.split), end='', flush=True)
+ mergedDataset = []
+ total = len(dataset)
+
+ buffer = None
+ for i in range(total):
+ document = dataset[i]
+ id = document['id']
+ image_path = document['image_path'][0]
+ # report = document['report']
+ # split = document['split']
+ study_id = document['study_id']
+ # subject_id = document['subject_id']
+
+ if study_id == buffer:
+ mergedDataset[-1]['image_path'].append(image_path)
+ mergedDataset[-1]['id'].append(id)
+ else:
+ newDocument = copy.deepcopy(document)
+ newDocument['id'] = [newDocument['id']]
+ mergedDataset.append(newDocument)
+ n += 1
+ buffer = study_id
+ if print_num:
+ print('done %d->%d (%.2fs)' % (total, n, time.time() - t), flush=True)
+ return mergedDataset
+
+ def __len__(self):
+ return len(self.examples)
+
+
+class IUXRAY(BaseDataset):
+ def __getitem__(self, idx):
+ # indices = np.array([idx]).astype('int') # Modified
+ image_id = self.examples[idx]['id']
+ indices = np.array([image_id])
+ example = self.examples[idx]
+ image_path = example['image_path']
+ image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
+ image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
+ if self.transform is not None:
+ image_1 = self.transform(image_1)
+ image_2 = self.transform(image_2)
+ if self.split == 'train':
+ image = random_position(image_1, image_2, 0.5)
+ else:
+ image = torch.stack((image_1, image_2), 0)
+
+
+ report_ids = np.array(example['ids'])
+
+ input_sequence = np.zeros(self.max_seq_length, dtype='int')
+ target_sequence = np.zeros(self.max_seq_length, dtype='int')
+
+ input_sequence[:len(report_ids)] = report_ids
+ target_sequence[:len(report_ids)-1] = report_ids[1:]
+
+ gv_feat = np.zeros((1, 1)) # Never been used
+ # report_masks = example['mask']
+ # seq_length = len(report_ids)
+ return indices, input_sequence, target_sequence, gv_feat, image
+
+
+class MIMICCXR(BaseDataset): # MimiccxrSingleImageDataset
+ def __getitem__(self, idx):
+ # indices = np.array([idx]).astype('int') # Modified
+ image_id = self.examples[idx]['id']
+ indices = np.array([image_id])
+ example = self.examples[idx]
+ image_path = example['image_path']
+ image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
+ if self.transform is not None:
+ image = self.transform(image)
+ report_ids = np.array(example['ids'])
+
+ input_sequence = np.ones(self.max_seq_length, dtype='int')
+ target_sequence = np.ones(self.max_seq_length, dtype='int')
+
+ input_sequence[:len(report_ids)] = report_ids
+ target_sequence[:len(report_ids) - 1] = report_ids[1:]
+
+ gv_feat = np.zeros((1, 1)) # Never been used
+ # report_masks = example['mask']
+ # seq_length = len(report_ids)
+ return indices, input_sequence, target_sequence, gv_feat, image
+
+class MimiccxrMultiImage(BaseDataset): # MimiccxrMultiImageDataset
+ def __getitem__(self, idx):
+ # indices = np.array([idx]).astype('int') # Modified
+ image_id = self.examples[idx]['id']
+ indices = np.array([image_id])
+ example = self.examples[idx]
+ image_path = example['image_path']
+ image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
+ if self.transform is not None:
+ image = self.transform(image)
+ report_ids = np.array(example['ids'])
+
+ input_sequence = np.ones(self.max_seq_length, dtype='int')
+ target_sequence = np.ones(self.max_seq_length, dtype='int')
+
+ input_sequence[:len(report_ids)] = report_ids
+ target_sequence[:len(report_ids) - 1] = report_ids[1:]
+
+ gv_feat = np.zeros((1, 1)) # Never been used
+ # report_masks = example['mask']
+ # seq_length = len(report_ids)
+ return indices, input_sequence, target_sequence, gv_feat, image
+
+class MimiccxrMultiImage(BaseDataset):
+ def __getitem__(self, idx):
+ image_id = str(self.examples[idx]['subject_id']) + '_' + str(self.examples[idx]['study_id'])
+ indices = np.array([image_id])
+ example = self.examples[idx]
+ image_path = example['image_path']
+ image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
+ try:
+ image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
+ # if this record only have one image, duplicate image1
+ except:
+ image_2 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
+
+ if self.transform is not None:
+ image_1 = self.transform(image_1)
+ image_2 = self.transform(image_2)
+
+ def random_position(image_1, image_2, thr):
+ if random.random() < thr:
+ img = torch.stack((image_1, image_2), 0)
+ else:
+ img = torch.stack((image_2, image_1), 0)
+ return img
+
+ if self.split == 'train':
+ image = random_position(image_1, image_2, 0.5)
+ else:
+ image = torch.stack((image_1, image_2), 0)
+
+ report_ids = np.array(example['ids'])
+
+ input_sequence = np.zeros(self.max_seq_length, dtype='int')
+ target_sequence = np.zeros(self.max_seq_length, dtype='int')
+
+ input_sequence[:len(report_ids)] = report_ids
+ target_sequence[:len(report_ids)-1] = report_ids[1:]
+
+ gv_feat = np.zeros((1, 1)) # Never been used
+ return indices, input_sequence, target_sequence, gv_feat, image
\ No newline at end of file
diff --git a/datasets/tokenizers.py b/datasets/tokenizers.py
new file mode 100644
index 0000000..36d2261
--- /dev/null
+++ b/datasets/tokenizers.py
@@ -0,0 +1,129 @@
+import json
+import re
+from collections import Counter
+
+
+class Tokenizer(object):
+ def __init__(self, ann_path, dataset_name):
+ self.ann_path = ann_path # todo
+ self.threshold = 3
+ self.dataset_name = dataset_name # todo
+ if self.dataset_name == 'iu_xray':
+ self.clean_report = self.clean_report_iu_xray
+ else:
+ self.clean_report = self.clean_report_mimic_cxr
+ self.ann = json.loads(open(self.ann_path, 'r').read())
+ self.token2idx, self.idx2token = self.create_vocabulary_and_CIDEr_DF()
+
+ def create_vocabulary_and_CIDEr_DF(self):
+ # print('Building Vocabulary and CIDEr DF! ')
+ print('Building Vocabulary! ')
+ total_tokens = []
+
+ for example in self.ann['train']:
+ tokens = self.clean_report(example['report']).split()
+ for token in tokens:
+ total_tokens.append(token)
+
+ counter = Counter(total_tokens)
+ vocab = [k for k, v in counter.items() if v >= self.threshold] + ['']
+ vocab.sort()
+ token2idx, idx2token = {}, {}
+ for idx, token in enumerate(vocab):
+ token2idx[token] = idx + 1
+ idx2token[idx + 1] = token
+ print('Vocab Done. Dataset: {} Vocab Size: {}, include "" , non exists "sos" "eos"'.format(self.dataset_name,len(vocab)))
+
+ # for text in self.ann['train']:
+ # toks = tokenizer.tokenize(textfilter.filter(text))
+ # toks = tokenfilter.filter(toks)
+ # ftext = ' '.join(toks)
+ # ftexts.append(ftext)
+ #
+ # df = GenEval.compute_cider_df(texts)
+ # with gzip.open(args.output, 'w') as f:
+ # pickle.dump(df, f)
+
+
+ return token2idx, idx2token
+
+ def clean_report_iu_xray(self, report):
+ report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
+ .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
+ .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
+ .strip().lower().split('. ')
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
+ replace('\\', '').replace("'", '').strip().lower())
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
+ report = ' . '.join(tokens) + ' .'
+ return report
+
+ def clean_report_mimic_cxr(self, report):
+ report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
+ .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
+ .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
+ .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
+ .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
+ .strip().lower().split('. ')
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
+ .replace('\\', '').replace("'", '').strip().lower())
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
+ report = ' . '.join(tokens) + ' .'
+ return report
+
+ # def get_id2word(self):
+
+ def get_token_by_id(self, id):
+ return self.idx2token[id]
+
+ def get_id_by_token(self, token):
+ if token not in self.token2idx:
+ return self.token2idx['']
+ return self.token2idx[token]
+
+ def get_vocab_size(self):
+ return len(self.token2idx)
+
+ def __call__(self, report):
+ tokens = self.clean_report(report).split()
+ ids = []
+ for token in tokens:
+ ids.append(self.get_id_by_token(token))
+ ids = [0] + ids + [0]
+ return ids
+
+ def decode(self, ids):
+ txt = ''
+ for i, idx in enumerate(ids):
+ if idx > 0:
+ if i >= 1:
+ txt += ' '
+ txt += self.idx2token[idx]
+ else:
+ break
+ return txt
+
+ def decode_batch(self, ids_batch):
+ out = []
+ for ids in ids_batch:
+ out.append(self.decode(ids))
+ return out
+
+
+# def cider_df(texts):
+# tokenizer = get_tokenizer('nltk')
+# textfilter = get_textfilter('lower')
+# tokenfilter = get_tokenfilter('none')
+
+# ftexts = []
+# for text in texts:
+# toks = tokenizer.tokenize(textfilter.filter(text))
+# toks = tokenfilter.filter(toks)
+# ftext = ' '.join(toks)
+# ftexts.append(ftext)
+
+# df = GenEval.compute_cider_df(texts) # texts : [report_text, ... ] , where report_text : [filtered_token, ... ]
+# with gzip.open('mimic-cxr_train-df.bin.gz', 'w') as f:
+# pickle.dump(df, f)
\ No newline at end of file
diff --git a/evaluation/evaler.py b/evaluation/evaler.py
index 873062a..0c1fdf5 100755
--- a/evaluation/evaler.py
+++ b/evaluation/evaler.py
@@ -9,56 +9,113 @@
import datasets.data_loader as data_loader
from lib.config import cfg
+from pycocoevalcap.bleu.bleu import Bleu
+from pycocoevalcap.meteor import Meteor
+from pycocoevalcap.rouge import Rouge
+from scorer.cider import Cider
+device = torch.device("cuda")
+
+
+def compute_scores(gts, res):
+ """
+ Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap)
+
+ :param gts: Dictionary with the image ids and their gold captions,
+ :param res: Dictionary with the image ids ant their generated captions
+ :print: Evaluation score (the mean of the scores of all the instances) for each measure
+ """
+
+ # Set up scorers
+ scorers = [
+ (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]),
+ (Cider(), 'CIDEr'),
+ (Meteor(), "METEOR"),
+ (Rouge(), "ROUGE_L")
+ ]
+ eval_res = {}
+ # Compute score for each metric
+ for scorer, method in scorers:
+ try:
+ score, scores = scorer.compute_score(gts, res, verbose=0)
+ except TypeError:
+ score, scores = scorer.compute_score(gts, res)
+ if type(method) == list:
+ for sc, m in zip(score, method):
+ eval_res[m] = sc
+ else:
+ eval_res[method] = score
+ return eval_res
+
+
class Evaler(object):
def __init__(
- self,
- eval_ids,
- gv_feat,
- att_feats,
- eval_annfile
+ self,
+ dataset,
+ tokenizer
):
super(Evaler, self).__init__()
- self.vocab = utils.load_vocab(cfg.INFERENCE.VOCAB)
+ self.tokenizer = tokenizer
+ # self.vocab = utils.load_vocab(cfg.INFERENCE.VOCAB) # TODO
- self.eval_ids = np.array(utils.load_ids(eval_ids))
- self.eval_loader = data_loader.load_val(eval_ids, gv_feat, att_feats)
- self.evaler = evaluation.create(cfg.INFERENCE.EVAL, eval_annfile)
+ # self.eval_ids = np.array(utils.load_ids(eval_ids))
+ self.eval_loader = data_loader.load_val(dataset)
+ # self.evaler = evaluation.create(cfg.INFERENCE.EVAL, eval_annfile)
+ self.evaler = compute_scores
def make_kwargs(self, indices, ids, gv_feat, att_feats, att_mask):
kwargs = {}
kwargs[cfg.PARAM.INDICES] = indices
kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
kwargs[cfg.PARAM.ATT_FEATS] = att_feats
+ # att_mask_a = torch.ones(16,70).to(device)
kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
+ # kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask_a
kwargs['BEAM_SIZE'] = cfg.INFERENCE.BEAM_SIZE
kwargs['GREEDY_DECODE'] = cfg.INFERENCE.GREEDY_DECODE
return kwargs
-
+
def __call__(self, model, rname):
model.eval()
-
- results = []
+
+ results, golden_sents = {}, {}
with torch.no_grad():
- for _, (indices, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)):
- ids = self.eval_ids[indices]
+ for _, (indices, target_seq, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)):
+ ids = indices
gv_feat = gv_feat.cuda()
att_feats = att_feats.cuda()
att_mask = att_mask.cuda()
kwargs = self.make_kwargs(indices, ids, gv_feat, att_feats, att_mask)
if kwargs['BEAM_SIZE'] > 1:
- seq, _ = model.module.decode_beam(**kwargs)
+ seq, _ = model.decode_beam(**kwargs) # modified
else:
- seq, _ = model.module.decode(**kwargs)
- sents = utils.decode_sequence(self.vocab, seq.data)
+ seq, _ = model.decode(**kwargs)
+ sents = utils.decode_sequence(self.tokenizer.idx2token, seq.data) # to check
+ # sents: [sent (str), ... ]
+ gold_sents = utils.decode_sequence(self.tokenizer.idx2token,
+ target_seq.data) # to check target_seq callable
+
+
+ # for sid, sent in enumerate(sents):
+ # result = {ids[sid]: [sent]}
+ # results.append(result)
+ # for sid, sent in enumerate(gold_sents):
+ # g_sents = {ids[sid]: [sent]}
+ # golden_sents.append(g_sents)
+
for sid, sent in enumerate(sents):
- result = {cfg.INFERENCE.ID_KEY: int(ids[sid]), cfg.INFERENCE.CAP_KEY: sent}
- results.append(result)
- eval_res = self.evaler.eval(results)
+ results[ids[sid]] = [sent] # sents: [sent (str), ... ]
+ # results.append(result)
+ for sid, sent in enumerate(gold_sents):
+ golden_sents[ids[sid]] = [sent]
+ # golden_sents.append(g_sents)
+ # golden_sents : {image_id, [sent (str) ]}
+ eval_res = self.evaler(golden_sents, results)
result_folder = os.path.join(cfg.ROOT_DIR, 'result')
if not os.path.exists(result_folder):
os.mkdir(result_folder)
- json.dump(results, open(os.path.join(result_folder, 'result_' + rname +'.json'), 'w'))
+ json.dump(results,
+ open(os.path.join(result_folder, 'result_' + rname + '.json'), 'w')) # store the generated sentences
model.train()
- return eval_res
\ No newline at end of file
+ return eval_res
diff --git a/experiments/xlan_rl/train.sh b/experiments/xlan_rl/train.sh
deleted file mode 100644
index 648cb01..0000000
--- a/experiments/xlan_rl/train.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments/xlan_rl --resume 47
-
-# 47 is the epoch number of the pretrained model
diff --git a/experiments/xtransformer_rl/train.sh b/experiments/xtransformer_rl/train.sh
deleted file mode 100644
index 8ecb209..0000000
--- a/experiments/xtransformer_rl/train.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments/xtransformer_rl --resume 39
-
-# 39 is the epoch number of the pretrained model
diff --git a/experiments/xlan/config.yml b/experiments_iuxray/xlan/config.yml
similarity index 95%
rename from experiments/xlan/config.yml
rename to experiments_iuxray/xlan/config.yml
index eedf3e9..8967371 100644
--- a/experiments/xlan/config.yml
+++ b/experiments_iuxray/xlan/config.yml
@@ -39,8 +39,8 @@ DATA_LOADER:
############################ MODEL ############################
MODEL:
TYPE: 'XLAN'
- SEQ_LEN: 17 # include /
- VOCAB_SIZE: 9487 # exclude /
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
########## word embedding ##########
WORD_EMBED_DIM: 1024
WORD_EMBED_ACT: 'CELU'
@@ -52,7 +52,7 @@ MODEL:
GVFEAT_EMBED_ACT: 'NONE'
DROPOUT_GV_EMBED: 0.0
########## attention features ##########
- ATT_FEATS_DIM: 2048
+ ATT_FEATS_DIM: 1024 # Modified
ATT_FEATS_EMBED_DIM: 1024
ATT_FEATS_EMBED_ACT: 'CELU'
DROPOUT_ATT_EMBED: 0.5
diff --git a/experiments/xlan/train.sh b/experiments_iuxray/xlan/train.sh
similarity index 50%
rename from experiments/xlan/train.sh
rename to experiments_iuxray/xlan/train.sh
index e16d6f5..c498753 100644
--- a/experiments/xlan/train.sh
+++ b/experiments_iuxray/xlan/train.sh
@@ -1 +1 @@
-CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments/xlan
+CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan
diff --git a/experiments/xlan_rl/config.yml b/experiments_iuxray/xlan_rl/config.yml
similarity index 95%
rename from experiments/xlan_rl/config.yml
rename to experiments_iuxray/xlan_rl/config.yml
index 861da38..931c356 100644
--- a/experiments/xlan_rl/config.yml
+++ b/experiments_iuxray/xlan_rl/config.yml
@@ -39,8 +39,8 @@ DATA_LOADER:
############################ MODEL ############################
MODEL:
TYPE: 'XLAN'
- SEQ_LEN: 17 # include /
- VOCAB_SIZE: 9487 # exclude /
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
########## word embedding ##########
WORD_EMBED_DIM: 1024
WORD_EMBED_ACT: 'CELU'
@@ -52,7 +52,7 @@ MODEL:
GVFEAT_EMBED_ACT: 'NONE'
DROPOUT_GV_EMBED: 0.0
########## attention features ##########
- ATT_FEATS_DIM: 2048
+ ATT_FEATS_DIM: 1024 # Modified
ATT_FEATS_EMBED_DIM: 1024
ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE'
DROPOUT_ATT_EMBED: 0.5
diff --git a/experiments_iuxray/xlan_rl/train.sh b/experiments_iuxray/xlan_rl/train.sh
new file mode 100644
index 0000000..5079ea7
--- /dev/null
+++ b/experiments_iuxray/xlan_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47
+
+# 47 is the epoch number of the pretrained model
diff --git a/experiments/xtransformer/config.yml b/experiments_iuxray/xtransformer/config.yml
similarity index 90%
rename from experiments/xtransformer/config.yml
rename to experiments_iuxray/xtransformer/config.yml
index 51d1c6b..3c4e4fd 100644
--- a/experiments/xtransformer/config.yml
+++ b/experiments_iuxray/xtransformer/config.yml
@@ -3,14 +3,14 @@ SEED: 1546884941.160048
############################ TRAIN ############################
TRAIN:
- BATCH_SIZE: 40
+ BATCH_SIZE: 16
#################### REINFORCEMENT ####################
REINFORCEMENT:
START: 9999
-############################ TEST ############################
+############################ TEST ############################
TEST:
- BATCH_SIZE: 36
+ BATCH_SIZE: 16
############################ DATA_LOADER ############################
DATA_LOADER:
@@ -27,18 +27,18 @@ DATA_LOADER:
TEST_ID: './mscoco/txt/coco_test_image_id.txt'
INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
- SEQ_PER_IMG: 5
+ SEQ_PER_IMG: 1
MAX_FEAT: 50
############################ MODEL ############################
MODEL:
TYPE: 'XTransformer'
- SEQ_LEN: 17 # include /
- VOCAB_SIZE: 9487 # exclude /
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
########## word embedding ##########
WORD_EMBED_DIM: 768
WORD_EMBED_ACT: 'CELU'
- WORD_EMBED_NORM: False
+ WORD_EMBED_NORM: False
DROPOUT_WORD_EMBED: 0.1
########## global features ##########
GVFEAT_DIM: 2048
@@ -46,7 +46,7 @@ MODEL:
GVFEAT_EMBED_ACT: 'NONE'
DROPOUT_GV_EMBED: 0.0
########## attention features ##########
- ATT_FEATS_DIM: 2048
+ ATT_FEATS_DIM: 1024 # Modified
ATT_FEATS_EMBED_DIM: 768
ATT_FEATS_EMBED_ACT: 'CELU'
DROPOUT_ATT_EMBED: 0.5
@@ -87,9 +87,9 @@ MODEL:
ENCODE_BIFEAT_EMB_DROPOUT: 0.3
DECODE_BIFEAT_EMB_DROPOUT: 0.3
-############################ SOLVER ############################
+############################ SOLVER ############################
SOLVER:
- BASE_LR: 0.0005
+ BASE_LR: 0.000001
TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
MAX_EPOCH: 70
MAX_ITER: -1
@@ -98,7 +98,7 @@ SOLVER:
WEIGHT_DECAY: 0.0000
WEIGHT_DECAY_BIAS: 0.0
BIAS_LR_FACTOR: 1
- DISPLAY: 20
+ DISPLAY: 100
TEST_INTERVAL: 1
SNAPSHOT_ITERS: 1
diff --git a/experiments_iuxray/xtransformer/train.sh b/experiments_iuxray/xtransformer/train.sh
new file mode 100644
index 0000000..c5b6482
--- /dev/null
+++ b/experiments_iuxray/xtransformer/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer --resume 0
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/experiments_iuxray/xtransformer_VSEGCN/config.yml b/experiments_iuxray/xtransformer_VSEGCN/config.yml
new file mode 100644
index 0000000..3c4e4fd
--- /dev/null
+++ b/experiments_iuxray/xtransformer_VSEGCN/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 100
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray/xtransformer_VSEGCN/train.sh b/experiments_iuxray/xtransformer_VSEGCN/train.sh
new file mode 100644
index 0000000..467a7a3
--- /dev/null
+++ b/experiments_iuxray/xtransformer_VSEGCN/train.sh
@@ -0,0 +1,5 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer_VSEGCN
+ --submodel VSEGCN \
+ --resume 0
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/experiments_iuxray/xtransformer_rl/config.yml b/experiments_iuxray/xtransformer_rl/config.yml
new file mode 100644
index 0000000..659e45e
--- /dev/null
+++ b/experiments_iuxray/xtransformer_rl/config.yml
@@ -0,0 +1,147 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.0
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.1
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.1
+ DECODE_BIFEAT_EMB_DROPOUT: 0.1
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000005
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 60
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray/xtransformer_rl/train.sh b/experiments_iuxray/xtransformer_rl/train.sh
new file mode 100644
index 0000000..e69c213
--- /dev/null
+++ b/experiments_iuxray/xtransformer_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 0
+
+# 39 is the epoch number of the pretrained model
diff --git a/experiments_iuxray_dwe/xlan/config.yml b/experiments_iuxray_dwe/xlan/config.yml
new file mode 100644
index 0000000..8967371
--- /dev/null
+++ b/experiments_iuxray_dwe/xlan/config.yml
@@ -0,0 +1,148 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 10
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.0005
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments/xtransformer/train.sh b/experiments_iuxray_dwe/xlan/train.sh
similarity index 50%
rename from experiments/xtransformer/train.sh
rename to experiments_iuxray_dwe/xlan/train.sh
index 787b9d8..c498753 100644
--- a/experiments/xtransformer/train.sh
+++ b/experiments_iuxray_dwe/xlan/train.sh
@@ -1 +1 @@
-CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments/xtransformer
+CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan
diff --git a/experiments_iuxray_dwe/xlan_rl/config.yml b/experiments_iuxray_dwe/xlan_rl/config.yml
new file mode 100644
index 0000000..931c356
--- /dev/null
+++ b/experiments_iuxray_dwe/xlan_rl/config.yml
@@ -0,0 +1,151 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 20
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH' # 'RELU', 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.0
+ DECODE_ATT_MID_DROPOUT: 0.0
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.00001
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 35
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_dwe/xlan_rl/train.sh b/experiments_iuxray_dwe/xlan_rl/train.sh
new file mode 100644
index 0000000..5079ea7
--- /dev/null
+++ b/experiments_iuxray_dwe/xlan_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47
+
+# 47 is the epoch number of the pretrained model
diff --git a/experiments_iuxray_dwe/xtransformer/config.yml b/experiments_iuxray_dwe/xtransformer/config.yml
new file mode 100644
index 0000000..2a8a23d
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 12
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 12
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 100
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Step' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.9
+ STEP_SIZE: 1
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 1000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_dwe/xtransformer/train.sh b/experiments_iuxray_dwe/xtransformer/train.sh
new file mode 100644
index 0000000..07292c5
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray_dwe/xtransformer --resume 0 --encoder_mode dualwayencoder --dataset_name IUXRAY --submodel VSEGCN
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml b/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml
new file mode 100644
index 0000000..3c4e4fd
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 100
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh b/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh
new file mode 100644
index 0000000..467a7a3
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh
@@ -0,0 +1,5 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer_VSEGCN
+ --submodel VSEGCN \
+ --resume 0
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/experiments_iuxray_dwe/xtransformer_rl/config.yml b/experiments_iuxray_dwe/xtransformer_rl/config.yml
new file mode 100644
index 0000000..659e45e
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer_rl/config.yml
@@ -0,0 +1,147 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.0
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.1
+ ENCODE_LAYERS: 6
+ DECODE_LAYERS: 6
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.1
+ DECODE_BIFEAT_EMB_DROPOUT: 0.1
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000005
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 60
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_dwe/xtransformer_rl/train.sh b/experiments_iuxray_dwe/xtransformer_rl/train.sh
new file mode 100644
index 0000000..e69c213
--- /dev/null
+++ b/experiments_iuxray_dwe/xtransformer_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 0
+
+# 39 is the epoch number of the pretrained model
diff --git a/experiments_iuxray_testing/xlan/config.yml b/experiments_iuxray_testing/xlan/config.yml
new file mode 100644
index 0000000..8967371
--- /dev/null
+++ b/experiments_iuxray_testing/xlan/config.yml
@@ -0,0 +1,148 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 10
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.0005
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_testing/xlan/train.sh b/experiments_iuxray_testing/xlan/train.sh
new file mode 100644
index 0000000..c498753
--- /dev/null
+++ b/experiments_iuxray_testing/xlan/train.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan
diff --git a/experiments_iuxray_testing/xlan_rl/config.yml b/experiments_iuxray_testing/xlan_rl/config.yml
new file mode 100644
index 0000000..931c356
--- /dev/null
+++ b/experiments_iuxray_testing/xlan_rl/config.yml
@@ -0,0 +1,151 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 20
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH' # 'RELU', 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.0
+ DECODE_ATT_MID_DROPOUT: 0.0
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.00001
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 35
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_testing/xlan_rl/train.sh b/experiments_iuxray_testing/xlan_rl/train.sh
new file mode 100644
index 0000000..5079ea7
--- /dev/null
+++ b/experiments_iuxray_testing/xlan_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47
+
+# 47 is the epoch number of the pretrained model
diff --git a/experiments_iuxray_testing/xtransformer/config.yml b/experiments_iuxray_testing/xtransformer/config.yml
new file mode 100644
index 0000000..b6d726f
--- /dev/null
+++ b/experiments_iuxray_testing/xtransformer/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 3
+ DECODE_LAYERS: 3
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 100
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 3
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_iuxray_testing/xtransformer/train.sh b/experiments_iuxray_testing/xtransformer/train.sh
new file mode 100644
index 0000000..873e8e0
--- /dev/null
+++ b/experiments_iuxray_testing/xtransformer/train.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer
diff --git a/experiments/xtransformer_rl/config.yml b/experiments_iuxray_testing/xtransformer_rl/config.yml
similarity index 95%
rename from experiments/xtransformer_rl/config.yml
rename to experiments_iuxray_testing/xtransformer_rl/config.yml
index 7544a0e..8fb2803 100644
--- a/experiments/xtransformer_rl/config.yml
+++ b/experiments_iuxray_testing/xtransformer_rl/config.yml
@@ -33,8 +33,8 @@ DATA_LOADER:
############################ MODEL ############################
MODEL:
TYPE: 'XTransformer'
- SEQ_LEN: 17 # include /
- VOCAB_SIZE: 9487 # exclude /
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 9487 # TODO # exclude /
########## word embedding ##########
WORD_EMBED_DIM: 768
WORD_EMBED_ACT: 'CELU'
@@ -46,7 +46,7 @@ MODEL:
GVFEAT_EMBED_ACT: 'NONE'
DROPOUT_GV_EMBED: 0.0
########## attention features ##########
- ATT_FEATS_DIM: 2048
+ ATT_FEATS_DIM: 1024 # Modified
ATT_FEATS_EMBED_DIM: 768
ATT_FEATS_EMBED_ACT: 'CELU'
DROPOUT_ATT_EMBED: 0.5
@@ -89,7 +89,7 @@ MODEL:
############################ SOLVER ############################
SOLVER:
- BASE_LR: 0.00001
+ BASE_LR: 0.000005
TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
MAX_EPOCH: 60
MAX_ITER: -1
diff --git a/experiments_iuxray_testing/xtransformer_rl/train.sh b/experiments_iuxray_testing/xtransformer_rl/train.sh
new file mode 100644
index 0000000..c97b3fd
--- /dev/null
+++ b/experiments_iuxray_testing/xtransformer_rl/train.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 39
+
+# 39 is the epoch number of the pretrained model
diff --git a/experiments_mimiccxr/xlan/config.yml b/experiments_mimiccxr/xlan/config.yml
new file mode 100644
index 0000000..e79cc1f
--- /dev/null
+++ b/experiments_mimiccxr/xlan/config.yml
@@ -0,0 +1,148 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 10
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 7863 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.0005
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 2000
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 300
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments_mimiccxr/xlan/train.sh b/experiments_mimiccxr/xlan/train.sh
new file mode 100644
index 0000000..55a8bec
--- /dev/null
+++ b/experiments_mimiccxr/xlan/train.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch \
+ --nproc_per_node=4 main.py --folder ./experiments_mimiccxr/xlan \
+ --dataset_name MIMICCXR \
+ --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json
\ No newline at end of file
diff --git a/experiments_mimiccxr/xlan_rl/config.yml b/experiments_mimiccxr/xlan_rl/config.yml
new file mode 100644
index 0000000..1d7cf07
--- /dev/null
+++ b/experiments_mimiccxr/xlan_rl/config.yml
@@ -0,0 +1,151 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 20
+ #################### SCHEDULED_SAMPLING ####################
+ SCHEDULED_SAMPLING:
+ START: 6
+ INC_EVERY: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.5
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 36
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: -1
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XLAN'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 7863 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 1024
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.5
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 1024
+ ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: False
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH' # 'RELU', 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 1024
+ ENCODE_ATT_MID_DIM: [128, 64, 128]
+ DECODE_ATT_MID_DIM: [128, 64, 128]
+ ENCODE_ATT_MID_DROPOUT: 0.0
+ DECODE_ATT_MID_DROPOUT: 0.0
+ ATT_DIM: 1024
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 4
+ DECODE_LAYERS: 1
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_BLOCK: 'LowRankBilinearEnc'
+ DECODE_BLOCK: 'LowRankBilinearDec'
+ ELU_ALPHA: 1.3
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.00001
+ TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 35
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam'
+ GAMMA: 0.8
+ STEP_SIZE: 300
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 1024 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.0
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 3
+ GREEDY_DECODE: True
diff --git a/experiments_mimiccxr/xlan_rl/train.sh b/experiments_mimiccxr/xlan_rl/train.sh
new file mode 100644
index 0000000..6a3862a
--- /dev/null
+++ b/experiments_mimiccxr/xlan_rl/train.sh
@@ -0,0 +1,5 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py \
+ --folder ./experiments_mimiccxr/xlan_rl --resume 47 \
+ --dataset_name MIMICCXR \
+ --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json
+# 47 is the epoch number of the pretrained model
diff --git a/experiments_mimiccxr/xtransformer/config.yml b/experiments_mimiccxr/xtransformer/config.yml
new file mode 100644
index 0000000..a2a7d12
--- /dev/null
+++ b/experiments_mimiccxr/xtransformer/config.yml
@@ -0,0 +1,144 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 9999
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 1
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 7863 # TODO # exclude / IUXRAY: 760
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.5
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.5
+ ENCODE_LAYERS: 3
+ DECODE_LAYERS: 3
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.3
+ DECODE_BIFEAT_EMB_DROPOUT: 0.3
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000001
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM'
+ MAX_EPOCH: 70
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 2000
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.98]
+ EPS: 1.0e-9
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 300 # modified
+ SETP_TYPE: 'Iter' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_mimiccxr/xtransformer/train.sh b/experiments_mimiccxr/xtransformer/train.sh
new file mode 100644
index 0000000..066e383
--- /dev/null
+++ b/experiments_mimiccxr/xtransformer/train.sh
@@ -0,0 +1,7 @@
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch \
+ --nproc_per_node=1 main.py --folder ./experiments_mimiccxr/xtransformer \
+ --dataset_name MIMICCXR \
+ --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json \
+ --submodel VSEGCN --resume 0
+
+###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number
diff --git a/experiments_mimiccxr/xtransformer_rl/config.yml b/experiments_mimiccxr/xtransformer_rl/config.yml
new file mode 100644
index 0000000..f5fa64d
--- /dev/null
+++ b/experiments_mimiccxr/xtransformer_rl/config.yml
@@ -0,0 +1,147 @@
+LOGGER_NAME: 'log'
+SEED: 1546884941.160048
+
+############################ TRAIN ############################
+TRAIN:
+ BATCH_SIZE: 16
+ #################### REINFORCEMENT ####################
+ REINFORCEMENT:
+ START: 0
+
+############################ TEST ############################
+TEST:
+ BATCH_SIZE: 16
+
+############################ DATA_LOADER ############################
+DATA_LOADER:
+ NUM_WORKERS: 4
+ SHUFFLE: True
+ TRAIN_GV_FEAT: ''
+ TRAIN_ATT_FEATS: './mscoco/feature/up_down_100'
+ VAL_GV_FEAT: ''
+ VAL_ATT_FEATS: './mscoco/feature/up_down_100'
+ TEST_GV_FEAT: ''
+ TEST_ATT_FEATS: './mscoco/feature/up_down_100'
+ TRAIN_ID: './mscoco/txt/coco_train_image_id.txt'
+ VAL_ID: './mscoco/txt/coco_val_image_id.txt'
+ TEST_ID: './mscoco/txt/coco_test_image_id.txt'
+ INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl'
+ TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl'
+ SEQ_PER_IMG: 5
+ MAX_FEAT: 50
+
+############################ MODEL ############################
+MODEL:
+ TYPE: 'XTransformer'
+ SEQ_LEN: 60 # Modified # include /
+ VOCAB_SIZE: 7863 # TODO # exclude /
+ ########## word embedding ##########
+ WORD_EMBED_DIM: 768
+ WORD_EMBED_ACT: 'CELU'
+ WORD_EMBED_NORM: False
+ DROPOUT_WORD_EMBED: 0.1
+ ########## global features ##########
+ GVFEAT_DIM: 2048
+ GVFEAT_EMBED_DIM: -1
+ GVFEAT_EMBED_ACT: 'NONE'
+ DROPOUT_GV_EMBED: 0.0
+ ########## attention features ##########
+ ATT_FEATS_DIM: 1024 # Modified
+ ATT_FEATS_EMBED_DIM: 768
+ ATT_FEATS_EMBED_ACT: 'CELU'
+ DROPOUT_ATT_EMBED: 0.5
+ ATT_FEATS_NORM: True
+ ########## attention param ##########
+ ATT_HIDDEN_SIZE: -1
+ ATT_HIDDEN_DROP: 0.0
+ ATT_ACT: 'TANH'
+ ########## rnn param ##########
+ RNN_SIZE: 1024
+ DROPOUT_LM: 0.0
+
+ ########## BOTTOM_UP ##########
+ BOTTOM_UP:
+ DROPOUT_FIRST_INPUT: 0.0
+ DROPOUT_SEC_INPUT: 0.0
+
+ ########## BILINEAR ##########
+ BILINEAR:
+ DIM: 768
+ ENCODE_ATT_MID_DIM: [96, 48, 96]
+ DECODE_ATT_MID_DIM: [96, 48, 96]
+ ENCODE_ATT_MID_DROPOUT: 0.1
+ DECODE_ATT_MID_DROPOUT: 0.1
+ ATT_DIM: 768
+ ACT: 'CELU'
+ ENCODE_DROPOUT: 0.5
+ DECODE_DROPOUT: 0.1
+ ENCODE_LAYERS: 3
+ DECODE_LAYERS: 3
+ TYPE: 'LowRank'
+ ATTTYPE: 'SCAtt' # SCAtt, BasicAtt
+ HEAD: 8
+ ENCODE_FF_DROPOUT: 0.5
+ DECODE_FF_DROPOUT: 0.5
+ ELU_ALPHA: 1.3
+ BIFEAT_EMB_ACT: 'RELU'
+ ENCODE_BIFEAT_EMB_DROPOUT: 0.1
+ DECODE_BIFEAT_EMB_DROPOUT: 0.1
+
+############################ SOLVER ############################
+SOLVER:
+ BASE_LR: 0.000005
+ TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP'
+ MAX_EPOCH: 60
+ MAX_ITER: -1
+ GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1
+ GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm'
+ WEIGHT_DECAY: 0.0000
+ WEIGHT_DECAY_BIAS: 0.0
+ BIAS_LR_FACTOR: 1
+ DISPLAY: 20
+ TEST_INTERVAL: 1
+ SNAPSHOT_ITERS: 1
+
+ ########## SGD ##########
+ SGD:
+ MOMENTUM: 0.9
+ ########## ADAM ##########
+ ADAM:
+ BETAS: [0.9, 0.999]
+ EPS: 1.0e-8
+ ########## LR_POLICY ##########
+ LR_POLICY:
+ TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau'
+ GAMMA: 0.8
+ STEP_SIZE: 300
+ SETP_TYPE: 'Epoch' # 'Epoch', 'Iter'
+ WARMUP: 10000 # For Noam only
+ FACTOR: 1.0 # For Noam only
+ MODEL_SIZE: 768 # For Noam only
+
+ PLATEAU_FACTOR: 0.8
+ PLATEAU_PATIENCE: 3
+
+############################ LOSSES ############################
+LOSSES:
+ XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing'
+ LABELSMOOTHING: 0.1
+ RL_TYPE: 'RewardCriterion'
+
+############################ SCORER ############################
+SCORER:
+ TYPES: ['CIDEr']
+ WEIGHTS: [1.0]
+ GT_PATH: './mscoco/misc/coco_train_gts.pkl'
+ CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl'
+
+############################ INFERENCE ############################
+INFERENCE:
+ VOCAB: './mscoco/txt/coco_vocabulary.txt'
+ ID_KEY: 'image_id'
+ CAP_KEY: 'caption'
+ EVAL: 'COCO'
+ VAL_ANNFILE: './mscoco/misc/captions_val5k.json'
+ TEST_ANNFILE: './mscoco/misc/captions_test5k.json'
+ BEAM_SIZE: 2
+ GREEDY_DECODE: True
diff --git a/experiments_mimiccxr/xtransformer_rl/train.sh b/experiments_mimiccxr/xtransformer_rl/train.sh
new file mode 100644
index 0000000..c267ddd
--- /dev/null
+++ b/experiments_mimiccxr/xtransformer_rl/train.sh
@@ -0,0 +1,6 @@
+CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_mimiccxr/xtransformer_rl \
+ --dataset_name MIMICCXR \
+ --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json \
+ --submodel VSEGCN
+
+# 39 is the epoch number of the pretrained model
diff --git a/layers/sc_att.py b/layers/sc_att.py
index 7bc4dea..89afadf 100755
--- a/layers/sc_att.py
+++ b/layers/sc_att.py
@@ -12,6 +12,19 @@ def __init__(self, mid_dims, mid_dropout):
self.attention_last2 = nn.Linear(mid_dims[-2], mid_dims[-1])
def forward(self, att_map, att_mask, value1, value2):
+ """
+ att_map, att_mask, value1, value2
+
+ torch.Size([4, 8, 49, 96])
+ torch.Size([4, 49])
+ torch.Size([4, 8, 96])
+ torch.Size([4, 8, 49, 96])
+
+ """
+ # print('att_map, att_mask, value1, value2')
+ # for i in [att_map, att_mask, value1, value2]:
+ # print(i.shape)
+
if self.attention_basic is not None:
att_map = self.attention_basic(att_map)
@@ -37,4 +50,5 @@ def forward(self, att_map, att_mask, value1, value2):
value2 = torch.matmul(alpha_spatial.unsqueeze(-2), value2).squeeze(-2)
attn = value1 * value2 * alpha_channel
+ # raise Exception('lol')
return attn
diff --git a/lib/config.py b/lib/config.py
index a8da4fa..5c04958 100755
--- a/lib/config.py
+++ b/lib/config.py
@@ -87,13 +87,15 @@
# ---------------------------------------------------------------------------- #
__C.MODEL = edict()
+__C.MODEL.PretrainedImageModel = '/content/image-captioning/model_auc14.dict.gz' # TODO # Modified
+
__C.MODEL.TYPE = 'UpDown' # 'UpDown', 'XLAN', 'XTransformer'
-__C.MODEL.SEQ_LEN = 17 # include /
+__C.MODEL.SEQ_LEN = 60 # include / # modified
-__C.MODEL.VOCAB_SIZE = 9487 # exclude /
+__C.MODEL.VOCAB_SIZE = 760 # exclude / # TODO : IUXRAY: 760
-__C.MODEL.WORD_EMBED_DIM = 1000
+__C.MODEL.WORD_EMBED_DIM = 512 # TODO # Modified
__C.MODEL.WORD_EMBED_ACT = 'NONE' # 'RELU', 'CELU', 'NONE'
@@ -101,7 +103,7 @@
__C.MODEL.DROPOUT_WORD_EMBED = 0.0
-__C.MODEL.GVFEAT_DIM = 2048
+__C.MODEL.GVFEAT_DIM = 2048 # TODO
__C.MODEL.GVFEAT_EMBED_DIM = -1
@@ -109,7 +111,7 @@
__C.MODEL.DROPOUT_GV_EMBED = 0.0
-__C.MODEL.ATT_FEATS_DIM = 2048
+__C.MODEL.ATT_FEATS_DIM = 1024 # Not used. Modified on the init stage of model
__C.MODEL.ATT_FEATS_EMBED_DIM = -1
@@ -119,13 +121,13 @@
__C.MODEL.ATT_FEATS_NORM = False
-__C.MODEL.ATT_HIDDEN_SIZE = 512
+__C.MODEL.ATT_HIDDEN_SIZE = 512 # TODO
__C.MODEL.ATT_HIDDEN_DROP = 0.0
__C.MODEL.ATT_ACT = 'RELU' # 'RELU', 'CELU', 'TANH'
-__C.MODEL.RNN_SIZE = 1000
+__C.MODEL.RNN_SIZE = 1000 # TODO
__C.MODEL.DROPOUT_LM = 0.5
@@ -304,7 +306,7 @@
__C.INFERENCE.ID_KEY = 'image_id'
-__C.INFERENCE.CAP_KEY = 'caption'
+__C.INFERENCE.CAP_KEY = 'report' # Modified
__C.INFERENCE.EVAL = 'COCO'
diff --git a/lib/utils.py b/lib/utils.py
index d331c75..8b8c4e7 100755
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -75,6 +75,7 @@ def decode_sequence(vocab, seq):
words = []
for t in range(T):
ix = seq[n, t]
+ ix = ix.item()
if ix == 0:
break
words.append(vocab[ix])
diff --git a/main.py b/main.py
index f478dbb..03fe3a6 100755
--- a/main.py
+++ b/main.py
@@ -16,12 +16,174 @@
import losses
import models
import datasets
+from datasets.radiology_dataset import IUXRAY, MIMICCXR
+from datasets.tokenizers import Tokenizer
import lib.utils as utils
from lib.utils import AverageMeter
-from optimizer.optimizer import Optimizer
+from optimizer.optimizer import Optimizer, build_optimizer
from evaluation.evaler import Evaler
from scorer.scorer import Scorer
from lib.config import cfg, cfg_from_file
+from mlclassifier import GCNClassifier
+device = torch.device('cuda')
+
+def parse_args():
+ """
+ Parse input arguments
+ """
+ parser = argparse.ArgumentParser(description='Image Captioning')
+ parser.add_argument('--folder', dest='folder', type=str, default=None)
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument("--resume", type=int, default=-1)
+ parser.add_argument('--image_dir', type=str, default='/content/iu_xray_resized/images/',
+ help='the path to the directory containing the data.')
+ parser.add_argument('--ann_path', type=str, default='/content/iu_xray_resized/annotation.json',
+ help='the path to the directory containing the data.')
+
+ parser.add_argument('--dataset_name', type=str, default='IUXRAY', choices=['IUXRAY', 'MIMICCXR','MIMICCXR_MultiImages'],
+ help='the dataset to be used.')
+ parser.add_argument('--submodel', type=str, default='RGMG', choices=['RGMG', 'VSEGCN'],
+ help='the knowledge graph to be used.')
+ # Encoder Mode
+ parser.add_argument('--encoder_mode', type=str, default='normal', choices=['normal', 'dualwayencoder'],
+ help='Specify the transformer encoder')
+
+ parser.add_argument('--training_ratio', type = float, default = '1.0', help ='Select the training ratio. Recommend: 0.001, 0.005, 0.01, 0.1, 0.5 and 1.0')
+ parser.add_argument('--KG_path', type = str, help='the path to the pretrained kg checkpoint')
+
+
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(1)
+
+ args = parser.parse_args()
+ return args
+
+args = parse_args()
+
+if args.submodel =='RGMG' and args.dataset_name =='IUXRAY':
+ fw_adj = torch.tensor([
+ #FOR RGMG on IUXray
+ [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ], dtype=torch.float,device=device)
+elif args.submodel == 'VSEGCN' and args.dataset_name =='IUXRAY':
+ fw_adj = torch.tensor([
+#FOR VSEGCN on IUXray
+ [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08624227881040383, 0.0, 0.0, 0.0, 0.0, 0.08531678128946102, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.01865074607267221, 0.0, 0.2924299133554616, 0.0, 0.0, 0.21304488410089617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17180192556684684, 0.6418055548125825, 0.0, 0.5855658364897064, 0.0, 0.8458489347533727, 0.9602592859311171, 0.4274547554463511, 0.0, 0.7595885904689661, 0.0] ,
+ [0.0, 0.0, 0.01865074607267221, 0.0, 0.0, 0.4009853199098073, 1.5255730949811777, 0.0, 0.08521151259101123, 0.631755218959081, 0.0, 0.752383206747696, 0.0, 0.0, 0.331650626508743, 0.426960806313068, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7569183619130871, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.5610118144338234, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 0.059920020742461305, 0.12193009552667401, 0.0, 0.6916982549261147, 0.0, 0.028649105523448872, 0.0, 1.38484543548606, 0.0, 0.0, 0.0, 0.6395125017555444] ,
+ [0.0, 0.0, 0.2924299133554616, 0.4009853199098073, 0.5610118144338234, 0.0, 1.522720025998771, 0.6039632016294225, 1.1321805681072825, 0.9342837995278563, 1.281557969181883, 0.8673131734216728, 1.2434062032175066, 0.3490255809790961, 0.6754221656115674, 0.4370111421665694, 0.0, 0.0, 0.0, 0.0, 0.40348845012792584, 0.0, 0.0, 0.06928636204125185, 0.0] ,
+ [0.0, 0.0, 0.0, 1.5255730949811777, 0.0, 1.522720025998771, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1794995623878415, 0.0, 0.0, 1.535623430834679, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.6039632016294225, 0.0, 0.0, 0.0, 1.5819475025089174, 0.0, 0.3162811291776416, 0.0, 0.5912241758519928, 0.7710172862925886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8175373019274815, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.21304488410089617, 0.08521151259101123, 0.0, 1.1321805681072825, 0.0, 0.0, 0.0, 0.9061920646608415, 0.0, 0.7391379799976754, 0.0, 0.0, 1.1938741371126225, 0.0952618484445128, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.16859312153006234, 0.0, 1.376195693906577, 0.0] ,
+ [0.0, 0.0, 0.0, 0.631755218959081, 0.0, 0.9342837995278563, 0.0, 1.5819475025089174, 0.9061920646608415, 0.0, 0.9602592859311171, 1.6911467944739096, 0.3242705192111205, 0.39747392323441544, 0.8649491061267922, 0.50827416218806, 0.0, 0.0, 0.0, 0.0, 0.623787049309904, 0.0, 0.021989647338186796, 0.0, 0.46624078048150774] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 1.281557969181883, 0.0, 0.0, 0.0, 0.9602592859311171, 0.0, 0.793205201267951, 0.0, 1.068148247942302, 1.535623430834679, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 1.294461374017791, 0.0, 0.0, 0.0, 0.6669114759436587] ,
+ [0.0, 0.0, 0.0, 0.752383206747696, 0.0, 0.8673131734216728, 2.1794995623878415, 0.3162811291776416, 0.7391379799976754, 1.6911467944739096, 0.793205201267951, 0.0, 0.0, 0.007276287257039512, 1.3910422020235713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23358941333252825, 0.0, 0.0, 0.3693909544915901, 0.29918669581834156] ,
+ [0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 1.2434062032175066, 0.0, 0.0, 0.0, 0.3242705192111205, 0.0, 0.0, 0.0, 0.4321594812223054, 0.0, 0.20648748355473706, 0.2166398550187551, 0.0, 0.16826627073453929, 0.0, 0.6584726072977941, 0.0, 0.0, 0.0, 1.4172170703435527] ,
+ [0.0, 0.0, 0.0, 0.0, 0.059920020742461305, 0.3490255809790961, 0.0, 0.5912241758519928, 0.0, 0.39747392323441544, 1.068148247942302, 0.007276287257039512, 0.4321594812223054, 0.0, 0.6161631241992449, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1179703724409795, 0.35302216066358155, 0.0, 0.6443340011659412, 0.0] ,
+ [0.0, 0.0, 0.17180192556684684, 0.331650626508743, 0.12193009552667401, 0.6754221656115674, 1.535623430834679, 0.7710172862925886, 1.1938741371126225, 0.8649491061267922, 1.535623430834679, 1.3910422020235713, 0.0, 0.6161631241992449, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 1.0937906785556397, 0.3096717197899676, 0.0, 1.430262915176853, 0.3484577448251243] ,
+ [0.0, 0.0, 0.6418055548125825, 0.426960806313068, 0.0, 0.4370111421665694, 0.0, 0.0, 0.0952618484445128, 0.50827416218806, 0.0, 0.0, 0.20648748355473706, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.5060040136535207, 0.0, 0.3096717197899676, 0.4186620034983729, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.6916982549261147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2166398550187551, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10158746275990369, 0.029803617870273677, 0.0, 0.0, 0.137502534460031, 0.0, 0.07092804383736119] ,
+ [0.0, 0.08624227881040383, 0.5855658364897064, 0.0, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1461991767058608, 0.845848934753373, 0.0, 0.0, 0.9958502310338198, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.028649105523448872, 0.0, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.0, 0.16826627073453929, 0.0, 0.0, 0.0, 0.10158746275990369, 0.1461991767058608, 0.0, 0.08964361822629091, 0.0034771927022252134, 0.0, 0.08912895017581536, 0.0, 0.06907447518803858] ,
+ [0.0, 0.0, 0.8458489347533727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.029803617870273677, 0.845848934753373, 0.08964361822629091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.9602592859311171, 0.0, 1.38484543548606, 0.40348845012792584, 0.0, 0.8175373019274815, 0.0, 0.623787049309904, 1.294461374017791, 0.23358941333252825, 0.6584726072977941, 2.1179703724409795, 1.0937906785556397, 0.5060040136535207, 0.0, 0.0, 0.0034771927022252134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.4274547554463511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16859312153006234, 0.0, 0.0, 0.0, 0.0, 0.35302216066358155, 0.3096717197899676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.6449325692248835] ,
+ [0.0, 0.08531678128946102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.021989647338186796, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3096717197899676, 0.137502534460031, 0.9958502310338198, 0.08912895017581536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.7595885904689661, 0.7569183619130871, 0.0, 0.06928636204125185, 0.0, 0.0, 1.376195693906577, 0.0, 0.0, 0.3693909544915901, 0.0, 0.6443340011659412, 1.430262915176853, 0.4186620034983729, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.6395125017555444, 0.0, 0.0, 0.0, 0.0, 0.46624078048150774, 0.6669114759436587, 0.29918669581834156, 1.4172170703435527, 0.0, 0.3484577448251243, 0.0, 0.07092804383736119, 0.0, 0.06907447518803858, 0.0, 0.0, 0.6449325692248835, 0.0, 0.0, 0.0] ,
+ ], dtype=torch.float,device=device)
+elif args.submodel == 'VSEGCN' and args.dataset_name != 'IUXRAY':
+ fw_adj = torch.tensor([
+#FOR VSEGCN on mimic
+ [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ,
+ [0.0, 0.0, 0.554004673298568, 0.3230466430334976, 0.2631825039749924, 0.7016287008480984, 0.24746163509836083, 0.0010015098042039912, 0.5068736479106769, 0.0, 0.9524974820767607, 0.0, 0.274863833501621, 0.41494699762748494, 0.5840689600939755, 0.0, 0.0, 0.1302442384969287, 0.0, 0.6339873741551628, 0.0, 0.0, 0.0, 0.0, 0.13757450134915059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20484800268752174, 0.0, 0.3451355245988348, 0.0, 0.0, 0.31496065589330335] ,
+ [0.0, 0.554004673298568, 0.0, 0.0, 0.8777428349239078, 0.48756689514453816, 0.15609090494487807, 0.0, 0.2680360388191779, 0.0, 0.6199516595911282, 0.42238326753620825, 0.11635380051050957, 0.0020601220064393085, 0.6321690244338739, 0.0, 0.0, 0.08334337914718853, 0.5194350922631013, 0.12854873550846002, 0.6928120918404297, 0.0, 0.0, 0.9837299370816517, 0.2314271028213519, 0.5817251164456204, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33811183663241207, 0.0, 0.13936476949018967, 0.0, 0.0, 0.07089242056885223] ,
+ [0.0, 0.3230466430334976, 0.0, 0.0, 0.49497396267096466, 0.6594324804768135, 0.0, 0.706453671957811, 0.4967156785324145, 0.0, 0.8143326399158364, 0.09215690907124767, 1.0617056232382251, 0.1725070158434512, 0.3677521118316441, 0.0, 0.0, 0.0, 0.0, 0.31306543839429407, 1.8437693102858905, 0.0, 0.0, 0.0, 0.12775432598701517, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.0, 0.037260086847355586, 0.0, 0.7070285247253186, 0.0, 0.0, 0.36112875829496877] ,
+ [0.0, 0.2631825039749924, 0.8777428349239078, 0.49497396267096466, 0.0, 0.5081730413487658, 0.0, 0.0, 0.4345747745875523, 0.0, 0.9271810641740084, 0.0, 0.43551881399664394, 0.0, 0.32510239740035224, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5363277306682555, 0.0, 0.0, 0.35190615068655606, 0.5705232126845626, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6523482076409699, 0.05734284387314562, 0.029039130518496915, 0.0, 0.0, 0.47600439204660566] ,
+ [0.0, 0.7016287008480984, 0.48756689514453816, 0.6594324804768135, 0.5081730413487658, 0.0, 0.5879416440743396, 0.7282368105665861, 0.6655288739610595, 0.0, 0.5829843060689561, 1.1084681450499108, 0.23022619090918064, 0.7869908341567744, 0.9237236161240492, 0.0, 0.0, 0.00679805799268029, 0.3508343765662928, 0.7283807004431154, 1.2142749896595786, 0.0, 0.0, 0.21380809279477983, 0.0, 0.0, 0.0, 0.7384514725987071, 0.036197082879009246, 0.024547912417246583, 0.0, 0.0031487780150971255, 0.0, 0.17255665551694127, 0.0, 0.30464217507912805, 0.5800348339999603] ,
+ [0.0, 0.24746163509836083, 0.15609090494487807, 0.0, 0.0, 0.5879416440743396, 0.0, 0.0, 0.0, 0.0, 0.050047418864644054, 1.1040108892060951, 0.0, 1.0758558236410007, 0.20507481280870768, 0.0, 0.0, 0.20482994095079318, 0.20102931464117388, 0.5609723973768375, 0.5907785254095393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1108353406999546, 2.2722344254125035, 0.6708537856962394, 0.0, 0.0, 0.0, 0.656008131578071, 0.019380062443701017, 0.0] ,
+ [0.0, 0.0010015098042039912, 0.0, 0.706453671957811, 0.0, 0.7282368105665861, 0.0, 0.0, 0.8928317911567151, 0.0, 0.2412417134011427, 1.1756920438484275, 0.45754054092273994, 0.0, 0.0, 0.07862826830847675, 0.0, 0.08084792186159706, 0.17752298133873087, 0.0, 0.23167676395941764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22902820929658516, 0.0, 0.18822158336999106, 1.3653688666139736, 0.0, 0.0, 0.08262146112756505, 0.8291646636467345, 0.3474221602648512, 0.028679991724856257, 0.15823233737507106] ,
+ [0.0, 0.5068736479106769, 0.2680360388191779, 0.4967156785324145, 0.4345747745875523, 0.6655288739610595, 0.0, 0.8928317911567151, 0.0, 0.0, 0.5266439565166569, 0.7732181640322486, 0.8796823806455875, 0.04357377094721733, 0.38483053231749587, 0.0, 0.0, 0.01715279715950329, 0.0, 0.1935198098014637, 0.7726446050833446, 0.0, 0.0, 0.01203069298044543, 0.009614236915435854, 0.0, 0.028445852424580496, 0.0, 0.017029382779510584, 0.1773653331099141, 0.0, 0.0, 0.0, 0.7478386613505874, 0.0, 0.0, 0.3988597091752934] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11819146504902717, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.3441683893946561, 0.0, 0.0, 0.0, 0.056702571766598944, 0.19444256855484526, 0.0, 0.0, 0.034767887372249416, 0.0, 0.05535863446625411, 0.0, 0.0, 0.2006116936209859, 0.0] ,
+ [0.0, 0.9524974820767607, 0.6199516595911282, 0.8143326399158364, 0.9271810641740084, 0.5829843060689561, 0.050047418864644054, 0.2412417134011427, 0.5266439565166569, 0.0, 0.0, 0.27606588231803425, 0.2702763655065547, 0.5155464612930855, 0.5345046760872463, 0.0, 0.0, 0.0, 0.0, 0.6619248249548156, 0.11676497651446016, 0.0, 0.0, 0.012298886029973912, 0.36531125789699864, 0.22014333096829503, 0.0, 0.0, 0.1421597490181444, 0.0, 0.0, 0.5023838505056603, 0.007094738878332813, 0.16482707917210404, 0.12422641272658975, 0.0, 0.16806740202956802] ,
+ [0.0, 0.0, 0.42238326753620825, 0.09215690907124767, 0.0, 1.1084681450499108, 1.1040108892060951, 1.1756920438484275, 0.7732181640322486, 0.0, 0.27606588231803425, 0.0, 0.4571939169349286, 0.44285881406541816, 0.014435627721654973, 0.0, 0.0, 0.0, 0.35645685756909534, 0.3273858184572163, 0.0, 0.0, 0.0, 0.0, 0.22124160687808572, 0.11022312228841263, 0.04830169724837715, 0.0, 0.92753653417296, 0.8948599965191218, 0.15171053363715095, 0.28220534987677603, 0.3001183720922399, 0.5143515680120436, 0.2778525728199598, 0.0, 0.0] ,
+ [0.0, 0.274863833501621, 0.11635380051050957, 1.0617056232382251, 0.43551881399664394, 0.23022619090918064, 0.0, 0.45754054092273994, 0.8796823806455875, 0.0, 0.2702763655065547, 0.4571939169349286, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.448038949914647, 0.0, 0.0, 0.07261956980307784, 0.0, 0.0, 0.07421899390203622, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.594487035114069, 0.0, 0.0, 0.0] ,
+ [0.0, 0.41494699762748494, 0.0020601220064393085, 0.1725070158434512, 0.0, 0.7869908341567744, 1.0758558236410007, 0.0, 0.04357377094721733, 0.0, 0.5155464612930855, 0.44285881406541816, 0.0, 0.0, 0.9595683261315329, 0.0, 0.0, 0.0, 0.0, 2.4365238654909045, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4720717303053189, 0.5727697527229471, 0.0, 0.0, 0.0, 0.9575685889709006, 1.2377447423336991, 0.0, 0.0] ,
+ [0.0, 0.5840689600939755, 0.6321690244338739, 0.3677521118316441, 0.32510239740035224, 0.9237236161240492, 0.20507481280870768, 0.0, 0.38483053231749587, 0.0, 0.5345046760872463, 0.014435627721654973, 0.0, 0.9595683261315329, 0.0, 0.17363312653927349, 0.19895016785061964, 0.0, 0.0, 1.208717109379973, 0.0, 0.015167532069747917, 0.0, 0.06703478683378344, 0.6177554201058338, 0.4406078996794569, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3022525474987056, 0.0884563575795982, 0.0, 1.2708509083599127] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07862826830847675, 0.0, 0.11819146504902717, 0.0, 0.0, 0.0, 0.0, 0.17363312653927349, 0.0, 2.447688346482624, 0.17983209096647654, 0.0, 0.0, 0.0, 0.28338141933933964, 0.0, 0.0, 0.29702177034947524, 0.07753776902134689, 0.007825343616291445, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5708374900647346, 0.0, 0.15510336927054139, 0.3542350803376888, 0.32669031419868527] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.19895016785061964, 2.447688346482624, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.139880360686599, 0.0, 0.1439133818303898, 0.0, 0.0] ,
+ [0.0, 0.1302442384969287, 0.08334337914718853, 0.0, 0.0, 0.00679805799268029, 0.20482994095079318, 0.08084792186159706, 0.01715279715950329, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17983209096647654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4124407401173898, 0.27031693276904833, 0.6962440948684386, 0.01868026565576617, 0.0, 0.058711117532272095, 0.3108414982982819, 0.8917732120346288, 0.5028865181216288, 0.3710258416111474, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.5194350922631013, 0.0, 0.0, 0.3508343765662928, 0.20102931464117388, 0.17752298133873087, 0.0, 0.0, 0.0, 0.35645685756909534, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.33960729056256184, 0.14426353106038403, 0.48295034078748456, 0.04096791345986367, 0.0, 0.21459856249590378, 0.3394630945304, 0.9426984665379273, 0.22121090818098985, 0.3817453264511766, 0.3146902017201184, 0.2525866980495684, 0.041488578522214464, 0.0] ,
+ [0.0, 0.6339873741551628, 0.12854873550846002, 0.31306543839429407, 0.0, 0.7283807004431154, 0.5609723973768375, 0.0, 0.1935198098014637, 0.0, 0.6619248249548156, 0.3273858184572163, 0.0, 2.4365238654909045, 1.208717109379973, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.023645690589139633, 0.15053136759027846, 0.0, 0.24360268371473795, 0.0, 0.8821797270781289, 1.442822630962518, 0.0, 0.21531865324924443] ,
+ [0.0, 0.0, 0.6928120918404297, 1.8437693102858905, 0.5363277306682555, 1.2142749896595786, 0.5907785254095393, 0.23167676395941764, 0.7726446050833446, 0.0, 0.11676497651446016, 0.0, 1.448038949914647, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.0, 0.0, 0.33727296191859424, 0.0, 0.11424727258813801, 0.004340630185881583, 0.0, 0.0, 0.17398826794432173, 0.699350130525858, 0.0, 0.0, 0.4114035987596012, 0.0, 0.8603703454658308, 0.6698830923247899] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.0, 0.0, 0.0, 0.0, 0.015167532069747917, 0.28338141933933964, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.3311971887232972, 0.0, 0.221278095246663, 0.7821347726775432, 0.4309320418653175] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3441683893946561, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.9837299370816517, 0.0, 0.35190615068655606, 0.21380809279477983, 0.0, 0.0, 0.01203069298044543, 0.0, 0.012298886029973912, 0.0, 0.07261956980307784, 0.0, 0.06703478683378344, 0.0, 0.0, 0.4124407401173898, 0.33960729056256184, 0.0, 0.33727296191859424, 0.0, 0.0, 0.0, 0.0, 0.08979034229911677, 0.0, 0.38518142976105196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.13757450134915059, 0.2314271028213519, 0.12775432598701517, 0.5705232126845626, 0.0, 0.0, 0.0, 0.009614236915435854, 0.0, 0.36531125789699864, 0.22124160687808572, 0.0, 0.0, 0.6177554201058338, 0.29702177034947524, 0.0, 0.27031693276904833, 0.14426353106038403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.03024606051820196, 0.9015367412440394, 0.4187448648843476, 0.11315910768263383, 1.2343811011663848, 0.0, 0.24145505239298407] ,
+ [0.0, 0.0, 0.5817251164456204, 0.0, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22014333096829503, 0.11022312228841263, 0.0, 0.0, 0.4406078996794569, 0.07753776902134689, 0.0, 0.6962440948684386, 0.48295034078748456, 0.0, 0.11424727258813801, 0.0, 0.0, 0.08979034229911677, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3231554047178068, 1.2532084879227057, 0.5614893243263129, 0.046137943483739924, 1.540851917651678, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.22902820929658516, 0.028445852424580496, 0.056702571766598944, 0.0, 0.04830169724837715, 0.07421899390203622, 0.0, 0.0, 0.007825343616291445, 0.0, 0.01868026565576617, 0.04096791345986367, 0.0, 0.004340630185881583, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0564784860967394, 0.12567556923304798, 0.06594334312122997, 0.0, 0.0, 0.06688509334837037, 0.36269590043980027, 0.030319915064814112, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.7384514725987071, 0.0, 0.0, 0.0, 0.19444256855484526, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.38518142976105196, 0.0, 0.0, 0.0564784860967394, 0.0, 0.0, 0.0, 0.0, 0.2104088857291676, 0.4561074463243395, 0.0, 0.0, 0.0, 0.6247400008366559] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.036197082879009246, 0.1108353406999546, 0.18822158336999106, 0.017029382779510584, 0.0, 0.1421597490181444, 0.92753653417296, 0.0, 0.4720717303053189, 0.0, 0.0, 0.0, 0.058711117532272095, 0.21459856249590378, 0.023645690589139633, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.0, 0.12567556923304798, 0.0, 0.0, 0.4309948864530749, 0.0, 0.02542734722301152, 0.0, 0.0, 0.2510742769865065, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.024547912417246583, 2.2722344254125035, 1.3653688666139736, 0.1773653331099141, 0.0, 0.0, 0.8948599965191218, 0.0, 0.5727697527229471, 0.0, 0.0, 0.0, 0.3108414982982819, 0.3394630945304, 0.15053136759027846, 0.17398826794432173, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06594334312122997, 0.0, 0.4309948864530749, 0.0, 0.8692491673212553, 0.0, 0.19435111707051575, 0.266221588915103, 0.676971258598654, 0.20262509195680156, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6708537856962394, 0.0, 0.0, 0.034767887372249416, 0.0, 0.15171053363715095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8917732120346288, 0.9426984665379273, 0.0, 0.699350130525858, 0.0, 0.0, 0.0, 0.03024606051820196, 0.3231554047178068, 0.0, 0.0, 0.0, 0.8692491673212553, 0.0, 1.1605097349751408, 0.3461994713504898, 0.0, 1.5082427073912366, 0.0, 0.0] ,
+ [0.0, 0.20484800268752174, 0.33811183663241207, 0.037260086847355586, 0.6523482076409699, 0.0031487780150971255, 0.0, 0.0, 0.0, 0.0, 0.5023838505056603, 0.28220534987677603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5028865181216288, 0.22121090818098985, 0.24360268371473795, 0.0, 0.0, 0.0, 0.0, 0.9015367412440394, 1.2532084879227057, 0.0, 0.2104088857291676, 0.02542734722301152, 0.0, 1.1605097349751408, 0.0, 1.681814915277294, 0.0, 0.43925238577431364, 0.04338409257246071, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.05734284387314562, 0.0, 0.0, 0.08262146112756505, 0.0, 0.05535863446625411, 0.007094738878332813, 0.3001183720922399, 0.0, 0.0, 0.0, 1.5708374900647346, 2.139880360686599, 0.3710258416111474, 0.3817453264511766, 0.0, 0.0, 0.3311971887232972, 0.0, 0.0, 0.4187448648843476, 0.5614893243263129, 0.06688509334837037, 0.4561074463243395, 0.0, 0.19435111707051575, 0.3461994713504898, 1.681814915277294, 0.0, 0.2674633965945188, 0.8740258958015569, 0.2474464019557586, 0.0] ,
+ [0.0, 0.3451355245988348, 0.13936476949018967, 0.7070285247253186, 0.029039130518496915, 0.17255665551694127, 0.0, 0.8291646636467345, 0.7478386613505874, 0.0, 0.16482707917210404, 0.5143515680120436, 0.594487035114069, 0.9575685889709006, 0.3022525474987056, 0.0, 0.0, 0.0, 0.3146902017201184, 0.8821797270781289, 0.4114035987596012, 0.0, 0.0, 0.0, 0.11315910768263383, 0.046137943483739924, 0.36269590043980027, 0.0, 0.0, 0.266221588915103, 0.0, 0.0, 0.2674633965945188, 0.0, 0.42342916886466814, 0.0, 0.33323103097930845] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.656008131578071, 0.3474221602648512, 0.0, 0.0, 0.12422641272658975, 0.2778525728199598, 0.0, 1.2377447423336991, 0.0884563575795982, 0.15510336927054139, 0.1439133818303898, 0.0, 0.2525866980495684, 1.442822630962518, 0.0, 0.221278095246663, 0.0, 0.0, 1.2343811011663848, 1.540851917651678, 0.030319915064814112, 0.0, 0.2510742769865065, 0.676971258598654, 1.5082427073912366, 0.43925238577431364, 0.8740258958015569, 0.42342916886466814, 0.0, 0.0, 0.0] ,
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.30464217507912805, 0.019380062443701017, 0.028679991724856257, 0.0, 0.2006116936209859, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3542350803376888, 0.0, 0.0, 0.041488578522214464, 0.0, 0.8603703454658308, 0.7821347726775432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20262509195680156, 0.0, 0.04338409257246071, 0.2474464019557586, 0.0, 0.0, 0.0, 0.0] ,
+ [0.0, 0.31496065589330335, 0.07089242056885223, 0.36112875829496877, 0.47600439204660566, 0.5800348339999603, 0.0, 0.15823233737507106, 0.3988597091752934, 0.0, 0.16806740202956802, 0.0, 0.0, 0.0, 1.2708509083599127, 0.32669031419868527, 0.0, 0.0, 0.0, 0.21531865324924443, 0.6698830923247899, 0.4309320418653175, 0.0, 0.0, 0.24145505239298407, 0.0, 0.0, 0.6247400008366559, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33323103097930845, 0.0, 0.0, 0.0] ,
+ ], dtype=torch.float,device=device)
+else:
+ raise Nonetype("There is no this kind of KG or dataset")
+
+
+bw_adj = fw_adj.t()
+num_feat = fw_adj.shape[0] - 1
+
+submodel = GCNClassifier(num_feat, fw_adj, bw_adj)
+
+state_dict = submodel.state_dict()
+
+# Load the kg path from argument
+
+# if args.submodel =='RGMG' and args.dataset_name =='IUXRAY':
+# KG_path = '/content/pretrainedKG/gcnclassifier_v2_ones3_t401v2t3_lr1e-6_e80.pth'
+#
+# elif args.submodel == 'VSEGCN' and args.dataset_name =='IUXRAY':
+# KG_path = '/content/pretrainedKG/iuxray_gcnclassifier_v1_ones3_t0v1t2_lr1e-6_23050521_e180.pth'
+#
+# elif args.submodel == 'VSEGCN' and args.dataset_name =='MIMICCXR':
+# KG_path = '/content/pretrainedKG/mimic_gcnclassifier_v1_ones3_t0v1t2_lr1e-6_24052021_e10.pth'
+# else:
+# raise Nonetype("There is no this kind of KG or dataset")
+
+state_dict.update({k:v for k, v in torch.load(args.KG_path).items() if k in state_dict})
+
+submodel.load_state_dict(state_dict)
+
class Trainer(object):
def __init__(self, args):
@@ -35,28 +197,37 @@ def __init__(self, args):
self.num_gpus = torch.cuda.device_count()
self.distributed = self.num_gpus > 1
if self.distributed:
+ print('init distributed')
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
self.device = torch.device("cuda")
+ # self.device = 'cpu'
self.rl_stage = False
self.setup_logging()
self.setup_dataset()
self.setup_network()
self.val_evaler = Evaler(
- eval_ids = cfg.DATA_LOADER.VAL_ID,
- gv_feat = cfg.DATA_LOADER.VAL_GV_FEAT,
- att_feats = cfg.DATA_LOADER.VAL_ATT_FEATS,
- eval_annfile = cfg.INFERENCE.VAL_ANNFILE
- )
+ datasets.create(name = args.dataset_name,
+ image_dir=args.image_dir,
+ ann_path=args.ann_path,
+ tokenizer=self.tokenizer,
+ split='val',
+ args = args,
+ ),
+ tokenizer=self.tokenizer
+ ) # TODO
self.test_evaler = Evaler(
- eval_ids = cfg.DATA_LOADER.TEST_ID,
- gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT,
- att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS,
- eval_annfile = cfg.INFERENCE.TEST_ANNFILE
- )
+ datasets.create(name = args.dataset_name,
+ image_dir=args.image_dir,
+ ann_path=args.ann_path,
+ tokenizer=self.tokenizer,
+ split='test',
+ args=args),
+ tokenizer=self.tokenizer
+ ) # TODO
self.scorer = Scorer()
def setup_logging(self):
@@ -64,7 +235,7 @@ def setup_logging(self):
self.logger.setLevel(logging.INFO)
if self.distributed and dist.get_rank() > 0:
return
-
+
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.INFO)
formatter = logging.Formatter("[%(levelname)s: %(asctime)s] %(message)s")
@@ -73,65 +244,76 @@ def setup_logging(self):
if not os.path.exists(cfg.ROOT_DIR):
os.makedirs(cfg.ROOT_DIR)
-
+
fh = logging.FileHandler(os.path.join(cfg.ROOT_DIR, cfg.LOGGER_NAME + '.txt'))
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
-
+
self.logger.info('Training with config:')
self.logger.info(pprint.pformat(cfg))
def setup_network(self):
- model = models.create(cfg.MODEL.TYPE)
+ # model = models.create(cfg.MODEL.TYPE, args)
+ model = models.create('XTransformer', args, submodel = submodel)
if self.distributed:
# this should be removed if we update BatchNorm stats
self.model = torch.nn.parallel.DistributedDataParallel(
- model.to(self.device),
- device_ids = [self.args.local_rank],
- output_device = self.args.local_rank,
- broadcast_buffers = False
+ model.to(self.device),
+ device_ids=[self.args.local_rank],
+ output_device=self.args.local_rank,
+ broadcast_buffers=False
)
else:
- self.model = torch.nn.DataParallel(model).cuda()
+ # self.model = torch.nn.DataParallel(model).cuda() # strange
+ self.model = model.cuda() # strange
if self.args.resume > 0:
self.model.load_state_dict(
torch.load(self.snapshot_path("caption_model", self.args.resume),
- map_location=lambda storage, loc: storage)
+ map_location=lambda storage, loc: storage)
)
self.optim = Optimizer(self.model)
+ # self.optim = build_optimizer(args, model)
self.xe_criterion = losses.create(cfg.LOSSES.XE_TYPE).cuda()
self.rl_criterion = losses.create(cfg.LOSSES.RL_TYPE).cuda()
-
+
def setup_dataset(self):
- self.coco_set = datasets.coco_dataset.CocoDataset(
- image_ids_path = cfg.DATA_LOADER.TRAIN_ID,
- input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH,
- target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH,
- gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT,
- att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS,
- seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG,
- max_feat_num = cfg.DATA_LOADER.MAX_FEAT
- )
+ self.tokenizer = Tokenizer(ann_path=args.ann_path, dataset_name=args.dataset_name)
+ self.dataset = datasets.create(name = args.dataset_name,
+ image_dir=args.image_dir,
+ ann_path=args.ann_path,
+ tokenizer=self.tokenizer,
+ split='train',
+ args=args,
+ )
+ # self.coco_set = datasets.coco_dataset.CocoDataset(
+ # image_ids_path = cfg.DATA_LOADER.TRAIN_ID,
+ # input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH,
+ # target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH,
+ # gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT,
+ # att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS,
+ # seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG,
+ # max_feat_num = cfg.DATA_LOADER.MAX_FEAT
+ # )
def setup_loader(self, epoch):
self.training_loader = datasets.data_loader.load_train(
- self.distributed, epoch, self.coco_set)
+ self.distributed, epoch, self.dataset)
def eval(self, epoch):
if (epoch + 1) % cfg.SOLVER.TEST_INTERVAL != 0:
return None
if self.distributed and dist.get_rank() > 0:
return None
-
+
val_res = self.val_evaler(self.model, 'val_' + str(epoch + 1))
self.logger.info('######## Epoch (VAL)' + str(epoch + 1) + ' ########')
self.logger.info(str(val_res))
- test_res = self.test_evaler(self.model,'test_' + str(epoch + 1))
+ test_res = self.test_evaler(self.model, 'test_' + str(epoch + 1))
self.logger.info('######## Epoch (TEST)' + str(epoch + 1) + ' ########')
self.logger.info(str(test_res))
@@ -152,11 +334,12 @@ def save_model(self, epoch):
snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot')
if not os.path.exists(snapshot_folder):
os.mkdir(snapshot_folder)
- torch.save(self.model.state_dict(), self.snapshot_path("caption_model", epoch+1))
+ torch.save(self.model.state_dict(), self.snapshot_path("caption_model", epoch + 1))
def make_kwargs(self, indices, input_seq, target_seq, gv_feat, att_feats, att_mask):
seq_mask = (input_seq > 0).type(torch.cuda.LongTensor)
- seq_mask[:,0] += 1
+ # print(seq_mask)
+ seq_mask[:, 0] += 1
seq_mask_sum = seq_mask.sum(-1)
max_len = int(seq_mask_sum.max())
input_seq = input_seq[:, 0:max_len].contiguous()
@@ -176,7 +359,7 @@ def scheduled_sampling(self, epoch):
if epoch > cfg.TRAIN.SCHEDULED_SAMPLING.START:
frac = (epoch - cfg.TRAIN.SCHEDULED_SAMPLING.START) // cfg.TRAIN.SCHEDULED_SAMPLING.INC_EVERY
ss_prob = min(cfg.TRAIN.SCHEDULED_SAMPLING.INC_PROB * frac, cfg.TRAIN.SCHEDULED_SAMPLING.MAX_PROB)
- self.model.module.ss_prob = ss_prob
+ # self.model.ss_prob = ss_prob
def display(self, iteration, data_time, batch_time, losses, loss_info):
if iteration % cfg.SOLVER.DISPLAY != 0:
@@ -184,7 +367,7 @@ def display(self, iteration, data_time, batch_time, losses, loss_info):
if self.distributed and dist.get_rank() > 0:
return
info_str = ' (DataTime/BatchTime: {:.3}/{:.3}) losses = {:.5}'.format(data_time.avg, batch_time.avg, losses.avg)
- self.logger.info('Iteration ' + str(iteration) + info_str +', lr = ' + str(self.optim.get_lr()))
+ self.logger.info('Iteration ' + str(iteration) + info_str + ', lr = ' + str(self.optim.get_lr()))
for name in sorted(loss_info):
self.logger.info(' ' + name + ' = ' + str(loss_info[name]))
data_time.reset()
@@ -200,6 +383,7 @@ def forward(self, kwargs):
gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT]
att_feats = kwargs[cfg.PARAM.ATT_FEATS]
att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+ target_seq = kwargs[cfg.PARAM.TARGET_SENT]
# max
kwargs['BEAM_SIZE'] = 1
@@ -210,12 +394,12 @@ def forward(self, kwargs):
self.model.eval()
with torch.no_grad():
- seq_max, logP_max = self.model.module.decode(**kwargs)
+ seq_max, logP_max = self.model.decode(**kwargs)
self.model.train()
- rewards_max, rewards_info_max = self.scorer(ids, seq_max.data.cpu().numpy().tolist())
+ rewards_max, rewards_info_max = self.scorer(target_seq, seq_max.data.cpu().numpy().tolist()) # Modified
rewards_max = utils.expand_numpy(rewards_max)
- ids = utils.expand_numpy(ids)
+ ids = utils.expand_numpy(ids) # to check?
gv_feat = utils.expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG)
att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
@@ -228,12 +412,13 @@ def forward(self, kwargs):
kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
seq_sample, logP_sample = self.model.module.decode(**kwargs)
- rewards_sample, rewards_info_sample = self.scorer(ids, seq_sample.data.cpu().numpy().tolist())
+ rewards_sample, rewards_info_sample = self.scorer(target_seq,
+ seq_sample.data.cpu().numpy().tolist()) # Modified
rewards = rewards_sample - rewards_max
rewards = torch.from_numpy(rewards).float().cuda()
loss = self.rl_criterion(seq_sample, logP_sample, rewards)
-
+
loss_info = {}
for key in rewards_info_sample:
loss_info[key + '_sample'] = rewards_info_sample[key]
@@ -247,7 +432,7 @@ def train(self):
self.optim.zero_grad()
iteration = 0
- for epoch in range(cfg.SOLVER.MAX_EPOCH):
+ for epoch in range(cfg.SOLVER.MAX_EPOCH):
if epoch == cfg.TRAIN.REINFORCEMENT.START:
self.rl_stage = True
self.setup_loader(epoch)
@@ -264,16 +449,19 @@ def train(self):
gv_feat = gv_feat.cuda()
att_feats = att_feats.cuda()
att_mask = att_mask.cuda()
+ # att_mask = torch.ones(16,70).cuda()
+ # print(att_mask.shape)
+
kwargs = self.make_kwargs(indices, input_seq, target_seq, gv_feat, att_feats, att_mask)
loss, loss_info = self.forward(kwargs)
loss.backward()
- utils.clip_gradient(self.optim.optimizer, self.model,
- cfg.SOLVER.GRAD_CLIP_TYPE, cfg.SOLVER.GRAD_CLIP)
+ # utils.clip_gradient(self.optim.optimizer, self.model,
+ # cfg.SOLVER.GRAD_CLIP_TYPE, cfg.SOLVER.GRAD_CLIP)
self.optim.step()
self.optim.zero_grad()
- self.optim.scheduler_step('Iter')
-
+ # self.optim.scheduler_step('Iter')
+
batch_time.update(time.time() - start)
start = time.time()
losses.update(loss.item())
@@ -282,31 +470,15 @@ def train(self):
if self.distributed:
dist.barrier()
-
+
self.save_model(epoch)
val = self.eval(epoch)
- self.optim.scheduler_step('Epoch', val)
- self.scheduled_sampling(epoch)
+ # self.optim.scheduler_step('Epoch', val)
+ # self.scheduled_sampling(epoch)
if self.distributed:
dist.barrier()
-def parse_args():
- """
- Parse input arguments
- """
- parser = argparse.ArgumentParser(description='Image Captioning')
- parser.add_argument('--folder', dest='folder', type=str, default=None)
- parser.add_argument("--local_rank", type=int, default=0)
- parser.add_argument("--resume", type=int, default=-1)
-
- if len(sys.argv) == 1:
- parser.print_help()
- sys.exit(1)
-
- args = parser.parse_args()
- return args
-
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
diff --git a/mlclassifier.py b/mlclassifier.py
new file mode 100644
index 0000000..9223340
--- /dev/null
+++ b/mlclassifier.py
@@ -0,0 +1,185 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+
+class MLClassifier(nn.Module):
+
+ def __init__(self, num_classes):
+ super().__init__()
+ self.densenet121 = models.densenet121(pretrained=True)
+ num_ftrs = self.densenet121.classifier.in_features
+ # self.backbone = nn.Sequential(*list(densenet.children())[:-1])
+ self.densenet121.classifier = nn.Linear(num_ftrs, num_classes)
+
+ def forward(self, img1, img2):
+ x1 = self.densenet121.features(img1)
+ x1 = F.relu(x1, inplace=True)
+ x1 = F.adaptive_avg_pool2d(x1, (1, 1))
+ x1 = torch.flatten(x1, 1)
+ x1 = self.densenet121.classifier(x1)
+ x2 = self.densenet121.features(img2)
+ x2 = F.relu(x2, inplace=True)
+ x2 = F.adaptive_avg_pool2d(x2, (1, 1))
+ x2 = torch.flatten(x2, 1)
+ x2 = self.densenet121.classifier(x2)
+ return x1 + x2
+
+
+class ClsAttention(nn.Module):
+
+ def __init__(self, feat_size, num_classes):
+ super().__init__()
+ self.feat_size = feat_size
+ self.num_classes = num_classes
+ self.channel_w = nn.Conv2d(feat_size, num_classes, 1, bias=False)
+
+ def forward(self, feats):
+ # feats: batch size x feat size x H x W
+ batch_size, feat_size, H, W = feats.size()
+ att_maps = self.channel_w(feats)
+ att_maps = torch.softmax(att_maps.view(batch_size, self.num_classes, -1), dim=2)
+ feats_t = feats.view(batch_size, feat_size, H * W).permute(0, 2, 1)
+ cls_feats = torch.bmm(att_maps, feats_t)
+ return cls_feats
+
+
+class GCLayer(nn.Module):
+
+ def __init__(self, in_size, state_size):
+ super().__init__()
+ self.condense = nn.Conv1d(in_size, state_size, 1, bias=False)
+ self.condense_norm = nn.BatchNorm1d(state_size)
+ self.fw_trans = nn.Conv1d(in_size, state_size, 1, bias=False)
+ self.fw_norm = nn.BatchNorm1d(state_size)
+ self.bw_trans = nn.Conv1d(in_size, state_size, 1, bias=False)
+ self.bw_norm = nn.BatchNorm1d(state_size)
+ self.update = nn.Conv1d(3 * state_size, in_size, 1, bias=False)
+ self.update_norm = nn.BatchNorm1d(in_size)
+ self.relu = nn.ReLU(inplace=True)
+ # v2:
+ self.dropout = nn.Dropout(0.5)
+
+ def forward(self, states, fw_A, bw_A):
+ # states: batch size x feat size x nodes
+ condensed = self.relu(self.condense_norm(self.condense(states)))
+ fw_msg = self.relu(self.fw_norm(self.fw_trans(states).bmm(fw_A)))
+ bw_msg = self.relu(self.bw_norm(self.bw_trans(states).bmm(bw_A)))
+ updated = self.update_norm(self.update(torch.cat((condensed, fw_msg, bw_msg), dim=1)))
+ updated = self.relu(self.dropout(updated) + states)
+ return updated
+
+
+class GCN(nn.Module):
+
+ def __init__(self, in_size, state_size, steps=3):
+ super().__init__()
+ self.in_size = in_size
+ self.state_size = state_size
+ self.steps = steps
+
+ # layers = []
+ # for istep in range(steps):
+ # layers.append(GCLayer(in_size, state_size))
+ # self.layers = nn.Sequential(*layers)
+ self.layer1 = GCLayer(in_size, state_size)
+ self.layer2 = GCLayer(in_size, state_size)
+ self.layer3 = GCLayer(in_size, state_size)
+
+ def forward(self, states, fw_A, bw_A):
+ states = states.permute(0, 2, 1)
+ states = self.layer1(states, fw_A, bw_A)
+ states = self.layer2(states, fw_A, bw_A)
+ states = self.layer3(states, fw_A, bw_A)
+ return states.permute(0, 2, 1)
+
+
+class GCNClassifier(nn.Module):
+
+ def __init__(self, num_classes, fw_adj, bw_adj):
+ super().__init__()
+ self.num_classes = num_classes
+ self.densenet121 = models.densenet121(pretrained=True)
+ feat_size = self.densenet121.classifier.in_features
+ self.densenet121.classifier = nn.Linear(feat_size, num_classes)
+ self.cls_atten = ClsAttention(feat_size, num_classes)
+ self.gcn = GCN(feat_size, 256)
+ # v1:
+ self.fcs = nn.ModuleList([nn.Linear(feat_size, 1) for _ in range(num_classes)])
+ # v2:
+ self.fc2 = nn.Linear(feat_size, num_classes)
+
+ fw_D = torch.diag_embed(fw_adj.sum(dim=1))
+ bw_D = torch.diag_embed(bw_adj.sum(dim=1))
+ inv_sqrt_fw_D = fw_D.pow(-0.5)
+ inv_sqrt_fw_D[torch.isinf(inv_sqrt_fw_D)] = 0
+ inv_sqrt_bw_D = bw_D.pow(-0.5)
+ inv_sqrt_bw_D[torch.isinf(inv_sqrt_bw_D)] = 0
+ self.fw_A = inv_sqrt_fw_D.mm(fw_adj).mm(inv_sqrt_bw_D)
+ self.bw_A = inv_sqrt_bw_D.mm(bw_adj).mm(inv_sqrt_fw_D)
+
+ self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
+
+ def forward(self, img1, img2 = None):
+ if img2 is not None:
+ batch_size = img1.size(0)
+ fw_A = self.fw_A.repeat(batch_size, 1, 1)
+ bw_A = self.bw_A.repeat(batch_size, 1, 1)
+ cnn_feats1 = self.densenet121.features(img1) #no linear layer
+ cnn_feats2 = self.densenet121.features(img2) #cnn_feats1 torch.Size([16, 1024, 7, 7])
+ # print('cnn_feats1',cnn_feats1.shape)
+ global_feats1 = cnn_feats1.mean(dim=(2, 3))
+ global_feats2 = cnn_feats2.mean(dim=(2, 3))
+ cls_feats1 = self.cls_atten(cnn_feats1)
+ cls_feats2 = self.cls_atten(cnn_feats2)
+ node_feats1 = torch.cat((global_feats1.unsqueeze(1), cls_feats1), dim=1)
+ node_feats2 = torch.cat((global_feats2.unsqueeze(1), cls_feats2), dim=1)
+ node_states1 = self.gcn(node_feats1, fw_A, bw_A)
+ node_states2 = self.gcn(node_feats2, fw_A, bw_A)
+ # v1:
+ # logits = img1.new_zeros((batch_size, self.num_classes), dtype=torch.float)
+ # for c in range(self.num_classes):
+ # logits[:, c] = self.fcs[c](node_states1[:, c+1] + node_states2[:, c+1]).squeeze(1)
+ # return logits
+ # v2:
+ # global_states = node_states1.mean(dim=1) + node_states2.mean(dim=1)
+ # global_states = torch.cat((node_states1.mean(dim = 1), node_states2.mean(dim = 1)),dim = 1)
+
+ cnn_feats1_reshaped = cnn_feats1.reshape(cnn_feats1.size(0),cnn_feats1.size(1),-1).permute(0,2,1)
+ cnn_feats2_reshaped = cnn_feats2.reshape(cnn_feats2.size(0),cnn_feats2.size(1),-1).permute(0,2,1)
+
+ cnn_feats = torch.cat((cnn_feats1_reshaped,cnn_feats2_reshaped),dim = 2)
+ # logits = self.fc2(global_states)
+ # print('global_states',global_states.shape)
+ # print('logits',logits.shape)
+ node_states = torch.cat((node_states1, node_states2), dim = 2)
+
+
+ avg_feats1 = self.avg_fnt(cnn_feats1).squeeze().reshape(-1, cnn_feats1.size(1))
+ avg_feats2 = self.avg_fnt(cnn_feats2).squeeze().reshape(-1, cnn_feats1.size(1))
+ global_states = torch.cat((avg_feats1, avg_feats2), dim = 1)
+ # cnn_feats torch.Size([16, 49, 2048])
+ # node_states torch.Size([16, 21, 2048])
+ # global_states torch.Size([16, 2048])
+ if img2 is None:
+ batch_size = img1.size(0)
+ fw_A = self.fw_A.repeat(batch_size, 1, 1)
+ bw_A = self.bw_A.repeat(batch_size, 1, 1)
+ cnn_feats1 = self.densenet121.features(img1) #no linear layer
+ # print('cnn_feats1',cnn_feats1.shape)
+ global_feats1 = cnn_feats1.mean(dim=(2, 3))
+ cls_feats1 = self.cls_atten(cnn_feats1)
+ node_feats1 = torch.cat((global_feats1.unsqueeze(1), cls_feats1), dim=1)
+ node_states1 = self.gcn(node_feats1, fw_A, bw_A)
+
+ cnn_feats1_reshaped = cnn_feats1.reshape(cnn_feats1.size(0),cnn_feats1.size(1),-1).permute(0,2,1)
+ cnn_feats = cnn_feats1_reshaped
+
+ node_states = node_states1
+
+
+ avg_feats1 = self.avg_fnt(cnn_feats1).squeeze().reshape(-1, cnn_feats1.size(1))
+ global_states = avg_feats1
+
+ return cnn_feats, node_states, global_states
+
diff --git a/models/__init__.py b/models/__init__.py
index 9f56b54..222b6ef 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,17 +1,22 @@
from models.updown import UpDown
from models.xlan import XLAN
from models.xtransformer import XTransformer
+from models.dwextransformer import DWEXTransformer
+
__factory = {
'UpDown': UpDown,
'XLAN': XLAN,
- 'XTransformer': XTransformer
+ 'XTransformer': XTransformer,
+ 'DWEXtransformer' : DWEXTransformer
}
def names():
return sorted(__factory.keys())
-def create(name, *args, **kwargs):
+def create(name, args, submodel, **kwargs):
+ if name == 'XTransformer' and args.encoder_mode == 'dualwayencoder':
+ name = 'DWEXtransformer'
if name not in __factory:
raise KeyError("Unknown caption model:", name)
- return __factory[name](*args, **kwargs)
\ No newline at end of file
+ return __factory[name](args, submodel, **kwargs)
diff --git a/models/dwextransformer.py b/models/dwextransformer.py
new file mode 100644
index 0000000..378b667
--- /dev/null
+++ b/models/dwextransformer.py
@@ -0,0 +1,774 @@
+import copy
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from lib.config import cfg
+
+from layers.low_rank import LowRank
+import blocks
+import lib.utils as utils
+from models.basic_model import BasicModel
+from layers.positional_encoding import PositionalEncoding
+from .pretrained_models import ImageClassification
+from mlclassifier import GCNClassifier
+
+device = torch.device("cuda")
+
+
+def subsequent_mask(size):
+ "Mask out subsequent positions."
+ attn_shape = (1, size, size)
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
+ return torch.from_numpy(subsequent_mask) == 0
+
+
+def clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+
+def makeMask(input_features):
+ bs, seq, feats = input_features.shape
+ return torch.ones(bs, seq, dtype=torch.long).cuda() # hardcode to cuda
+
+
+class DWEXTransformer(BasicModel):
+ def __init__(self, args, submodel):
+ super(DWEXTransformer, self).__init__()
+ self.vocab_size = cfg.MODEL.VOCAB_SIZE + 1
+ # image pretrained
+ # self.image_pretrained_models, self.input_visual_feats = ImageClassification.image_features(
+ # 'densenet', fixed_weight=False, pretrained_model=cfg.MODEL.PretrainedImageModel)
+ if args.dataset_name == 'IUXRAY':
+ num_images = 2
+ self.get_visual_features = self.forward_iuxray
+ elif args.dataset_name == 'MIMICCXR':
+ num_images = 1
+ self.get_visual_features = self.forward_mimiccxr
+ elif args.dataset_name == 'MIMICCXR_MultiImages':
+ num_images = 2
+ self.get_visual_features = self.forward_mimiccxr
+
+ # att_feats encoder
+ cnn_sequential = []
+ self.input_visual_feats = 1024
+ cnn_sequential.append(nn.Linear(self.input_visual_feats * num_images, cfg.MODEL.ATT_FEATS_EMBED_DIM))
+ cnn_sequential.append(utils.activation(cfg.MODEL.ATT_FEATS_EMBED_ACT))
+ if cfg.MODEL.ATT_FEATS_NORM == True:
+ cnn_sequential.append(nn.LayerNorm(cfg.MODEL.ATT_FEATS_EMBED_DIM))
+ if cfg.MODEL.DROPOUT_ATT_EMBED > 0:
+ cnn_sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED))
+
+ gcn_sequential = copy.deepcopy(cnn_sequential)
+
+ self.cnn_embed = nn.Sequential(*cnn_sequential) if len(cnn_sequential) > 0 else None
+ self.gcn_embed = nn.Sequential(*gcn_sequential) if len(gcn_sequential) > 0 else None
+
+ self.encoder = Encoder(
+ embed_dim=cfg.MODEL.BILINEAR.DIM,
+ dropout=cfg.MODEL.BILINEAR.ENCODE_DROPOUT,
+ att_type=cfg.MODEL.BILINEAR.ATTTYPE,
+ att_heads=cfg.MODEL.BILINEAR.HEAD,
+ att_mid_dim=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM,
+ att_mid_drop=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT,
+ bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
+ bifeat_emb_drop=cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT,
+ ff_dropout=cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT,
+ layer_num=cfg.MODEL.BILINEAR.ENCODE_LAYERS)
+
+ self.decoder = Decoder(
+ vocab_size=self.vocab_size,
+ embed_dim=cfg.MODEL.BILINEAR.DIM,
+ dropout=cfg.MODEL.BILINEAR.DECODE_DROPOUT,
+ att_type=cfg.MODEL.BILINEAR.ATTTYPE,
+ att_heads=cfg.MODEL.BILINEAR.HEAD,
+ att_mid_dim=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM,
+ att_mid_drop=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT,
+ bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
+ bifeat_emb_drop=cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT,
+ ff_dropout=cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT,
+ layer_num=cfg.MODEL.BILINEAR.DECODE_LAYERS)
+ self.submodel = submodel
+
+ def forward(self, **kwargs):
+ # forward entry
+
+ att_feats = kwargs[cfg.PARAM.ATT_FEATS]
+ seq = kwargs[cfg.PARAM.INPUT_SENT]
+ att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+ # att_mask = torch.ones(16,70).to(device)
+ att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
+ att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
+ # HARDCODE: att_mask = None
+ # Regenerate later
+ att_mask = None
+
+ ##############################################
+ seq_mask = (seq > 0).type(torch.cuda.IntTensor)
+ seq_mask[:, 0] += 1
+ seq_mask = seq_mask.unsqueeze(-2)
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
+ seq_mask = seq_mask.type(torch.cuda.FloatTensor)
+ ##############################################
+
+ cnn_feats, gcn_feats = self.get_visual_features(att_feats)
+ all_feats = torch.cat([cnn_feats, gcn_feats], dim=1)
+
+ att_mask = makeMask(all_feats)
+ cnn_mask = makeMask(cnn_feats)
+ gcn_mask = makeMask(gcn_feats)
+
+ cnn_feats = self.cnn_embed(cnn_feats) # forward entry
+ gcn_feats = self.gcn_embed(gcn_feats)
+
+ gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask)
+
+ decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask)
+ # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761
+ # raise Exception('lol')
+ return decoder_out
+
+ def forward_iuxray(self, att_feats):
+
+ att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1]) # bs, 49 2048
+
+ return att_feats, node_feats
+
+ def forward_mimiccxr(self, att_feats):
+ if self.args.dataset_name == 'mimic_cxr_2images':
+ att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1])
+ else:
+ # if only one image is inputted.
+ att_feats, node_feats, fc_feats = self.submodel(att_feats)
+
+ return att_feats, node_feats
+
+ def get_logprobs_state(self, **kwargs):
+ wt = kwargs[cfg.PARAM.WT]
+ state = kwargs[cfg.PARAM.STATE]
+ encoder_out = kwargs[cfg.PARAM.ATT_FEATS]
+ att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+
+ gx = kwargs[cfg.PARAM.GLOBAL_FEAT]
+ p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS]
+
+ if state is None:
+ ys = wt.unsqueeze(1)
+ else:
+ ys = torch.cat([state[0][0], wt.unsqueeze(1)], dim=1)
+ # ys = torch.zeros(16,1, dtype = int).to(device)
+ seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze(
+ 1)
+ decoder_out = self.decoder(gx, ys[:, -1].unsqueeze(-1), encoder_out, att_mask, seq_mask, p_att_feats,
+ True).squeeze(1)
+
+ logprobs = F.log_softmax(decoder_out, dim=-1)
+ return logprobs, [ys.unsqueeze(0)]
+
+ def _expand_state(self, batch_size, beam_size, cur_beam_size, selected_beam):
+ def fn(s):
+ shape = [int(sh) for sh in s.shape]
+ beam = selected_beam
+ for _ in shape[1:]:
+ beam = beam.unsqueeze(-1)
+ beam = beam.long()
+ s = torch.gather(s.view(*([batch_size, cur_beam_size] + shape[1:])), 1,
+ beam.expand(*([batch_size, beam_size] + shape[1:])))
+
+ s = s.view(*([-1, ] + shape[1:]))
+ return s
+
+ return fn
+
+ # the beam search code is inspired by https://github.com/aimagelab/meshed-memory-transformer
+ def decode_beam(self, **kwargs):
+ att_feats = kwargs[cfg.PARAM.ATT_FEATS]
+ # att_mask = torch.ones(16,70).to(device)
+ att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+ beam_size = kwargs['BEAM_SIZE']
+ batch_size = att_feats.size(0)
+ seq_logprob = torch.zeros((batch_size, 1, 1)).cuda()
+ log_probs = []
+ selected_words = None
+ seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()
+
+ cnn_feats, gcn_feats = self.get_visual_features(att_feats)
+ all_feats = torch.cat([cnn_feats, gcn_feats], dim=1)
+
+ att_mask = makeMask(all_feats)
+ cnn_mask = makeMask(cnn_feats)
+ gcn_mask = makeMask(gcn_feats)
+
+ cnn_feats = self.cnn_embed(cnn_feats) # forward entry
+ gcn_feats = self.gcn_embed(gcn_feats)
+
+ gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask)
+
+ p_att_feats = self.decoder.precompute(encoder_out)
+
+ state = None
+ wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())
+ kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
+ kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
+ kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats
+
+ outputs = []
+ self.decoder.init_buffer(batch_size)
+ for t in range(cfg.MODEL.SEQ_LEN):
+ cur_beam_size = 1 if t == 0 else beam_size
+
+ kwargs[cfg.PARAM.WT] = wt
+ kwargs[cfg.PARAM.STATE] = state
+ word_logprob, state = self.get_logprobs_state(**kwargs)
+ word_logprob = word_logprob.view(batch_size, cur_beam_size, -1)
+ candidate_logprob = seq_logprob + word_logprob
+
+ # Mask sequence if it reaches EOS
+ if t > 0:
+ mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1)
+ seq_mask = seq_mask * mask
+ word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
+ old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
+ old_seq_logprob[:, :, 1:] = -999
+ candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
+
+ selected_idx, selected_logprob = self.select(batch_size, beam_size, t, candidate_logprob)
+ selected_beam = selected_idx / candidate_logprob.shape[-1]
+ selected_beam = selected_beam.long()
+ selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
+
+ self.decoder.apply_to_states(self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam))
+ seq_logprob = selected_logprob.unsqueeze(-1)
+ seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
+ outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
+ outputs.append(selected_words.unsqueeze(-1))
+
+ this_word_logprob = torch.gather(word_logprob, 1,
+ selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
+ word_logprob.shape[-1]))
+ this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
+ log_probs = list(
+ torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs)
+ log_probs.append(this_word_logprob)
+ selected_words = selected_words.view(-1, 1)
+ wt = selected_words.squeeze(-1)
+
+ if t == 0:
+ encoder_out = utils.expand_tensor(encoder_out, beam_size)
+ gx = utils.expand_tensor(gx, beam_size)
+ att_mask = utils.expand_tensor(att_mask, beam_size)
+ state[0] = state[0].squeeze(0)
+ state[0] = utils.expand_tensor(state[0], beam_size)
+ state[0] = state[0].unsqueeze(0)
+
+ p_att_feats_tmp = []
+ for p_feat in p_att_feats:
+ p_key, p_value2 = p_feat
+ p_key = utils.expand_tensor(p_key, beam_size)
+ p_value2 = utils.expand_tensor(p_value2, beam_size)
+ p_att_feats_tmp.append((p_key, p_value2))
+
+ kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
+ kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
+ kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
+ kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp
+
+ seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
+ outputs = torch.cat(outputs, -1)
+ outputs = torch.gather(outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
+ log_probs = torch.cat(log_probs, -1)
+ log_probs = torch.gather(log_probs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
+
+ outputs = outputs.contiguous()[:, 0]
+ log_probs = log_probs.contiguous()[:, 0]
+
+ self.decoder.clear_buffer()
+ return outputs, log_probs
+
+ def decode(self, **kwargs):
+ beam_size = kwargs['BEAM_SIZE']
+ greedy_decode = kwargs['GREEDY_DECODE']
+ att_feats = kwargs[cfg.PARAM.ATT_FEATS]
+ # att_mask = torch.ones(16,70).to(device)
+ att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+
+ batch_size = att_feats.shape[0]
+ cnn_feats, gcn_feats = self.get_visual_features(att_feats)
+ all_feats = torch.cat([cnn_feats, gcn_feats], dim=1)
+
+ att_mask = makeMask(all_feats)
+ cnn_mask = makeMask(cnn_feats)
+ gcn_mask = makeMask(gcn_feats)
+
+ cnn_feats = self.cnn_embed(cnn_feats) # forward entry
+ gcn_feats = self.gcn_embed(gcn_feats)
+
+ gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask)
+
+ p_att_feats = self.decoder.precompute(encoder_out)
+ self.decoder.init_buffer(batch_size)
+
+ state = None
+ sents = Variable(torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda())
+ logprobs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda())
+ wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())
+ unfinished = wt.eq(wt)
+ kwargs[cfg.PARAM.ATT_FEATS] = encoder_out
+ kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
+ kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats
+ for t in range(cfg.MODEL.SEQ_LEN):
+ kwargs[cfg.PARAM.WT] = wt
+ kwargs[cfg.PARAM.STATE] = state
+ logprobs_t, state = self.get_logprobs_state(**kwargs)
+
+ if greedy_decode:
+ logP_t, wt = torch.max(logprobs_t, 1)
+ else:
+ probs_t = torch.exp(logprobs_t)
+ wt = torch.multinomial(probs_t, 1)
+ logP_t = logprobs_t.gather(1, wt)
+ wt = wt.view(-1).long()
+ unfinished = unfinished * (wt > 0)
+ wt = wt * unfinished.type_as(wt)
+ sents[:, t] = wt
+ logprobs[:, t] = logP_t.view(-1)
+
+ if unfinished.sum() == 0:
+ break
+ self.decoder.clear_buffer()
+ return sents, logprobs
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ layer_num
+ ):
+ super(Encoder, self).__init__()
+ self.att_heads = att_heads
+ self.layers = nn.ModuleList([])
+ for i in range(layer_num):
+ sublayer = EncoderLayer(
+ embed_dim=embed_dim,
+ dropout=dropout,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop,
+ bifeat_emb_act=bifeat_emb_act,
+ bifeat_emb_drop=bifeat_emb_drop,
+ ff_dropout=ff_dropout)
+ self.layers.append(sublayer)
+ assert embed_dim % 2 == 0
+ half_embed_dim = int(embed_dim / 2)
+ self.cnn_proj_norm = nn.Sequential(
+ nn.Linear(embed_dim * (layer_num + 1), half_embed_dim),
+ torch.nn.LayerNorm(half_embed_dim))
+ self.gcn_proj_norm = nn.Sequential(
+ nn.Linear(embed_dim * (layer_num + 1), half_embed_dim),
+ torch.nn.LayerNorm(half_embed_dim))
+
+ # def forward(self, x, mask):
+ # # we try not to use the mask
+ # gx = (torch.sum(x * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
+
+ def forward(self, cnn_feats, gcn_feats, cnn_mask, gcn_mask):
+
+ # we try not to use the mask
+
+ # drop mask
+
+ # if att_masks is None:
+ # att_masks = x.new_ones(x.shape[:2], dtype=torch.long)
+ # att_masks = att_masks.unsqueeze(-2)
+ # print(x.shape)
+ # print(att_masks.shape)
+ cnn_gx = (torch.sum(cnn_feats * cnn_mask.unsqueeze(-1), 1) / torch.sum(cnn_mask.unsqueeze(-1), 1))
+ gcn_gx = (torch.sum(gcn_feats * gcn_mask.unsqueeze(-1), 1) / torch.sum(gcn_mask.unsqueeze(-1), 1))
+
+ cnn_gx_arr = [cnn_gx]
+ gcn_gx_arr = [gcn_gx]
+ for layer in self.layers:
+ cnn_gx, gcn_gx, cnn_feats, gcn_feats = layer(cnn_gx, gcn_gx, cnn_feats, gcn_feats, cnn_mask,
+ gcn_mask) # modified
+ cnn_gx_arr.append(cnn_gx)
+ gcn_gx_arr.append(gcn_gx)
+
+ print(cnn_gx_arr[0].shape)
+ print(gcn_gx_arr[0].shape)
+
+ cnn_gx = torch.cat(cnn_gx_arr, dim=-1) # cat dim?
+
+ cnn_gx = self.cnn_proj_norm(cnn_gx)
+
+ gcn_gx = torch.cat(gcn_gx_arr, dim=-1)
+
+ gcn_gx = self.gcn_proj_norm(gcn_gx) # TODO
+
+ gx = torch.cat([cnn_gx, gcn_gx], dim=-1) # TODO: cat dim
+ x = torch.cat([cnn_feats, gcn_feats], dim=1) # cat at seq
+
+ return gx, x
+
+
+class EncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout
+ ):
+ super(EncoderLayer, self).__init__()
+ self.encoder_attn = clones(LowRank(
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop), 4)
+ self.dropout = clones(nn.Dropout(dropout), 4)
+
+ self.bifeat_emb = clones(nn.Sequential(
+ nn.Linear(2 * embed_dim, embed_dim),
+ utils.activation(bifeat_emb_act),
+ nn.Dropout(bifeat_emb_drop)
+ ), 4)
+
+ self.layer_norm = clones(torch.nn.LayerNorm(embed_dim), 4)
+
+ self.ff_layer = clones(blocks.create(
+ 'FeedForward',
+ embed_dim=embed_dim,
+ ffn_embed_dim=embed_dim * 4,
+ relu_dropout=ff_dropout,
+ dropout=ff_dropout), 4)
+
+ def forward(self, cnn_gx, gcn_gx, cnn_feats, gcn_feats, cnn_mask, gcn_mask):
+ """
+ gx : torch.Size([4, 768])
+ x : torch.Size([4, 49, 768])
+ mask : torch.Size([4, 49])
+
+ """
+ assert self.ff_layer != None
+
+ cnn_gx = self.dropout[0](
+ self.encoder_attn[0](query=cnn_gx, key=cnn_feats, mask=cnn_mask, value1=cnn_gx, value2=cnn_feats))
+ cnn_x_ = torch.cat([cnn_gx.unsqueeze(1).expand_as(cnn_feats), cnn_feats], dim=-1)
+ cnn_feats = self.ff_layer[0](self.layer_norm[0](self.bifeat_emb[0](cnn_x_) + cnn_feats))
+
+ gcn_gx = self.dropout[1](
+ self.encoder_attn[1](query=gcn_gx, key=gcn_feats, mask=gcn_mask, value1=gcn_gx, value2=gcn_feats))
+ gcn_x_ = torch.cat([gcn_gx.unsqueeze(1).expand_as(gcn_feats), gcn_feats], dim=-1)
+ gcn_feats = self.ff_layer[1](self.layer_norm[1](self.bifeat_emb[1](gcn_x_) + gcn_feats))
+
+ all_feats = torch.cat([cnn_feats, gcn_feats], dim=1)
+ all_mask = torch.cat([cnn_mask, gcn_mask], dim=-1)
+
+ cnn_gx = self.dropout[2](
+ self.encoder_attn[2](query=cnn_gx, key=all_feats, mask=all_mask, value1=cnn_gx, value2=all_feats))
+ cnn_x_ = torch.cat([cnn_gx.unsqueeze(1).expand_as(cnn_feats), cnn_feats], dim=-1)
+ cnn_feats = self.ff_layer[2](self.layer_norm[2](self.bifeat_emb[2](cnn_x_) + cnn_feats))
+
+ gcn_gx = self.dropout[3](
+ self.encoder_attn[3](query=gcn_gx, key=all_feats, mask=all_mask, value1=gcn_gx, value2=all_feats))
+ gcn_x_ = torch.cat([gcn_gx.unsqueeze(1).expand_as(gcn_feats), gcn_feats], dim=-1)
+ gcn_feats = self.ff_layer[3](self.layer_norm[3](self.bifeat_emb[3](gcn_x_) + gcn_feats))
+
+ return cnn_gx, gcn_gx, cnn_feats, gcn_feats
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ vocab_size,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ layer_num
+ ):
+ super(Decoder, self).__init__()
+ self.att_heads = att_heads
+ self.layers = nn.ModuleList([])
+ self.embed_dim = embed_dim
+ for i in range(layer_num):
+ sublayer = DecoderLayer(
+ embed_dim=embed_dim,
+ dropout=dropout,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop,
+ bifeat_emb_act=bifeat_emb_act,
+ bifeat_emb_drop=bifeat_emb_drop,
+ ff_dropout=ff_dropout,
+ last_layer=(i == layer_num - 1))
+ self.layers.append(sublayer)
+
+ self.dropout = nn.Dropout(cfg.MODEL.DROPOUT_WORD_EMBED)
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim)
+ self.embed_scale = math.sqrt(embed_dim)
+ self.embed_positions = PositionalEncoding(
+ embed_dim, cfg.MODEL.TRANSFORMER.PE_MAX_LEN
+ )
+
+ self.layer_norm_word = torch.nn.LayerNorm(embed_dim)
+ self.generator = nn.Linear(embed_dim, vocab_size)
+
+ self.wbil1 = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ utils.activation(cfg.MODEL.BILINEAR.ACT),
+ torch.nn.LayerNorm(embed_dim)
+ )
+ self.wbil2 = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ utils.activation(cfg.MODEL.BILINEAR.ACT),
+ torch.nn.LayerNorm(embed_dim)
+ )
+ self.wbi_drop = nn.Dropout(cfg.MODEL.BILINEAR.DECODE_DROPOUT)
+ self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM)
+
+ self.proj_norm = nn.Sequential(
+ nn.Linear(embed_dim * (layer_num + 1), 2 * embed_dim),
+ nn.GLU(),
+ torch.nn.LayerNorm(embed_dim))
+
+ self.clear_buffer()
+
+ def init_buffer(self, batch_size):
+ self.seq_len = 0
+ self.x = torch.zeros((batch_size, 1, self.embed_dim)).cuda()
+ for layer in self.layers:
+ layer.init_buffer(batch_size)
+
+ def clear_buffer(self):
+ self.seq_len = None
+ self.x = None
+ for layer in self.layers:
+ layer.clear_buffer()
+
+ def apply_to_states(self, fn):
+ self.x = fn(self.x)
+ for layer in self.layers:
+ layer.apply_to_states(fn)
+
+ def precompute(self, encoder_out):
+ p_att_feats = []
+ for layer in self.layers:
+ key, value2 = layer.precompute(encoder_out)
+ p_att_feats.append((key, value2))
+ return p_att_feats
+
+ def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, p_att_feats=None, precompute=False):
+ # device = torch.device("cuda")
+ att_mask = att_mask.unsqueeze(1)
+
+ # embed positions
+ seq_len = prev_output_tokens.size(1)
+ if self.seq_len is not None:
+ seq_len = self.seq_len + seq_len
+ self.seq_len = seq_len
+ positions = self.embed_positions(seq_len)[:, -1, :].unsqueeze(1)
+ else:
+ positions = self.embed_positions(seq_len)
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+ x = x + positions
+ x = self.layer_norm_word(x)
+ if self.dropout is not None:
+ x = self.dropout(x)
+
+ # decoder layers
+ gx = self.wbil1(gx)
+ if self.x is None:
+ x_gx = (torch.sum(x.unsqueeze(1) * seq_mask.unsqueeze(-1), -2) / torch.sum(seq_mask, -1).unsqueeze(-1))
+ else:
+ self.x = self.x + x
+ x_gx = self.x / seq_len
+ x_gx = self.wbil2(x_gx)
+ gx = gx.unsqueeze(1)
+ gx = gx * x_gx
+ gx = self.wbi_drop(gx)
+
+ # print(gx.shape)
+ # print(type(gx))
+ # print((gx[0, 0, 0].dtype))
+
+ bs, att_nums, features_sizes = gx.shape
+
+ # gx_arr = torch.zeros((len(self.layers), bs, att_nums, features_sizes), dtype = torch.float32, device = device)# .to(torch.device("cuda")) # HARD CODE
+ # print(gx_arr.shape)
+ # print(gx_arr.shape)
+ # gx_arr = gx_arr.to(device)
+
+ gx_arr = [gx] # .to('cuda')
+ # gx_arr =
+
+ for layerid, layer in enumerate(self.layers):
+ # print('layerid',layerid)
+ if precompute == False:
+ p_key = None
+ p_value2 = None
+ else:
+ p_key, p_value2 = p_att_feats[layerid]
+ # print('x.shape b4',x.shape)
+ # print('precompute',precompute)
+ # print('att_mask.shape',att_mask.shape)
+ # print('seq_mask.shape',seq_mask.shape)
+ gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2,
+ precompute=precompute)
+ # print('gx.shape',gx.shape) # bs, 41, 768
+ # print('x.shape after',x.shape)
+ # print(type(gx_arr))
+ gx_arr.append(gx)
+ # print('test',gx_arr[layerid, :, :, :].shape)
+ # print('layerid',layerid)
+ # print('gx_arr.shape',gx_arr.shape)
+ # print('gx.device',gx.device)
+ # print('gx_arr.device',gx_arr.device)
+ # gx_arr[layerid] = gx
+ # raise Exception('decoder lol')
+ gx = torch.cat(gx_arr, dim=-1)
+
+ gx = self.proj_norm(gx)
+
+ gx = self.dropout_lm(gx)
+ out = self.generator(gx)
+ # raise Exception('decoder lol')
+ return out
+
+
+class DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ last_layer=False
+ ):
+ super(DecoderLayer, self).__init__()
+ self.last_layer = last_layer
+ self.word_attn = LowRank(
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop)
+ self.word_dropout = nn.Dropout(dropout)
+
+ self.cross_att = LowRank(
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop)
+ self.cross_dropout = nn.Dropout(dropout)
+ self.layer_norm_cross = torch.nn.LayerNorm(embed_dim)
+
+ if self.last_layer == False:
+ self.bifeat_emb = nn.Sequential(
+ nn.Linear(2 * embed_dim, embed_dim),
+ utils.activation(bifeat_emb_act),
+ nn.Dropout(bifeat_emb_drop)
+ )
+ self.layer_norm_x = torch.nn.LayerNorm(embed_dim)
+
+ self.ff_layer = blocks.create(
+ 'FeedForward',
+ embed_dim=embed_dim,
+ ffn_embed_dim=embed_dim * 4,
+ relu_dropout=ff_dropout,
+ dropout=ff_dropout)
+
+ self.layer_norm_gx = torch.nn.LayerNorm(embed_dim)
+
+ def apply_to_states(self, fn):
+ self.word_attn.apply_to_states(fn)
+
+ def init_buffer(self, batch_size):
+ self.word_attn.init_buffer(batch_size)
+
+ def clear_buffer(self):
+ self.word_attn.clear_buffer()
+
+ def precompute(self, encoder_out):
+ key, value2 = self.cross_att.precompute(encoder_out, encoder_out)
+ return key, value2
+
+ def forward(
+ self,
+ gx,
+ x,
+ encoder_out,
+ att_mask,
+ seq_mask,
+ p_key=None,
+ p_value2=None,
+ precompute=False
+ ):
+ word_x = x
+ residual = x
+ x = self.word_attn.forward2(
+ query=gx,
+ key=x,
+ mask=seq_mask,
+ value1=gx,
+ value2=x)
+ x = self.word_dropout(x)
+
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm_cross(x)
+ x = self.cross_att.forward2(
+ query=x,
+ key=encoder_out if precompute == False else p_key,
+ mask=att_mask,
+ value1=x,
+ value2=encoder_out if precompute == False else p_value2,
+ precompute=precompute)
+ x = self.cross_dropout(x)
+ gx = residual + x
+ gx = self.layer_norm_gx(gx)
+
+ if self.last_layer == False:
+ x_ = torch.cat([gx, word_x], dim=-1)
+ x = self.bifeat_emb(x_) + word_x
+ x = self.layer_norm_x(x)
+
+ if self.ff_layer is not None:
+ x = self.ff_layer(x)
+ else:
+ x = None
+
+ return gx, x
diff --git a/models/pretrained_models.py b/models/pretrained_models.py
new file mode 100644
index 0000000..480f6a6
--- /dev/null
+++ b/models/pretrained_models.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import gzip
+import torch
+from torch import relu
+from torch.nn import Dropout, Linear, Sequential
+from torch.nn.functional import adaptive_avg_pool2d, cross_entropy
+from torchvision import models
+
+
+# # Image processes
+# if image_model is None:
+# image_model = 'densenet'
+# self.image_feats, image_dim = ImageClassification.image_features(image_model, not finetune_image, True,
+# image_pretrained, device)
+
+class ImageClassification(torch.nn.Module):
+ def __init__(self, model, num_labels, num_classes, multi_image=1, dropout=0.0, pretrained=True):
+ super(ImageClassification, self).__init__()
+ self.image_feats, self.image_dim = self.image_features(model, False, pretrained)
+ for i in range(num_labels):
+ setattr(self, 'linear{0}'.format(i), Linear(self.image_dim, num_classes))
+ self.num_labels = num_labels
+ self.multi_image = multi_image
+ self.dropout = Dropout(p=dropout)
+
+ @classmethod
+ def fix_layers(cls, model):
+ for param in model.parameters():
+ param.requires_grad = False
+
+ @classmethod
+ def image_features(cls, name, fixed_weight=False, pretrained=True, pretrained_model=None, device='gpu'):
+ if pretrained_model is None:
+ if name == 'densenet121' or name == 'densenet':
+ m = models.densenet121(pretrained=pretrained)
+ if fixed_weight:
+ cls.fix_layers(m)
+ return Sequential(*list(m.features.children())), 1024
+ elif name == 'resnet50':
+ m = models.resnet50(pretrained=pretrained)
+ if fixed_weight:
+ cls.fix_layers(m)
+ return Sequential(*list(m.children())[:-2]), 2048
+ elif name == 'resnet152' or name == 'resnet':
+ m = models.resnet152(pretrained=pretrained)
+ if fixed_weight:
+ cls.fix_layers(m)
+ return Sequential(*list(m.children())[:-2]), 2048
+ elif name == 'vgg19' or name == 'vgg':
+ m = models.vgg19(pretrained=pretrained)
+ if fixed_weight:
+ cls.fix_layers(m)
+ return Sequential(*list(m.features.children())[:-1]), 512
+ else:
+ raise ValueError('Unknown model {0}'.format(name))
+ else:
+ d = torch.device('cpu')if device == 'cpu' else torch.device('cuda:0')
+ with gzip.open(pretrained_model) as f:
+ state = torch.load(f, map_location=d)
+ m = ImageClassification(name, 14, 3, pretrained=False)
+ m.load_state_dict(state['model'])
+ if fixed_weight:
+ cls.fix_layers(m)
+ return m.image_feats, m.image_dim
+
+ def deflatten_image(self, x):
+ if self.multi_image > 1:
+ x = x.view(int(x.shape[0] / self.multi_image), self.multi_image, x.shape[1])
+ x, _ = torch.max(x, dim=1)
+ return x
+
+ def flatten_image(self, x):
+ if self.multi_image > 1:
+ return x.flatten(start_dim=0, end_dim=1)
+ else:
+ return x
+
+ def forward(self, x):
+ x = self.flatten_image(x)
+ x = self.image_feats(x)
+ x = relu(x)
+ x = adaptive_avg_pool2d(x, (1, 1))
+ x = torch.flatten(x, 1)
+ x = self.deflatten_image(x)
+ xs = []
+ for i in range(self.num_labels):
+ xi = self.dropout(x)
+ xi = getattr(self, 'linear{0}'.format(i))(xi).unsqueeze(dim=2)
+ xs.append(xi)
+ x = torch.cat(xs, dim=2)
+ return x
\ No newline at end of file
diff --git a/models/xlan.py b/models/xlan.py
index 109dcf2..6d43815 100755
--- a/models/xlan.py
+++ b/models/xlan.py
@@ -7,7 +7,7 @@
import blocks
class XLAN(AttBasicModel):
- def __init__(self):
+ def __init__(self, args):
super(XLAN, self).__init__()
self.num_layers = 2
@@ -51,7 +51,7 @@ def Forward(self, **kwargs):
att, _ = self.attention(h_att, att_feats, att_mask, p_att_feats, precompute=True)
ctx_input = torch.cat([att, h_att], 1)
- output = self.att2ctx(ctx_input)
+ output = self.att2ctx(ctx_input) # forward entry
state = [torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))]
return output, state
\ No newline at end of file
diff --git a/models/xtransformer.py b/models/xtransformer.py
index 5179608..9ec09d4 100755
--- a/models/xtransformer.py
+++ b/models/xtransformer.py
@@ -12,6 +12,9 @@
import lib.utils as utils
from models.basic_model import BasicModel
from layers.positional_encoding import PositionalEncoding
+from .pretrained_models import ImageClassification
+from mlclassifier import GCNClassifier
+device = torch.device("cuda")
def subsequent_mask(size):
"Mask out subsequent positions."
@@ -19,71 +22,129 @@ def subsequent_mask(size):
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
+
class XTransformer(BasicModel):
- def __init__(self):
+ def __init__(self, args, submodel):
super(XTransformer, self).__init__()
self.vocab_size = cfg.MODEL.VOCAB_SIZE + 1
+ # image pretrained
+ # self.image_pretrained_models, self.input_visual_feats = ImageClassification.image_features(
+ # 'densenet', fixed_weight=False, pretrained_model=cfg.MODEL.PretrainedImageModel)
+ if args.dataset_name == 'IUXRAY':
+ num_images = 2
+ self.get_visual_features = self.forward_iuxray
+ else: # 'MIMICCXR'
+ num_images = 1
+ self.get_visual_features = self.forward_mimiccxr
# att_feats encoder
+ self.input_visual_feats = 1024
sequential = []
- sequential.append(nn.Linear(cfg.MODEL.ATT_FEATS_DIM, cfg.MODEL.ATT_FEATS_EMBED_DIM))
+ sequential.append(nn.Linear(self.input_visual_feats * num_images, cfg.MODEL.ATT_FEATS_EMBED_DIM))
sequential.append(utils.activation(cfg.MODEL.ATT_FEATS_EMBED_ACT))
if cfg.MODEL.ATT_FEATS_NORM == True:
sequential.append(nn.LayerNorm(cfg.MODEL.ATT_FEATS_EMBED_DIM))
if cfg.MODEL.DROPOUT_ATT_EMBED > 0:
- sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED))
+ sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED))
+
self.att_embed = nn.Sequential(*sequential) if len(sequential) > 0 else None
self.encoder = Encoder(
- embed_dim = cfg.MODEL.BILINEAR.DIM,
- dropout = cfg.MODEL.BILINEAR.ENCODE_DROPOUT,
- att_type = cfg.MODEL.BILINEAR.ATTTYPE,
- att_heads = cfg.MODEL.BILINEAR.HEAD,
- att_mid_dim = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM,
- att_mid_drop = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT,
- bifeat_emb_act = cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
- bifeat_emb_drop = cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT,
- ff_dropout = cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT,
- layer_num = cfg.MODEL.BILINEAR.ENCODE_LAYERS)
-
- self.decoder = Decoder(
- vocab_size = self.vocab_size,
- embed_dim = cfg.MODEL.BILINEAR.DIM,
- dropout = cfg.MODEL.BILINEAR.DECODE_DROPOUT,
- att_type = cfg.MODEL.BILINEAR.ATTTYPE,
- att_heads = cfg.MODEL.BILINEAR.HEAD,
- att_mid_dim = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM,
- att_mid_drop = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT,
- bifeat_emb_act = cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
- bifeat_emb_drop = cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT,
- ff_dropout = cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT,
- layer_num = cfg.MODEL.BILINEAR.DECODE_LAYERS)
+ embed_dim=cfg.MODEL.BILINEAR.DIM,
+ dropout=cfg.MODEL.BILINEAR.ENCODE_DROPOUT,
+ att_type=cfg.MODEL.BILINEAR.ATTTYPE,
+ att_heads=cfg.MODEL.BILINEAR.HEAD,
+ att_mid_dim=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM,
+ att_mid_drop=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT,
+ bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
+ bifeat_emb_drop=cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT,
+ ff_dropout=cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT,
+ layer_num=cfg.MODEL.BILINEAR.ENCODE_LAYERS)
+
+ self.decoder = Decoder(
+ vocab_size=self.vocab_size,
+ embed_dim=cfg.MODEL.BILINEAR.DIM,
+ dropout=cfg.MODEL.BILINEAR.DECODE_DROPOUT,
+ att_type=cfg.MODEL.BILINEAR.ATTTYPE,
+ att_heads=cfg.MODEL.BILINEAR.HEAD,
+ att_mid_dim=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM,
+ att_mid_drop=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT,
+ bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT,
+ bifeat_emb_drop=cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT,
+ ff_dropout=cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT,
+ layer_num=cfg.MODEL.BILINEAR.DECODE_LAYERS)
+ self.submodel = submodel
def forward(self, **kwargs):
+ # forward entry
+
att_feats = kwargs[cfg.PARAM.ATT_FEATS]
seq = kwargs[cfg.PARAM.INPUT_SENT]
att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+ # att_mask = torch.ones(16,70).to(device)
att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
##############################################
seq_mask = (seq > 0).type(torch.cuda.IntTensor)
- seq_mask[:,0] += 1
+ seq_mask[:, 0] += 1
seq_mask = seq_mask.unsqueeze(-2)
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
seq_mask = seq_mask.type(torch.cuda.FloatTensor)
##############################################
-
- att_feats = self.att_embed(att_feats)
+ # print('att_feats.shape b4 pretrain',att_feats.shape)
+ # if len(att_feats.shape) == 5:
+ # batch_size = att_feats.shape[0]
+ # img_size = att_feats.shape[1]
+ # channel_size = att_feats.shape[2]
+ # hidden_size1 = att_feats.shape[3]
+ # hidden_size2 = att_feats.shape[4]
+ # att_feats = att_feats.view(batch_size, -1, hidden_size1, hidden_size2)
+
+ # FEED IUXRAY: two images
+ # att_feats_0 = self.image_pretrained_models(att_feats[:, 0])
+ # att_feats_1 = self.image_pretrained_models(att_feats[:, 1])
+ # att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) # shape (bs, 2048, 7, 7)
+ att_feats = self.get_visual_features(att_feats)
+ batch_size, feat_size, _,_ = att_feats.shape
+ att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
+ # print('att_feats.shape after pretrain', att_feats.shape)
+ att_feats = self.att_embed(att_feats) # forward entry
+ # print('att_feats.shape after pretrain', att_feats.shape)
+
gx, encoder_out = self.encoder(att_feats, att_mask)
+ # print(gx.shape, encoder_out.shape)
+
decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask)
+ # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761
+ # raise Exception('lol')
return decoder_out
+ def forward_iuxray(self, att_feats):
+
+ att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1]) #bs, 49 2048
+ att_feats = torch.cat((att_feats, node_feats), dim = 1)#Gcn+cnn
+# att_feats = att_feats#cnn only
+# att_feats = node_feats#gcn only
+ att_feats = att_feats.permute(0,2,1)
+ att_feats = att_feats.unsqueeze(-1) # bs, 2048,74,1
+ return att_feats
+
+ def forward_mimiccxr(self, att_feats):
+ att_feats, node_feats, fc_feats = self.submodel(att_feats)
+ att_feats = torch.cat((att_feats, node_feats), dim = 1) #torch.Size([16, 70, 2048])
+# att_feats = att_feats#cnn only
+# att_feats = node_feats#gcn only
+ att_feats = att_feats.permute(0,2,1)
+ att_feats = att_feats.unsqueeze(-1) # bs, 2048,74,1
+ return att_feats
+
def get_logprobs_state(self, **kwargs):
wt = kwargs[cfg.PARAM.WT]
state = kwargs[cfg.PARAM.STATE]
encoder_out = kwargs[cfg.PARAM.ATT_FEATS]
att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
+
gx = kwargs[cfg.PARAM.GLOBAL_FEAT]
p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS]
@@ -91,9 +152,13 @@ def get_logprobs_state(self, **kwargs):
ys = wt.unsqueeze(1)
else:
ys = torch.cat([state[0][0], wt.unsqueeze(1)], dim=1)
- seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze(1)
+ # ys = torch.zeros(16,1, dtype = int).to(device)
+ seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze(
+ 1)
decoder_out = self.decoder(gx, ys[:, -1].unsqueeze(-1), encoder_out, att_mask, seq_mask, p_att_feats, True).squeeze(1)
-
+ # print('HHHHHHHHHHHHHHHHHHHHHHH',decoder_out.shape)
+ # raise Exception('end')
+
logprobs = F.log_softmax(decoder_out, dim=-1)
return logprobs, [ys.unsqueeze(0)]
@@ -103,15 +168,19 @@ def fn(s):
beam = selected_beam
for _ in shape[1:]:
beam = beam.unsqueeze(-1)
+ beam = beam.long()
s = torch.gather(s.view(*([batch_size, cur_beam_size] + shape[1:])), 1,
beam.expand(*([batch_size, beam_size] + shape[1:])))
+
s = s.view(*([-1, ] + shape[1:]))
return s
+
return fn
# the beam search code is inspired by https://github.com/aimagelab/meshed-memory-transformer
def decode_beam(self, **kwargs):
att_feats = kwargs[cfg.PARAM.ATT_FEATS]
+ # att_mask = torch.ones(16,70).to(device)
att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
beam_size = kwargs['BEAM_SIZE']
batch_size = att_feats.size(0)
@@ -120,7 +189,13 @@ def decode_beam(self, **kwargs):
selected_words = None
seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()
+ # Modified
+ att_feats = self.get_visual_features(att_feats)
+ batch_size, feat_size, _ ,_= att_feats.shape
+ att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
+ # print(att_feats.shape)
att_feats = self.att_embed(att_feats)
+
gx, encoder_out = self.encoder(att_feats, att_mask)
p_att_feats = self.decoder.precompute(encoder_out)
@@ -152,6 +227,7 @@ def decode_beam(self, **kwargs):
selected_idx, selected_logprob = self.select(batch_size, beam_size, t, candidate_logprob)
selected_beam = selected_idx / candidate_logprob.shape[-1]
+ selected_beam = selected_beam.long()
selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
self.decoder.apply_to_states(self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam))
@@ -161,7 +237,8 @@ def decode_beam(self, **kwargs):
outputs.append(selected_words.unsqueeze(-1))
this_word_logprob = torch.gather(word_logprob, 1,
- selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1]))
+ selected_beam.unsqueeze(-1).expand(batch_size, beam_size,
+ word_logprob.shape[-1]))
this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
log_probs = list(
torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs)
@@ -188,7 +265,7 @@ def decode_beam(self, **kwargs):
kwargs[cfg.PARAM.GLOBAL_FEAT] = gx
kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp
-
+
seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
outputs = torch.cat(outputs, -1)
outputs = torch.gather(outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN))
@@ -205,14 +282,20 @@ def decode(self, **kwargs):
beam_size = kwargs['BEAM_SIZE']
greedy_decode = kwargs['GREEDY_DECODE']
att_feats = kwargs[cfg.PARAM.ATT_FEATS]
+ # att_mask = torch.ones(16,70).to(device)
att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
- batch_size = att_feats.size(0)
+ att_feats = self.get_visual_features(att_feats)
+ batch_size, feat_size, _ ,_= att_feats.shape
+ att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
+
att_feats = self.att_embed(att_feats)
gx, encoder_out = self.encoder(att_feats, att_mask)
+ # print(gx.shape)
+ # print(encoder_out.shape)
p_att_feats = self.decoder.precompute(encoder_out)
self.decoder.init_buffer(batch_size)
-
+
state = None
sents = Variable(torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda())
logprobs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda())
@@ -225,7 +308,7 @@ def decode(self, **kwargs):
kwargs[cfg.PARAM.WT] = wt
kwargs[cfg.PARAM.STATE] = state
logprobs_t, state = self.get_logprobs_state(**kwargs)
-
+
if greedy_decode:
logP_t, wt = torch.max(logprobs_t, 1)
else:
@@ -235,80 +318,93 @@ def decode(self, **kwargs):
wt = wt.view(-1).long()
unfinished = unfinished * (wt > 0)
wt = wt * unfinished.type_as(wt)
- sents[:,t] = wt
- logprobs[:,t] = logP_t.view(-1)
+ sents[:, t] = wt
+ logprobs[:, t] = logP_t.view(-1)
if unfinished.sum() == 0:
break
self.decoder.clear_buffer()
return sents, logprobs
+
class Encoder(nn.Module):
def __init__(
- self,
- embed_dim,
- dropout,
- att_type,
- att_heads,
- att_mid_dim,
- att_mid_drop,
- bifeat_emb_act,
- bifeat_emb_drop,
- ff_dropout,
- layer_num
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ layer_num
):
super(Encoder, self).__init__()
self.att_heads = att_heads
- self.layers = nn.ModuleList([])
+ self.layers = nn.ModuleList([])
for i in range(layer_num):
- sublayer = EncoderLayer(
- embed_dim = embed_dim,
- dropout = dropout,
- att_type = att_type,
- att_heads = att_heads,
- att_mid_dim = att_mid_dim,
- att_mid_drop = att_mid_drop,
- bifeat_emb_act = bifeat_emb_act,
- bifeat_emb_drop = bifeat_emb_drop,
- ff_dropout = ff_dropout)
+ sublayer = EncoderLayer(
+ embed_dim=embed_dim,
+ dropout=dropout,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop,
+ bifeat_emb_act=bifeat_emb_act,
+ bifeat_emb_drop=bifeat_emb_drop,
+ ff_dropout=ff_dropout)
self.layers.append(sublayer)
-
+
self.proj_norm = nn.Sequential(
- nn.Linear(embed_dim * (layer_num + 1), embed_dim),
+ nn.Linear(embed_dim * (layer_num + 1), embed_dim),
torch.nn.LayerNorm(embed_dim))
+
+
+
def forward(self, x, mask):
- gx = (torch.sum(x * mask.unsqueeze(-1), 1) / torch.sum(mask.unsqueeze(-1), 1))
+ # drop mask
+ att_masks = mask
+ # if att_masks is None:
+ # att_masks = x.new_ones(x.shape[:2], dtype=torch.long)
+ # att_masks = att_masks.unsqueeze(-2)
+ # print(x.shape)
+ # print(att_masks.shape)
+ gx = (torch.sum(x * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
gx_arr = [gx]
for layer in self.layers:
- gx, x = layer(gx, x, mask)
+ gx, x = layer(gx, x, att_masks) # modified
gx_arr.append(gx)
gx = torch.cat(gx_arr, dim=-1)
gx = self.proj_norm(gx)
+
return gx, x
+
class EncoderLayer(nn.Module):
def __init__(
- self,
- embed_dim,
- dropout,
- att_type,
- att_heads,
- att_mid_dim,
- att_mid_drop,
- bifeat_emb_act,
- bifeat_emb_drop,
- ff_dropout
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout
):
super(EncoderLayer, self).__init__()
self.encoder_attn = LowRank(
- embed_dim = embed_dim,
- att_type = att_type,
- att_heads = att_heads,
- att_mid_dim = att_mid_dim,
- att_mid_drop = att_mid_drop)
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop)
self.dropout = nn.Dropout(dropout)
self.bifeat_emb = nn.Sequential(
@@ -320,60 +416,70 @@ def __init__(
self.ff_layer = blocks.create(
'FeedForward',
- embed_dim = embed_dim,
- ffn_embed_dim = embed_dim * 4,
- relu_dropout = ff_dropout,
- dropout = ff_dropout)
+ embed_dim=embed_dim,
+ ffn_embed_dim=embed_dim * 4,
+ relu_dropout=ff_dropout,
+ dropout=ff_dropout)
def forward(self, gx, x, mask):
+ """
+ gx : torch.Size([4, 768])
+ x : torch.Size([4, 49, 768])
+ mask : torch.Size([4, 49])
+
+ """
+
gx = self.encoder_attn(
- query = gx,
- key = x,
- mask = mask,
- value1 = gx,
- value2 = x
+ query=gx,
+ key=x,
+ mask=mask,
+ value1=gx,
+ value2=x
)
+
gx = self.dropout(gx)
- x_ = torch.cat([gx.unsqueeze(1).expand_as(x), x], dim = -1)
+ x_ = torch.cat([gx.unsqueeze(1).expand_as(x), x], dim=-1)
x = self.bifeat_emb(x_) + x
x = self.layer_norm(x)
if self.ff_layer is not None:
x = self.ff_layer(x)
+
return gx, x
+
class Decoder(nn.Module):
def __init__(
- self,
- vocab_size,
- embed_dim,
- dropout,
- att_type,
- att_heads,
- att_mid_dim,
- att_mid_drop,
- bifeat_emb_act,
- bifeat_emb_drop,
- ff_dropout,
- layer_num
+ self,
+ vocab_size,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ layer_num
):
super(Decoder, self).__init__()
self.att_heads = att_heads
self.layers = nn.ModuleList([])
self.embed_dim = embed_dim
for i in range(layer_num):
- sublayer = DecoderLayer(
- embed_dim = embed_dim,
- dropout = dropout,
- att_type = att_type,
- att_heads = att_heads,
- att_mid_dim = att_mid_dim,
- att_mid_drop = att_mid_drop,
- bifeat_emb_act = bifeat_emb_act,
- bifeat_emb_drop = bifeat_emb_drop,
- ff_dropout = ff_dropout,
- last_layer = (i == layer_num -1))
+ sublayer = DecoderLayer(
+ embed_dim=embed_dim,
+ dropout=dropout,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop,
+ bifeat_emb_act=bifeat_emb_act,
+ bifeat_emb_drop=bifeat_emb_drop,
+ ff_dropout=ff_dropout,
+ last_layer=(i == layer_num - 1))
self.layers.append(sublayer)
self.dropout = nn.Dropout(cfg.MODEL.DROPOUT_WORD_EMBED)
@@ -397,7 +503,7 @@ def __init__(
torch.nn.LayerNorm(embed_dim)
)
self.wbi_drop = nn.Dropout(cfg.MODEL.BILINEAR.DECODE_DROPOUT)
- self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM)
+ self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM)
self.proj_norm = nn.Sequential(
nn.Linear(embed_dim * (layer_num + 1), 2 * embed_dim),
@@ -431,14 +537,15 @@ def precompute(self, encoder_out):
return p_att_feats
def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, p_att_feats=None, precompute=False):
+ # device = torch.device("cuda")
att_mask = att_mask.unsqueeze(1)
-
+
# embed positions
seq_len = prev_output_tokens.size(1)
if self.seq_len is not None:
seq_len = self.seq_len + seq_len
self.seq_len = seq_len
- positions = self.embed_positions(seq_len)[:,-1,:].unsqueeze(1)
+ positions = self.embed_positions(seq_len)[:, -1, :].unsqueeze(1)
else:
positions = self.embed_positions(seq_len)
@@ -449,7 +556,7 @@ def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None,
x = self.layer_norm_word(x)
if self.dropout is not None:
x = self.dropout(x)
-
+
# decoder layers
gx = self.wbil1(gx)
if self.x is None:
@@ -462,53 +569,84 @@ def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None,
gx = gx * x_gx
gx = self.wbi_drop(gx)
- gx_arr = [gx]
+ # print(gx.shape)
+ # print(type(gx))
+ # print((gx[0, 0, 0].dtype))
+
+ bs, att_nums, features_sizes = gx.shape
+
+ # gx_arr = torch.zeros((len(self.layers), bs, att_nums, features_sizes), dtype = torch.float32, device = device)# .to(torch.device("cuda")) # HARD CODE
+ # print(gx_arr.shape)
+ # print(gx_arr.shape)
+ # gx_arr = gx_arr.to(device)
+
+ gx_arr = [gx] # .to('cuda')
+ # gx_arr =
+
for layerid, layer in enumerate(self.layers):
+ # print('layerid',layerid)
if precompute == False:
p_key = None
p_value2 = None
else:
- p_key, p_value2 = p_att_feats[layerid]
- gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2, precompute=precompute)
+ p_key, p_value2 = p_att_feats[layerid]
+ # print('x.shape b4',x.shape)
+ # print('precompute',precompute)
+ # print('att_mask.shape',att_mask.shape)
+ # print('seq_mask.shape',seq_mask.shape)
+ gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2,
+ precompute=precompute)
+ # print('gx.shape',gx.shape) # bs, 41, 768
+ # print('x.shape after',x.shape)
+ # print(type(gx_arr))
gx_arr.append(gx)
+ # print('test',gx_arr[layerid, :, :, :].shape)
+ # print('layerid',layerid)
+ # print('gx_arr.shape',gx_arr.shape)
+ # print('gx.device',gx.device)
+ # print('gx_arr.device',gx_arr.device)
+ # gx_arr[layerid] = gx
+ # raise Exception('decoder lol')
+ gx = torch.cat(gx_arr, dim=-1)
- gx = torch.cat(gx_arr, dim = -1)
gx = self.proj_norm(gx)
gx = self.dropout_lm(gx)
out = self.generator(gx)
+ # raise Exception('decoder lol')
return out
+
class DecoderLayer(nn.Module):
def __init__(
- self,
- embed_dim,
- dropout,
- att_type,
- att_heads,
- att_mid_dim,
- att_mid_drop,
- bifeat_emb_act,
- bifeat_emb_drop,
- ff_dropout,
- last_layer = False
+ self,
+ embed_dim,
+ dropout,
+ att_type,
+ att_heads,
+ att_mid_dim,
+ att_mid_drop,
+ bifeat_emb_act,
+ bifeat_emb_drop,
+ ff_dropout,
+ last_layer=False
):
super(DecoderLayer, self).__init__()
self.last_layer = last_layer
self.word_attn = LowRank(
- embed_dim = embed_dim,
- att_type = att_type,
- att_heads = att_heads,
- att_mid_dim = att_mid_dim,
- att_mid_drop = att_mid_drop)
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop)
self.word_dropout = nn.Dropout(dropout)
self.cross_att = LowRank(
- embed_dim = embed_dim,
- att_type = att_type,
- att_heads = att_heads,
- att_mid_dim = att_mid_dim,
- att_mid_drop = att_mid_drop)
+ embed_dim=embed_dim,
+ att_type=att_type,
+ att_heads=att_heads,
+ att_mid_dim=att_mid_dim,
+ att_mid_drop=att_mid_drop)
self.cross_dropout = nn.Dropout(dropout)
self.layer_norm_cross = torch.nn.LayerNorm(embed_dim)
@@ -522,10 +660,10 @@ def __init__(
self.ff_layer = blocks.create(
'FeedForward',
- embed_dim = embed_dim,
- ffn_embed_dim = embed_dim * 4,
- relu_dropout = ff_dropout,
- dropout = ff_dropout)
+ embed_dim=embed_dim,
+ ffn_embed_dim=embed_dim * 4,
+ relu_dropout=ff_dropout,
+ dropout=ff_dropout)
self.layer_norm_gx = torch.nn.LayerNorm(embed_dim)
@@ -543,42 +681,43 @@ def precompute(self, encoder_out):
return key, value2
def forward(
- self,
- gx,
- x,
- encoder_out,
- att_mask,
- seq_mask,
- p_key=None,
- p_value2=None,
- precompute=False
+ self,
+ gx,
+ x,
+ encoder_out,
+ att_mask,
+ seq_mask,
+ p_key=None,
+ p_value2=None,
+ precompute=False
):
word_x = x
residual = x
x = self.word_attn.forward2(
- query = gx,
- key = x,
- mask = seq_mask,
- value1 = gx,
- value2 = x)
+ query=gx,
+ key=x,
+ mask=seq_mask,
+ value1=gx,
+ value2=x)
x = self.word_dropout(x)
+
x = residual + x
residual = x
x = self.layer_norm_cross(x)
x = self.cross_att.forward2(
- query = x,
- key = encoder_out if precompute == False else p_key,
- mask = att_mask,
- value1 = x,
- value2 = encoder_out if precompute == False else p_value2,
+ query=x,
+ key=encoder_out if precompute == False else p_key,
+ mask=att_mask,
+ value1=x,
+ value2=encoder_out if precompute == False else p_value2,
precompute=precompute)
x = self.cross_dropout(x)
gx = residual + x
gx = self.layer_norm_gx(gx)
if self.last_layer == False:
- x_ = torch.cat([gx, word_x], dim = -1)
+ x_ = torch.cat([gx, word_x], dim=-1)
x = self.bifeat_emb(x_) + word_x
x = self.layer_norm_x(x)
@@ -586,4 +725,5 @@ def forward(
x = self.ff_layer(x)
else:
x = None
+
return gx, x
diff --git a/optimizer/optimizer.py b/optimizer/optimizer.py
index 4cbeeff..0d816e9 100755
--- a/optimizer/optimizer.py
+++ b/optimizer/optimizer.py
@@ -117,3 +117,28 @@ def get_lr(self):
lr.append(param_group['lr'])
lr = sorted(list(set(lr)))
return lr
+
+
+
+def build_optimizer(args, model):
+ # print(model.submodel.parameters())
+ #edit
+ # ve_params = list(map(id, model.submodel.parameters()))
+ # ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
+ # optimizer = torch.optim.Adam(
+ # [{'params': model.submodel.parameters(), 'lr': 5e-5}, #edit
+ # {'params': ed_params, 'lr': 1e-4}],
+ # weight_decay=5e-5,
+ # amsgrad=True
+ # )
+ optimizer = torch.optim.Adam(
+ [{'params': model, 'lr': 5e-5}], #edit
+ weight_decay=5e-5,
+ amsgrad=True
+ )
+ return optimizer
+
+
+def build_lr_scheduler(args, optimizer):
+ lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma)
+ return lr_scheduler
diff --git a/pycocoevalcap/README.md b/pycocoevalcap/README.md
new file mode 100644
index 0000000..942de18
--- /dev/null
+++ b/pycocoevalcap/README.md
@@ -0,0 +1,23 @@
+Microsoft COCO Caption Evaluation Tools
+---
+
+Modified the code to work with Python 3.
+
+### Requirements
+* Python 3.x
+* Java 1.8
+* pycocotools
+
+---
+
+### Tested on
+* Windows 10, Python 3.5.
+
+---
+### To fix Windows JVM memory error:
+Add the following in System Variables
+ Variable name : _JAVA_OPTIONS
+ Variable value : -Xmx1024M
+
+---
+Original code : https://github.com/tylin/coco-caption
diff --git a/pycocoevalcap/__init__.py b/pycocoevalcap/__init__.py
new file mode 100644
index 0000000..680063e
--- /dev/null
+++ b/pycocoevalcap/__init__.py
@@ -0,0 +1 @@
+__author__ = 'tylin'
\ No newline at end of file
diff --git a/pycocoevalcap/bleu/LICENSE b/pycocoevalcap/bleu/LICENSE
new file mode 100644
index 0000000..9ccf677
--- /dev/null
+++ b/pycocoevalcap/bleu/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/pycocoevalcap/bleu/__init__.py b/pycocoevalcap/bleu/__init__.py
new file mode 100644
index 0000000..680063e
--- /dev/null
+++ b/pycocoevalcap/bleu/__init__.py
@@ -0,0 +1 @@
+__author__ = 'tylin'
\ No newline at end of file
diff --git a/pycocoevalcap/bleu/bleu.py b/pycocoevalcap/bleu/bleu.py
new file mode 100644
index 0000000..60e723e
--- /dev/null
+++ b/pycocoevalcap/bleu/bleu.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+#
+# File Name : bleu.py
+#
+# Description : Wrapper for BLEU scorer.
+#
+# Creation Date : 06-01-2015
+# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
+# Authors : Hao Fang and Tsung-Yi Lin
+
+# Last modified : Wed 22 May 2019 08:10:00 PM EDT
+# By Sabarish Sivanath
+# To support Python 3
+
+from .bleu_scorer import BleuScorer
+
+
+class Bleu:
+ def __init__(self, n=4):
+ # default compute Blue score up to 4
+ self._n = n
+ self._hypo_for_image = {}
+ self.ref_for_image = {}
+
+ def compute_score(self, gts, res, score_option = 'closest', verbose = 1):
+ '''
+ Inputs:
+ gts - ground truths
+ res - predictions
+ score_option - {shortest, closest, average}
+ verbose - 1 or 0
+ Outputs:
+ Blue scores
+ '''
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ bleu_scorer = BleuScorer(n=self._n)
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
+
+ # Sanity check.
+ assert(type(hypo) is list)
+ assert(len(hypo) == 1)
+ assert(type(ref) is list)
+ #assert(len(ref) >= 1)
+
+ bleu_scorer += (hypo[0], ref)
+
+ score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose)
+
+ # return (bleu, bleu_info)
+ return score, scores
+
+ def method(self):
+ return "Bleu"
diff --git a/pycocoevalcap/bleu/bleu_scorer.py b/pycocoevalcap/bleu/bleu_scorer.py
new file mode 100644
index 0000000..d5646aa
--- /dev/null
+++ b/pycocoevalcap/bleu/bleu_scorer.py
@@ -0,0 +1,268 @@
+# bleu_scorer.py
+# David Chiang
+
+# Copyright (c) 2004-2006 University of Maryland. All rights
+# reserved. Do not redistribute without permission from the
+# author. Not for commercial use.
+
+# Modified by:
+# Hao Fang
+# Tsung-Yi Lin
+
+# Last modified : Wed 22 May 2019 08:10:00 PM EDT
+# By Sabarish Sivanath
+# To support Python 3
+
+'''Provides:
+cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
+cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
+'''
+
+import copy
+import sys, math, re
+from collections import defaultdict
+
+def precook(s, n=4, out=False):
+ """Takes a string as input and returns an object that can be given to
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
+ can take string arguments as well."""
+ words = s.split()
+ counts = defaultdict(int)
+ for k in range(1,n+1):
+ for i in range(len(words)-k+1):
+ ngram = tuple(words[i:i+k])
+ counts[ngram] += 1
+ return (len(words), counts)
+
+def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
+ '''Takes a list of reference sentences for a single segment
+ and returns an object that encapsulates everything that BLEU
+ needs to know about them.'''
+
+ reflen = []
+ maxcounts = {}
+ for ref in refs:
+ rl, counts = precook(ref, n)
+ reflen.append(rl)
+ for (ngram,count) in counts.items():
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
+
+ # Calculate effective reference sentence length.
+ if eff == "shortest":
+ reflen = min(reflen)
+ elif eff == "average":
+ reflen = float(sum(reflen))/len(reflen)
+
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
+
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
+
+ return (reflen, maxcounts)
+
+def cook_test(test, refs , eff=None, n=4):
+ '''Takes a test sentence and returns an object that
+ encapsulates everything that BLEU needs to know about it.'''
+
+ reflen = refs[0]
+ refmaxcounts = refs[1]
+
+ testlen, counts = precook(test, n, True)
+
+ result = {}
+
+ # Calculate effective reference sentence length.
+
+ if eff == "closest":
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
+ else: ## i.e., "average" or "shortest" or None
+ result["reflen"] = reflen
+
+ result["testlen"] = testlen
+
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
+
+ result['correct'] = [0]*n
+ for (ngram, count) in counts.items():
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
+
+ return result
+
+class BleuScorer(object):
+ """Bleu scorer.
+ """
+
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
+ # special_reflen is used in oracle (proportional effective ref len for a node).
+
+ def copy(self):
+ ''' copy the refs.'''
+ new = BleuScorer(n=self.n)
+ new.ctest = copy.copy(self.ctest)
+ new.crefs = copy.copy(self.crefs)
+ new._score = None
+ return new
+
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
+ ''' singular instance '''
+
+ self.n = n
+ self.crefs = []
+ self.ctest = []
+ self.cook_append(test, refs)
+ self.special_reflen = special_reflen
+
+ def cook_append(self, test, refs):
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
+
+ if refs is not None:
+ self.crefs.append(cook_refs(refs))
+ if test is not None:
+ cooked_test = cook_test(test, self.crefs[-1])
+ self.ctest.append(cooked_test) ## N.B.: -1
+ else:
+ self.ctest.append(None) # lens of crefs and ctest have to match
+
+ self._score = None ## need to recompute
+
+ def ratio(self, option=None):
+ self.compute_score(option=option)
+ return self._ratio
+
+ def score_ratio(self, option=None):
+ '''return (bleu, len_ratio) pair'''
+ return (self.fscore(option=option), self.ratio(option=option))
+
+ def score_ratio_str(self, option=None):
+ return "%.4f (%.2f)" % self.score_ratio(option)
+
+ def reflen(self, option=None):
+ self.compute_score(option=option)
+ return self._reflen
+
+ def testlen(self, option=None):
+ self.compute_score(option=option)
+ return self._testlen
+
+ def retest(self, new_test):
+ if type(new_test) is str:
+ new_test = [new_test]
+ assert len(new_test) == len(self.crefs), new_test
+ self.ctest = []
+ for t, rs in zip(new_test, self.crefs):
+ self.ctest.append(cook_test(t, rs))
+ self._score = None
+
+ return self
+
+ def rescore(self, new_test):
+ ''' replace test(s) with new test(s), and returns the new score.'''
+
+ return self.retest(new_test).compute_score()
+
+ def size(self):
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
+ return len(self.crefs)
+
+ def __iadd__(self, other):
+ '''add an instance (e.g., from another sentence).'''
+
+ if type(other) is tuple:
+ ## avoid creating new BleuScorer instances
+ self.cook_append(other[0], other[1])
+ else:
+ assert self.compatible(other), "incompatible BLEUs."
+ self.ctest.extend(other.ctest)
+ self.crefs.extend(other.crefs)
+ self._score = None ## need to recompute
+
+ return self
+
+ def compatible(self, other):
+ return isinstance(other, BleuScorer) and self.n == other.n
+
+ def single_reflen(self, option="average"):
+ return self._single_reflen(self.crefs[0][0], option)
+
+ def _single_reflen(self, reflens, option=None, testlen=None):
+
+ if option == "shortest":
+ reflen = min(reflens)
+ elif option == "average":
+ reflen = float(sum(reflens))/len(reflens)
+ elif option == "closest":
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
+ else:
+ assert False, "unsupported reflen option %s" % option
+
+ return reflen
+
+ def recompute_score(self, option=None, verbose=0):
+ self._score = None
+ return self.compute_score(option, verbose)
+
+ def compute_score(self, option=None, verbose=0):
+ n = self.n
+ small = 1e-9
+ tiny = 1e-15 ## so that if guess is 0 still return 0
+ bleu_list = [[] for _ in range(n)]
+
+ if self._score is not None:
+ return self._score
+
+ if option is None:
+ option = "average" if len(self.crefs) == 1 else "closest"
+
+ self._testlen = 0
+ self._reflen = 0
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
+
+ # for each sentence
+ for comps in self.ctest:
+ testlen = comps['testlen']
+ self._testlen += testlen
+
+ if self.special_reflen is None: ## need computation
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
+ else:
+ reflen = self.special_reflen
+
+ self._reflen += reflen
+
+ for key in ['guess','correct']:
+ for k in range(n):
+ totalcomps[key][k] += comps[key][k]
+
+ # append per image bleu score
+ bleu = 1.
+ for k in range(n):
+ bleu *= (float(comps['correct'][k]) + tiny) \
+ /(float(comps['guess'][k]) + small)
+ bleu_list[k].append(bleu ** (1./(k+1)))
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
+ if ratio < 1:
+ for k in range(n):
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
+
+ if verbose > 1:
+ print(comps, reflen)
+
+ totalcomps['reflen'] = self._reflen
+ totalcomps['testlen'] = self._testlen
+
+ bleus = []
+ bleu = 1.
+ for k in range(n):
+ bleu *= float(totalcomps['correct'][k] + tiny) \
+ / (totalcomps['guess'][k] + small)
+ bleus.append(bleu ** (1./(k+1)))
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
+ if ratio < 1:
+ for k in range(n):
+ bleus[k] *= math.exp(1 - 1/ratio)
+
+ if verbose > 0:
+ print(totalcomps)
+ print("ratio:", ratio)
+
+ self._score = bleus
+ return self._score, bleu_list
diff --git a/pycocoevalcap/cider/__init__.py b/pycocoevalcap/cider/__init__.py
new file mode 100644
index 0000000..3f7d85b
--- /dev/null
+++ b/pycocoevalcap/cider/__init__.py
@@ -0,0 +1 @@
+__author__ = 'tylin'
diff --git a/pycocoevalcap/cider/cider.py b/pycocoevalcap/cider/cider.py
new file mode 100644
index 0000000..7aadb9a
--- /dev/null
+++ b/pycocoevalcap/cider/cider.py
@@ -0,0 +1,55 @@
+# Filename: cider.py
+#
+# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
+# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
+#
+# Creation Date: Sun Feb 8 14:16:54 2015
+#
+# Authors: Ramakrishna Vedantam and Tsung-Yi Lin
+
+
+from .cider_scorer import CiderScorer
+import pdb
+
+class Cider:
+ """
+ Main Class to compute the CIDEr metric
+
+ """
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
+ # set cider to sum over 1 to 4-grams
+ self._n = n
+ # set the standard deviation parameter for gaussian penalty
+ self._sigma = sigma
+
+ def compute_score(self, gts, res):
+ """
+ Main function to compute CIDEr score
+ :param hypo_for_image (dict) : dictionary with key and value
+ ref_for_image (dict) : dictionary with key and value
+ :return: cider (float) : computed CIDEr score for the corpus
+ """
+
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
+
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
+
+ # Sanity check.
+ assert(type(hypo) is list)
+ assert(len(hypo) == 1)
+ assert(type(ref) is list)
+ assert(len(ref) > 0)
+
+ cider_scorer += (hypo[0], ref)
+
+ (score, scores) = cider_scorer.compute_score()
+
+ return score, scores
+
+ def method(self):
+ return "CIDEr"
\ No newline at end of file
diff --git a/pycocoevalcap/cider/cider_scorer.py b/pycocoevalcap/cider/cider_scorer.py
new file mode 100644
index 0000000..94752e8
--- /dev/null
+++ b/pycocoevalcap/cider/cider_scorer.py
@@ -0,0 +1,197 @@
+#!/usr/bin/env python
+# Tsung-Yi Lin
+# Ramakrishna Vedantam
+
+
+# Last modified : Wed 22 May 2019 08:10:00 PM EDT
+# By Sabarish Sivanath
+# To support Python 3
+
+import copy
+from collections import defaultdict
+import numpy as np
+import pdb
+import math
+
+def precook(s, n=4, out=False):
+ """
+ Takes a string as input and returns an object that can be given to
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
+ can take string arguments as well.
+ :param s: string : sentence to be converted into ngrams
+ :param n: int : number of ngrams for which representation is calculated
+ :return: term frequency vector for occuring ngrams
+ """
+ words = s.split()
+ counts = defaultdict(int)
+ for k in range(1,n+1):
+ for i in range(len(words)-k+1):
+ ngram = tuple(words[i:i+k])
+ counts[ngram] += 1
+ return counts
+
+def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
+ '''Takes a list of reference sentences for a single segment
+ and returns an object that encapsulates everything that BLEU
+ needs to know about them.
+ :param refs: list of string : reference sentences for some image
+ :param n: int : number of ngrams for which (ngram) representation is calculated
+ :return: result (list of dict)
+ '''
+ return [precook(ref, n) for ref in refs]
+
+def cook_test(test, n=4):
+ '''Takes a test sentence and returns an object that
+ encapsulates everything that BLEU needs to know about it.
+ :param test: list of string : hypothesis sentence for some image
+ :param n: int : number of ngrams for which (ngram) representation is calculated
+ :return: result (dict)
+ '''
+ return precook(test, n, True)
+
+class CiderScorer(object):
+ """CIDEr scorer.
+ """
+
+ def copy(self):
+ ''' copy the refs.'''
+ new = CiderScorer(n=self.n)
+ new.ctest = copy.copy(self.ctest)
+ new.crefs = copy.copy(self.crefs)
+ return new
+
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
+ ''' singular instance '''
+ self.n = n
+ self.sigma = sigma
+ self.crefs = []
+ self.ctest = []
+ self.document_frequency = defaultdict(float)
+ self.cook_append(test, refs)
+ self.ref_len = None
+
+ def cook_append(self, test, refs):
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
+
+ if refs is not None:
+ self.crefs.append(cook_refs(refs))
+ if test is not None:
+ self.ctest.append(cook_test(test)) ## N.B.: -1
+ else:
+ self.ctest.append(None) # lens of crefs and ctest have to match
+
+ def size(self):
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
+ return len(self.crefs)
+
+ def __iadd__(self, other):
+ '''add an instance (e.g., from another sentence).'''
+
+ if type(other) is tuple:
+ ## avoid creating new CiderScorer instances
+ self.cook_append(other[0], other[1])
+ else:
+ self.ctest.extend(other.ctest)
+ self.crefs.extend(other.crefs)
+
+ return self
+ def compute_doc_freq(self):
+ '''
+ Compute term frequency for reference data.
+ This will be used to compute idf (inverse document frequency later)
+ The term frequency is stored in the object
+ :return: None
+ '''
+ for refs in self.crefs:
+ # refs, k ref captions of one image
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
+ self.document_frequency[ngram] += 1
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
+
+ def compute_cider(self):
+ def counts2vec(cnts):
+ """
+ Function maps counts of ngram to vector of tfidf weights.
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
+ The n-th entry of array denotes length of n-grams.
+ :param cnts:
+ :return: vec (array of dict), norm (array of float), length (int)
+ """
+ vec = [defaultdict(float) for _ in range(self.n)]
+ length = 0
+ norm = [0.0 for _ in range(self.n)]
+ for (ngram,term_freq) in cnts.items():
+ # give word count 1 if it doesn't appear in reference corpus
+ df = np.log(max(1.0, self.document_frequency[ngram]))
+ # ngram index
+ n = len(ngram)-1
+ # tf (term_freq) * idf (precomputed idf) for n-grams
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
+ # compute norm for the vector. the norm will be used for computing similarity
+ norm[n] += pow(vec[n][ngram], 2)
+
+ if n == 1:
+ length += term_freq
+ norm = [np.sqrt(n) for n in norm]
+ return vec, norm, length
+
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
+ '''
+ Compute the cosine similarity of two vectors.
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
+ :param vec_ref: array of dictionary for vector corresponding to reference
+ :param norm_hyp: array of float for vector corresponding to hypothesis
+ :param norm_ref: array of float for vector corresponding to reference
+ :param length_hyp: int containing length of hypothesis
+ :param length_ref: int containing length of reference
+ :return: array of score for each n-grams cosine similarity
+ '''
+ delta = float(length_hyp - length_ref)
+ # measure consine similarity
+ val = np.array([0.0 for _ in range(self.n)])
+ for n in range(self.n):
+ # ngram
+ for (ngram,count) in vec_hyp[n].items():
+ # vrama91 : added clipping
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
+
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
+ val[n] /= (norm_hyp[n]*norm_ref[n])
+
+ assert(not math.isnan(val[n]))
+ # vrama91: added a length based gaussian penalty
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
+ return val
+
+ # compute log reference length
+ self.ref_len = np.log(float(len(self.crefs)))
+
+ scores = []
+ for test, refs in zip(self.ctest, self.crefs):
+ # compute vector for test captions
+ vec, norm, length = counts2vec(test)
+ # compute vector for ref captions
+ score = np.array([0.0 for _ in range(self.n)])
+ for ref in refs:
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
+ # change by vrama91 - mean of ngram scores, instead of sum
+ score_avg = np.mean(score)
+ # divide by number of references
+ score_avg /= len(refs)
+ # multiply score by 10
+ score_avg *= 10.0
+ # append score of an image to the score list
+ scores.append(score_avg)
+ return scores
+
+ def compute_score(self, option=None, verbose=0):
+ # compute idf
+ self.compute_doc_freq()
+ # assert to check document frequency
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
+ # compute cider score
+ score = self.compute_cider()
+ # debug
+ # print score
+ return np.mean(np.array(score)), np.array(score)
\ No newline at end of file
diff --git a/pycocoevalcap/eval.py b/pycocoevalcap/eval.py
new file mode 100644
index 0000000..21f53dc
--- /dev/null
+++ b/pycocoevalcap/eval.py
@@ -0,0 +1,74 @@
+__author__ = 'tylin'
+from .tokenizer.ptbtokenizer import PTBTokenizer
+from .bleu.bleu import Bleu
+from .meteor.meteor import Meteor
+from .rouge.rouge import Rouge
+from .cider.cider import Cider
+
+class COCOEvalCap:
+ def __init__(self, coco, cocoRes):
+ self.evalImgs = []
+ self.eval = {}
+ self.imgToEval = {}
+ self.coco = coco
+ self.cocoRes = cocoRes
+ self.params = {'image_id': cocoRes.getImgIds()}
+
+ def evaluate(self):
+ imgIds = self.params['image_id']
+ # imgIds = self.coco.getImgIds()
+ gts = {}
+ res = {}
+ for imgId in imgIds:
+ gts[imgId] = self.coco.imgToAnns[imgId]
+ res[imgId] = self.cocoRes.imgToAnns[imgId]
+
+ # =================================================
+ # Set up scorers
+ # =================================================
+ print('tokenization...')
+ tokenizer = PTBTokenizer()
+ gts = tokenizer.tokenize(gts)
+ res = tokenizer.tokenize(res)
+
+ # =================================================
+ # Set up scorers
+ # =================================================
+ print('setting up scorers...')
+ scorers = [
+ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
+ (Meteor(),"METEOR"),
+ (Rouge(), "ROUGE_L"),
+ (Cider(), "CIDEr")
+ ]
+
+ # =================================================
+ # Compute scores
+ # =================================================
+ eval = {}
+ for scorer, method in scorers:
+ print('computing %s score...'%(scorer.method()))
+ score, scores = scorer.compute_score(gts, res)
+ if type(method) == list:
+ for sc, scs, m in zip(score, scores, method):
+ self.setEval(sc, m)
+ self.setImgToEvalImgs(scs, imgIds, m)
+ print("%s: %0.3f"%(m, sc))
+ else:
+ self.setEval(score, method)
+ self.setImgToEvalImgs(scores, imgIds, method)
+ print("%s: %0.3f"%(method, score))
+ self.setEvalImgs()
+
+ def setEval(self, score, method):
+ self.eval[method] = score
+
+ def setImgToEvalImgs(self, scores, imgIds, method):
+ for imgId, score in zip(imgIds, scores):
+ if not imgId in self.imgToEval:
+ self.imgToEval[imgId] = {}
+ self.imgToEval[imgId]["image_id"] = imgId
+ self.imgToEval[imgId][method] = score
+
+ def setEvalImgs(self):
+ self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
diff --git a/pycocoevalcap/license.txt b/pycocoevalcap/license.txt
new file mode 100644
index 0000000..3ada56f
--- /dev/null
+++ b/pycocoevalcap/license.txt
@@ -0,0 +1,26 @@
+Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+The views and conclusions contained in the software and documentation are those
+of the authors and should not be interpreted as representing official policies,
+either expressed or implied, of the FreeBSD Project.
\ No newline at end of file
diff --git a/pycocoevalcap/meteor/__init__.py b/pycocoevalcap/meteor/__init__.py
new file mode 100644
index 0000000..349338d
--- /dev/null
+++ b/pycocoevalcap/meteor/__init__.py
@@ -0,0 +1 @@
+from .meteor import *
\ No newline at end of file
diff --git a/pycocoevalcap/meteor/data/paraphrase-en.gz b/pycocoevalcap/meteor/data/paraphrase-en.gz
new file mode 100644
index 0000000..88033c8
Binary files /dev/null and b/pycocoevalcap/meteor/data/paraphrase-en.gz differ
diff --git a/pycocoevalcap/meteor/meteor-1.5.jar b/pycocoevalcap/meteor/meteor-1.5.jar
new file mode 100644
index 0000000..a833bc0
Binary files /dev/null and b/pycocoevalcap/meteor/meteor-1.5.jar differ
diff --git a/pycocoevalcap/meteor/meteor.py b/pycocoevalcap/meteor/meteor.py
new file mode 100644
index 0000000..114b42a
--- /dev/null
+++ b/pycocoevalcap/meteor/meteor.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python
+
+# Python wrapper for METEOR implementation, by Xinlei Chen
+# Acknowledge Michael Denkowski for the generous discussion and help
+
+# Last modified : Wed 22 May 2019 08:10:00 PM EDT
+# By Sabarish Sivanath
+# To support Python 3
+
+import os
+import sys
+import subprocess
+import threading
+
+# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
+METEOR_JAR = 'meteor-1.5.jar'
+# print METEOR_JAR
+
+class Meteor:
+
+ def __init__(self):
+ self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
+ '-', '-', '-stdio', '-l', 'en', '-norm']
+ self.meteor_p = subprocess.Popen(self.meteor_cmd, \
+ cwd=os.path.dirname(os.path.abspath(__file__)), \
+ stdin=subprocess.PIPE, \
+ stdout=subprocess.PIPE, \
+ stderr=subprocess.PIPE,
+ universal_newlines = True,
+ bufsize = 1)
+ # Used to guarantee thread safety
+ self.lock = threading.Lock()
+
+ def compute_score(self, gts, res):
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+ scores = []
+
+ eval_line = 'EVAL'
+ self.lock.acquire()
+ for i in imgIds:
+ assert(len(res[i]) == 1)
+ stat = self._stat(res[i][0], gts[i])
+ eval_line += ' ||| {}'.format(stat)
+
+ self.meteor_p.stdin.write('{}\n'.format(eval_line))
+ for i in range(0,len(imgIds)):
+ scores.append(float(self.meteor_p.stdout.readline().strip()))
+ score = float(self.meteor_p.stdout.readline().strip())
+ self.lock.release()
+
+ return score, scores
+
+ def method(self):
+ return "METEOR"
+
+ def _stat(self, hypothesis_str, reference_list):
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
+ hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
+ self.meteor_p.stdin.write('{}\n'.format(score_line))
+ return self.meteor_p.stdout.readline().strip()
+
+ def _score(self, hypothesis_str, reference_list):
+ self.lock.acquire()
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
+ hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
+ self.meteor_p.stdin.write('{}\n'.format(score_line))
+ stats = self.meteor_p.stdout.readline().strip()
+ eval_line = 'EVAL ||| {}'.format(stats)
+ # EVAL ||| stats
+ self.meteor_p.stdin.write('{}\n'.format(eval_line))
+ score = float(self.meteor_p.stdout.readline().strip())
+ # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
+ # thanks for Andrej for pointing this out
+ score = float(self.meteor_p.stdout.readline().strip())
+ self.lock.release()
+ return score
+
+ def __del__(self):
+ self.lock.acquire()
+ self.meteor_p.stdin.close()
+ self.meteor_p.kill()
+ self.meteor_p.wait()
+ self.lock.release()
diff --git a/pycocoevalcap/rouge/__init__.py b/pycocoevalcap/rouge/__init__.py
new file mode 100644
index 0000000..e3c0469
--- /dev/null
+++ b/pycocoevalcap/rouge/__init__.py
@@ -0,0 +1 @@
+from .rouge import *
\ No newline at end of file
diff --git a/pycocoevalcap/rouge/rouge.py b/pycocoevalcap/rouge/rouge.py
new file mode 100644
index 0000000..3a10f5a
--- /dev/null
+++ b/pycocoevalcap/rouge/rouge.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+#
+# File Name : rouge.py
+#
+# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
+#
+# Creation Date : 2015-01-07 06:03
+# Author : Ramakrishna Vedantam
+
+import numpy as np
+import pdb
+
+def my_lcs(string, sub):
+ """
+ Calculates longest common subsequence for a pair of tokenized strings
+ :param string : list of str : tokens from a string split using whitespace
+ :param sub : list of str : shorter string, also split using whitespace
+ :returns: length (list of int): length of the longest common subsequence between the two strings
+
+ Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
+ """
+ if(len(string)< len(sub)):
+ sub, string = string, sub
+
+ lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
+
+ for j in range(1,len(sub)+1):
+ for i in range(1,len(string)+1):
+ if(string[i-1] == sub[j-1]):
+ lengths[i][j] = lengths[i-1][j-1] + 1
+ else:
+ lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
+
+ return lengths[len(string)][len(sub)]
+
+class Rouge():
+ '''
+ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
+
+ '''
+ def __init__(self):
+ # vrama91: updated the value below based on discussion with Hovey
+ self.beta = 1.2
+
+ def calc_score(self, candidate, refs):
+ """
+ Compute ROUGE-L score given one candidate and references for an image
+ :param candidate: str : candidate sentence to be evaluated
+ :param refs: list of str : COCO reference sentences for the particular image to be evaluated
+ :returns score: int (ROUGE-L score for the candidate evaluated against references)
+ """
+ assert(len(candidate)==1)
+ assert(len(refs)>0)
+ prec = []
+ rec = []
+
+ # split into tokens
+ token_c = candidate[0].split(" ")
+
+ for reference in refs:
+ # split into tokens
+ token_r = reference.split(" ")
+ # compute the longest common subsequence
+ lcs = my_lcs(token_r, token_c)
+ prec.append(lcs/float(len(token_c)))
+ rec.append(lcs/float(len(token_r)))
+
+ prec_max = max(prec)
+ rec_max = max(rec)
+
+ if(prec_max!=0 and rec_max !=0):
+ score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
+ else:
+ score = 0.0
+ return score
+
+ def compute_score(self, gts, res):
+ """
+ Computes Rouge-L score given a set of reference and candidate sentences for the dataset
+ Invoked by evaluate_captions.py
+ :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
+ :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
+ :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
+ """
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ score = []
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
+
+ score.append(self.calc_score(hypo, ref))
+
+ # Sanity check.
+ assert(type(hypo) is list)
+ assert(len(hypo) == 1)
+ assert(type(ref) is list)
+ assert(len(ref) > 0)
+
+ average_score = np.mean(np.array(score))
+ return average_score, np.array(score)
+
+ def method(self):
+ return "Rouge"
diff --git a/pycocoevalcap/tokenizer/__init__.py b/pycocoevalcap/tokenizer/__init__.py
new file mode 100644
index 0000000..71357a4
--- /dev/null
+++ b/pycocoevalcap/tokenizer/__init__.py
@@ -0,0 +1 @@
+__author__ = 'hfang'
diff --git a/pycocoevalcap/tokenizer/ptbtokenizer.py b/pycocoevalcap/tokenizer/ptbtokenizer.py
new file mode 100644
index 0000000..b7d06e1
--- /dev/null
+++ b/pycocoevalcap/tokenizer/ptbtokenizer.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python
+#
+# File Name : ptbtokenizer.py
+#
+# Description : Do the PTB Tokenization and remove punctuations.
+#
+# Creation Date : 29-12-2014
+# Last Modified : Thu Mar 19 09:53:35 2015
+# Authors : Hao Fang and Tsung-Yi Lin
+
+import os
+import sys
+import subprocess
+import tempfile
+import itertools
+
+
+# Last modified : Wed 22 May 2019 08:10:00 PM EDT
+# By Sabarish Sivanath
+# To support Python 3
+
+# path to the stanford corenlp jar
+STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
+
+# punctuations to be removed from the sentences
+PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
+
+class PTBTokenizer:
+ """Python wrapper of Stanford PTBTokenizer"""
+
+ def tokenize(self, captions_for_image):
+ cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
+ 'edu.stanford.nlp.process.PTBTokenizer', \
+ '-preserveLines', '-lowerCase']
+
+ # ======================================================
+ # prepare data for PTB Tokenizer
+ # ======================================================
+ final_tokenized_captions_for_image = {}
+ image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
+ sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
+
+ # ======================================================
+ # save sentences to temporary file
+ # ======================================================
+ path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
+ tmp_file.write(sentences.encode('utf-8'))
+ tmp_file.close()
+
+ # ======================================================
+ # tokenize sentence
+ # ======================================================
+ cmd.append(os.path.basename(tmp_file.name))
+ p_tokenizer = subprocess.Popen(cmd,
+ cwd=path_to_jar_dirname,
+ stdout=subprocess.PIPE,
+ universal_newlines = True,
+ bufsize = 1)
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
+ lines = token_lines.split('\n')
+ # remove temp file
+ os.remove(tmp_file.name)
+
+ # ======================================================
+ # create dictionary for tokenized captions
+ # ======================================================
+ for k, line in zip(image_id, lines):
+ if not k in final_tokenized_captions_for_image:
+ final_tokenized_captions_for_image[k] = []
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
+ if w not in PUNCTUATIONS])
+ final_tokenized_captions_for_image[k].append(tokenized_caption)
+
+ return final_tokenized_captions_for_image
diff --git a/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar
new file mode 100644
index 0000000..3cfa0a0
Binary files /dev/null and b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar differ
diff --git a/scorer/cider.py b/scorer/cider.py
index c4cc910..3415ae4 100755
--- a/scorer/cider.py
+++ b/scorer/cider.py
@@ -33,20 +33,24 @@ def compute_score(self, gts, res):
:return: cider (float) : computed CIDEr score for the corpus
"""
- # clear all the previous hypos and refs
- self.cider_scorer.clear()
- for i, hypo in enumerate(res):
- ref = gts[i]
+ assert(gts.keys() == res.keys())
+ imgIds = gts.keys()
+
+ cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
+
+ for id in imgIds:
+ hypo = res[id]
+ ref = gts[id]
# Sanity check.
- #assert(type(hypo) is list)
- #assert(len(hypo) == 1)
+ assert(type(hypo) is list)
+ assert(len(hypo) == 1)
assert(type(ref) is list)
assert(len(ref) > 0)
- self.cider_scorer += (hypo, ref)
+ cider_scorer += (hypo[0], ref)
- (score, scores) = self.cider_scorer.compute_score()
+ (score, scores) = cider_scorer.compute_score()
return score, scores
diff --git a/scorer/cider_scorer.py b/scorer/cider_scorer.py
index 793e180..71c4330 100755
--- a/scorer/cider_scorer.py
+++ b/scorer/cider_scorer.py
@@ -11,6 +11,7 @@
import math
import pickle
from lib.config import cfg
+import gzip
def precook(words, n=4, out=False):
"""
@@ -64,10 +65,14 @@ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
self.sigma = sigma
self.crefs = []
self.ctest = []
-
- cider_cache = pickle.load(open(cfg.SCORER.CIDER_CACHED, 'rb'), encoding='bytes')
- self.document_frequency = cider_cache['document_frequency']
- self.ref_len = cider_cache['ref_len']
+ cider_cache = None # TODO
+ try:
+ cider_cache = pickle.load(gzip.open(cfg.SCORER.CIDER_CACHED))
+ except:
+ cider_cache = None
+ # print('CIDEr df: {0}'.format(len(cider_cache)))
+ self.document_frequency = defaultdict(float) if not cider_cache else cider_cache['document_frequency']
+ self.ref_len = None if not cider_cache else cider_cache['ref_len']
self.cook_append(test, refs)
def clear(self):
@@ -169,7 +174,7 @@ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
return val
# compute log reference length
- # self.ref_len = np.log(float(len(self.crefs))) ###########################
+ self.ref_len = np.log(float(len(self.crefs))) ###########################
scores = []
for test, refs in zip(self.ctest, self.crefs):
diff --git a/scorer/scorer.py b/scorer/scorer.py
index bb1e25a..6e1460c 100755
--- a/scorer/scorer.py
+++ b/scorer/scorer.py
@@ -23,16 +23,16 @@ def __init__(self):
super(Scorer, self).__init__()
self.scorers = []
self.weights = cfg.SCORER.WEIGHTS
- self.gts = pickle.load(open(cfg.SCORER.GT_PATH, 'rb'), encoding='bytes')
+ # self.gts = pickle.load(open(cfg.SCORER.GT_PATH, 'rb'), encoding='bytes')
for name in cfg.SCORER.TYPES:
self.scorers.append(factory[name]())
- def __call__(self, ids, res):
+ def __call__(self, gts, res):
hypo = [get_sents(r) for r in res]
- gts = [self.gts[i] for i in ids]
+ gts = gts
rewards_info = {}
- rewards = np.zeros(len(ids))
+ rewards = np.zeros(len(gts))
for i, scorer in enumerate(self.scorers):
score, scores = scorer.compute_score(gts, hypo)
rewards += self.weights[i] * scores