diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..73f69e0 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/image-captioning.iml b/.idea/image-captioning.iml new file mode 100644 index 0000000..8b8c395 --- /dev/null +++ b/.idea/image-captioning.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..eeb3b77 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,87 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..79a4365 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..1197701 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/artimes_experiments/iuxray_rgmg/config.yml b/artimes_experiments/iuxray_rgmg/config.yml new file mode 100644 index 0000000..5dada3f --- /dev/null +++ b/artimes_experiments/iuxray_rgmg/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 20 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 100 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/artimes_experiments/iuxray_rgmg/train.sh b/artimes_experiments/iuxray_rgmg/train.sh new file mode 100644 index 0000000..fca3a35 --- /dev/null +++ b/artimes_experiments/iuxray_rgmg/train.sh @@ -0,0 +1,7 @@ +#CUDA_VISIBLE_DEVICES=0 -m torch.distributed.launch --nproc_per_node=1 +python3 main.py --folder ./artimes_experiments/iuxray_rgmg --resume 0 --submodel rgmg --KG_path /project/CVML/pretrained_kg/rgmg_iuxray_pretrain.pth --dataset_name IUXRAY --image_dir /project/CVML/Parallel-R2Gen-KG/data/iu_xray/images/ --ann_path /project/CVML/Parallel-R2Gen-KG/data/iuxray/annotation.json + + + + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/datasets/__init__.py b/datasets/__init__.py index e69de29..8a3d705 100755 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -0,0 +1,19 @@ +from datasets.coco_dataset import CocoDataset +from datasets.radiology_dataset import IUXRAY +from datasets.radiology_dataset import MIMICCXR +from datasets.radiology_dataset import MimiccxrMultiImage + +__factory = { + 'IUXRAY': IUXRAY, + 'MIMICCXR': MIMICCXR, + 'MIMICCXR_MultiImages': MimiccxrMultiImage, + 'COCO': CocoDataset, +} + +def names(): + return sorted(__factory.keys()) + +def create(name, *args, **kwargs): + if name not in __factory: + raise KeyError("Unknown Dataset:", name) + return __factory[name](*args, **kwargs) \ No newline at end of file diff --git a/datasets/data_loader.py b/datasets/data_loader.py index 2dc9476..f147add 100755 --- a/datasets/data_loader.py +++ b/datasets/data_loader.py @@ -3,97 +3,118 @@ from torchvision import transforms from lib.config import cfg from datasets.coco_dataset import CocoDataset +from datasets.radiology_dataset import IUXRAY import samplers.distributed import numpy as np +import argparse +import sys + def sample_collate(batch): + mask_dim = 70 indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch) - - indices = np.stack(indices, axis=0).reshape(-1) - input_seq = torch.cat([torch.from_numpy(b) for b in input_seq], 0) - target_seq = torch.cat([torch.from_numpy(b) for b in target_seq], 0) - gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) - atts_num = [x.shape[0] for x in att_feats] - max_att_num = np.max(atts_num) + max_seq_length = max([len(x) for x in input_seq]) + input_seqs = np.zeros((len(input_seq), max_seq_length), dtype=int) + target_seqs = np.zeros((len(target_seq), max_seq_length), dtype=int) + # print(max_seq_length) - feat_arr = [] - mask_arr = [] - for i, num in enumerate(atts_num): - tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32) - tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i] - feat_arr.append(torch.from_numpy(tmp_feat)) + for i, input in enumerate(input_seq): + input_seqs[i, :len(input)] = input - tmp_mask = np.zeros((1, max_att_num), dtype=np.float32) - tmp_mask[:, 0:num] = 1 - mask_arr.append(torch.from_numpy(tmp_mask)) + for i, target in enumerate(target_seq): + target_seqs[i, :len(target)] = target - att_feats = torch.cat(feat_arr, 0) - att_mask = torch.cat(mask_arr, 0) + indices = np.stack(indices, axis=0).reshape(-1) - return indices, input_seq, target_seq, gv_feat, att_feats, att_mask + gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) + + # IT DOESNT MATTER WHAT THE SHAPE OF MASK IS, WE WILL GET THE MASK LATER + mask_arr = torch.ones([len(att_feats), mask_dim]).float() + + att_mask = torch.cat([mask_arr], 0) + att_feats = torch.stack(att_feats, 0) # TODO for mimic + """ + indices (40, ) + input_seq (40, 60) + target_seq (40, 60) + gv_feat (40, 1) + att_feats (40, 2, 3, 224, 224) + att_mask (40, 1, 3) => (40, 49) + """ + return indices, torch.LongTensor(input_seqs), torch.LongTensor(target_seqs), gv_feat, att_feats, att_mask def sample_collate_val(batch): - indices, gv_feat, att_feats = zip(*batch) - - indices = np.stack(indices, axis=0).reshape(-1) - gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) + mask_dim = 70 + + indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch) + + max_seq_length = max([len(x) for x in input_seq]) + input_seqs = np.zeros((len(input_seq), max_seq_length), dtype=int) + target_seqs = np.zeros((len(target_seq), max_seq_length), dtype=int) + # print(max_seq_length) - atts_num = [x.shape[0] for x in att_feats] - max_att_num = np.max(atts_num) + for i, input in enumerate(input_seq): + input_seqs[i, :len(input)] = input - feat_arr = [] - mask_arr = [] - for i, num in enumerate(atts_num): - tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32) - tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i] - feat_arr.append(torch.from_numpy(tmp_feat)) + for i, target in enumerate(target_seq): + target_seqs[i, :len(target)] = target - tmp_mask = np.zeros((1, max_att_num), dtype=np.float32) - tmp_mask[:, 0:num] = 1 - mask_arr.append(torch.from_numpy(tmp_mask)) + indices = np.stack(indices, axis=0).reshape(-1) + + gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) - att_feats = torch.cat(feat_arr, 0) - att_mask = torch.cat(mask_arr, 0) + # IT DOESNT MATTER WHAT THE SHAPE OF MASK IS, WE WILL GET THE MASK LATER + mask_arr = torch.ones([len(att_feats), mask_dim]).float() - return indices, gv_feat, att_feats, att_mask + att_mask = torch.cat([mask_arr], 0) + att_feats = torch.stack(att_feats, 0) + """ + indices (40, ) + input_seq (40, 60) + target_seq (40, 60) + gv_feat (40, 1) + att_feats (40, 2, 3, 224, 224) + att_mask (40, 1, 3) => (40, 49) + """ + return indices, torch.LongTensor(target_seqs), gv_feat, att_feats, att_mask -def load_train(distributed, epoch, coco_set): - sampler = samplers.distributed.DistributedSampler(coco_set, epoch=epoch) \ +def load_train(distributed, epoch, dataset): + sampler = samplers.distributed.DistributedSampler(dataset, epoch=epoch) \ if distributed else None shuffle = cfg.DATA_LOADER.SHUFFLE if sampler is None else False - + loader = torch.utils.data.DataLoader( - coco_set, + dataset, batch_size = cfg.TRAIN.BATCH_SIZE, - shuffle = shuffle, - num_workers = cfg.DATA_LOADER.NUM_WORKERS, - drop_last = cfg.DATA_LOADER.DROP_LAST, + shuffle = shuffle, + num_workers = cfg.DATA_LOADER.NUM_WORKERS, + drop_last = cfg.DATA_LOADER.DROP_LAST, pin_memory = cfg.DATA_LOADER.PIN_MEMORY, - sampler = sampler, + sampler = sampler, collate_fn = sample_collate ) return loader -def load_val(image_ids_path, gv_feat_path, att_feats_folder): - coco_set = CocoDataset( - image_ids_path = image_ids_path, - input_seq = None, - target_seq = None, - gv_feat_path = gv_feat_path, - att_feats_folder = att_feats_folder, - seq_per_img = 1, - max_feat_num = cfg.DATA_LOADER.MAX_FEAT - ) +def load_val(dataset): + # coco_set = CocoDataset( + # image_ids_path = image_ids_path, + # input_seq = None, + # target_seq = None, + # gv_feat_path = gv_feat_path, + # att_feats_folder = att_feats_folder, + # seq_per_img = 1, + # max_feat_num = cfg.DATA_LOADER.MAX_FEAT + # ) loader = torch.utils.data.DataLoader( - coco_set, + dataset, batch_size = cfg.TEST.BATCH_SIZE, - shuffle = False, - num_workers = cfg.DATA_LOADER.NUM_WORKERS, - drop_last = False, - pin_memory = cfg.DATA_LOADER.PIN_MEMORY, + shuffle = False, + num_workers = cfg.DATA_LOADER.NUM_WORKERS, + drop_last = False, + pin_memory = cfg.DATA_LOADER.PIN_MEMORY, collate_fn = sample_collate_val ) - return loader \ No newline at end of file + return loader diff --git a/datasets/radiology_dataset.py b/datasets/radiology_dataset.py new file mode 100644 index 0000000..cee6976 --- /dev/null +++ b/datasets/radiology_dataset.py @@ -0,0 +1,217 @@ +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import json +from PIL import Image +from .tokenizers import Tokenizer +import random +import time +import copy + +def random_position(image_1, image_2, thr): + if random.random() < thr: + img = torch.stack((image_1, image_2), 0) + else: + img = torch.stack((image_2, image_1), 0) + return img + + + +class BaseDataset(Dataset): + def __init__(self, image_dir, ann_path, tokenizer, split, args): + self.image_dir = image_dir + self.ann_path = ann_path + self.max_seq_length = 60 # hardcode + self.split = split + self.args = args + self.tokenizer = tokenizer + self.training_ratio = args.training_ratio + assert 0.0 <= self.training_ratio <= 1.0 + + if split == 'train': + self.transform = transforms.Compose([ + transforms.Resize(256), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + else: + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))]) + self.texts = json.loads(open(self.ann_path, 'r').read()) + + self.examples = self.texts[self.split] + if args.dataset_name == 'MIMICCXR_MultiImages': + self.examples = self.convert_to_multi_images(self.examples) + if self.split == 'train': + self.examples = self.apply_training_ratio(self.examples) + for i in range(len(self.examples)): + self.examples[i]['ids'] = self.tokenizer(self.examples[i]['report'])[:self.max_seq_length] + + def apply_training_ratio(self, dataset): + t = time.time() + total = len(dataset) + print('{} set: applying training_ratio {} ... '.format(self.split,self.training_ratio), end='', flush=True) + select = int(total // (1/self.training_ratio)) + 1 + dataset = dataset[:select] + print('done %d->%d (%.2fs)' % (total, select, time.time() - t), flush=True) + return dataset + + def convert_to_multi_images(self, dataset, print_num=True): + t = time.time() + n = 0 + if print_num: + print('{} set: Converting to multiple image reports ... '.format(self.split), end='', flush=True) + mergedDataset = [] + total = len(dataset) + + buffer = None + for i in range(total): + document = dataset[i] + id = document['id'] + image_path = document['image_path'][0] + # report = document['report'] + # split = document['split'] + study_id = document['study_id'] + # subject_id = document['subject_id'] + + if study_id == buffer: + mergedDataset[-1]['image_path'].append(image_path) + mergedDataset[-1]['id'].append(id) + else: + newDocument = copy.deepcopy(document) + newDocument['id'] = [newDocument['id']] + mergedDataset.append(newDocument) + n += 1 + buffer = study_id + if print_num: + print('done %d->%d (%.2fs)' % (total, n, time.time() - t), flush=True) + return mergedDataset + + def __len__(self): + return len(self.examples) + + +class IUXRAY(BaseDataset): + def __getitem__(self, idx): + # indices = np.array([idx]).astype('int') # Modified + image_id = self.examples[idx]['id'] + indices = np.array([image_id]) + example = self.examples[idx] + image_path = example['image_path'] + image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') + if self.transform is not None: + image_1 = self.transform(image_1) + image_2 = self.transform(image_2) + if self.split == 'train': + image = random_position(image_1, image_2, 0.5) + else: + image = torch.stack((image_1, image_2), 0) + + + report_ids = np.array(example['ids']) + + input_sequence = np.zeros(self.max_seq_length, dtype='int') + target_sequence = np.zeros(self.max_seq_length, dtype='int') + + input_sequence[:len(report_ids)] = report_ids + target_sequence[:len(report_ids)-1] = report_ids[1:] + + gv_feat = np.zeros((1, 1)) # Never been used + # report_masks = example['mask'] + # seq_length = len(report_ids) + return indices, input_sequence, target_sequence, gv_feat, image + + +class MIMICCXR(BaseDataset): # MimiccxrSingleImageDataset + def __getitem__(self, idx): + # indices = np.array([idx]).astype('int') # Modified + image_id = self.examples[idx]['id'] + indices = np.array([image_id]) + example = self.examples[idx] + image_path = example['image_path'] + image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + if self.transform is not None: + image = self.transform(image) + report_ids = np.array(example['ids']) + + input_sequence = np.ones(self.max_seq_length, dtype='int') + target_sequence = np.ones(self.max_seq_length, dtype='int') + + input_sequence[:len(report_ids)] = report_ids + target_sequence[:len(report_ids) - 1] = report_ids[1:] + + gv_feat = np.zeros((1, 1)) # Never been used + # report_masks = example['mask'] + # seq_length = len(report_ids) + return indices, input_sequence, target_sequence, gv_feat, image + +class MimiccxrMultiImage(BaseDataset): # MimiccxrMultiImageDataset + def __getitem__(self, idx): + # indices = np.array([idx]).astype('int') # Modified + image_id = self.examples[idx]['id'] + indices = np.array([image_id]) + example = self.examples[idx] + image_path = example['image_path'] + image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + if self.transform is not None: + image = self.transform(image) + report_ids = np.array(example['ids']) + + input_sequence = np.ones(self.max_seq_length, dtype='int') + target_sequence = np.ones(self.max_seq_length, dtype='int') + + input_sequence[:len(report_ids)] = report_ids + target_sequence[:len(report_ids) - 1] = report_ids[1:] + + gv_feat = np.zeros((1, 1)) # Never been used + # report_masks = example['mask'] + # seq_length = len(report_ids) + return indices, input_sequence, target_sequence, gv_feat, image + +class MimiccxrMultiImage(BaseDataset): + def __getitem__(self, idx): + image_id = str(self.examples[idx]['subject_id']) + '_' + str(self.examples[idx]['study_id']) + indices = np.array([image_id]) + example = self.examples[idx] + image_path = example['image_path'] + image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + try: + image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') + # if this record only have one image, duplicate image1 + except: + image_2 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') + + if self.transform is not None: + image_1 = self.transform(image_1) + image_2 = self.transform(image_2) + + def random_position(image_1, image_2, thr): + if random.random() < thr: + img = torch.stack((image_1, image_2), 0) + else: + img = torch.stack((image_2, image_1), 0) + return img + + if self.split == 'train': + image = random_position(image_1, image_2, 0.5) + else: + image = torch.stack((image_1, image_2), 0) + + report_ids = np.array(example['ids']) + + input_sequence = np.zeros(self.max_seq_length, dtype='int') + target_sequence = np.zeros(self.max_seq_length, dtype='int') + + input_sequence[:len(report_ids)] = report_ids + target_sequence[:len(report_ids)-1] = report_ids[1:] + + gv_feat = np.zeros((1, 1)) # Never been used + return indices, input_sequence, target_sequence, gv_feat, image \ No newline at end of file diff --git a/datasets/tokenizers.py b/datasets/tokenizers.py new file mode 100644 index 0000000..36d2261 --- /dev/null +++ b/datasets/tokenizers.py @@ -0,0 +1,129 @@ +import json +import re +from collections import Counter + + +class Tokenizer(object): + def __init__(self, ann_path, dataset_name): + self.ann_path = ann_path # todo + self.threshold = 3 + self.dataset_name = dataset_name # todo + if self.dataset_name == 'iu_xray': + self.clean_report = self.clean_report_iu_xray + else: + self.clean_report = self.clean_report_mimic_cxr + self.ann = json.loads(open(self.ann_path, 'r').read()) + self.token2idx, self.idx2token = self.create_vocabulary_and_CIDEr_DF() + + def create_vocabulary_and_CIDEr_DF(self): + # print('Building Vocabulary and CIDEr DF! ') + print('Building Vocabulary! ') + total_tokens = [] + + for example in self.ann['train']: + tokens = self.clean_report(example['report']).split() + for token in tokens: + total_tokens.append(token) + + counter = Counter(total_tokens) + vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] + vocab.sort() + token2idx, idx2token = {}, {} + for idx, token in enumerate(vocab): + token2idx[token] = idx + 1 + idx2token[idx + 1] = token + print('Vocab Done. Dataset: {} Vocab Size: {}, include "" , non exists "sos" "eos"'.format(self.dataset_name,len(vocab))) + + # for text in self.ann['train']: + # toks = tokenizer.tokenize(textfilter.filter(text)) + # toks = tokenfilter.filter(toks) + # ftext = ' '.join(toks) + # ftexts.append(ftext) + # + # df = GenEval.compute_cider_df(texts) + # with gzip.open(args.output, 'w') as f: + # pickle.dump(df, f) + + + return token2idx, idx2token + + def clean_report_iu_xray(self, report): + report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ + .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ + .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). + replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + def clean_report_mimic_cxr(self, report): + report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ + .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ + .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ + .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ + .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ + .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ + .strip().lower().split('. ') + sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') + .replace('\\', '').replace("'", '').strip().lower()) + tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] + report = ' . '.join(tokens) + ' .' + return report + + # def get_id2word(self): + + def get_token_by_id(self, id): + return self.idx2token[id] + + def get_id_by_token(self, token): + if token not in self.token2idx: + return self.token2idx[''] + return self.token2idx[token] + + def get_vocab_size(self): + return len(self.token2idx) + + def __call__(self, report): + tokens = self.clean_report(report).split() + ids = [] + for token in tokens: + ids.append(self.get_id_by_token(token)) + ids = [0] + ids + [0] + return ids + + def decode(self, ids): + txt = '' + for i, idx in enumerate(ids): + if idx > 0: + if i >= 1: + txt += ' ' + txt += self.idx2token[idx] + else: + break + return txt + + def decode_batch(self, ids_batch): + out = [] + for ids in ids_batch: + out.append(self.decode(ids)) + return out + + +# def cider_df(texts): +# tokenizer = get_tokenizer('nltk') +# textfilter = get_textfilter('lower') +# tokenfilter = get_tokenfilter('none') + +# ftexts = [] +# for text in texts: +# toks = tokenizer.tokenize(textfilter.filter(text)) +# toks = tokenfilter.filter(toks) +# ftext = ' '.join(toks) +# ftexts.append(ftext) + +# df = GenEval.compute_cider_df(texts) # texts : [report_text, ... ] , where report_text : [filtered_token, ... ] +# with gzip.open('mimic-cxr_train-df.bin.gz', 'w') as f: +# pickle.dump(df, f) \ No newline at end of file diff --git a/evaluation/evaler.py b/evaluation/evaler.py index 873062a..0c1fdf5 100755 --- a/evaluation/evaler.py +++ b/evaluation/evaler.py @@ -9,56 +9,113 @@ import datasets.data_loader as data_loader from lib.config import cfg +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.meteor import Meteor +from pycocoevalcap.rouge import Rouge +from scorer.cider import Cider +device = torch.device("cuda") + + +def compute_scores(gts, res): + """ + Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap) + + :param gts: Dictionary with the image ids and their gold captions, + :param res: Dictionary with the image ids ant their generated captions + :print: Evaluation score (the mean of the scores of all the instances) for each measure + """ + + # Set up scorers + scorers = [ + (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), + (Cider(), 'CIDEr'), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L") + ] + eval_res = {} + # Compute score for each metric + for scorer, method in scorers: + try: + score, scores = scorer.compute_score(gts, res, verbose=0) + except TypeError: + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, m in zip(score, method): + eval_res[m] = sc + else: + eval_res[method] = score + return eval_res + + class Evaler(object): def __init__( - self, - eval_ids, - gv_feat, - att_feats, - eval_annfile + self, + dataset, + tokenizer ): super(Evaler, self).__init__() - self.vocab = utils.load_vocab(cfg.INFERENCE.VOCAB) + self.tokenizer = tokenizer + # self.vocab = utils.load_vocab(cfg.INFERENCE.VOCAB) # TODO - self.eval_ids = np.array(utils.load_ids(eval_ids)) - self.eval_loader = data_loader.load_val(eval_ids, gv_feat, att_feats) - self.evaler = evaluation.create(cfg.INFERENCE.EVAL, eval_annfile) + # self.eval_ids = np.array(utils.load_ids(eval_ids)) + self.eval_loader = data_loader.load_val(dataset) + # self.evaler = evaluation.create(cfg.INFERENCE.EVAL, eval_annfile) + self.evaler = compute_scores def make_kwargs(self, indices, ids, gv_feat, att_feats, att_mask): kwargs = {} kwargs[cfg.PARAM.INDICES] = indices kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat kwargs[cfg.PARAM.ATT_FEATS] = att_feats + # att_mask_a = torch.ones(16,70).to(device) kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask + # kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask_a kwargs['BEAM_SIZE'] = cfg.INFERENCE.BEAM_SIZE kwargs['GREEDY_DECODE'] = cfg.INFERENCE.GREEDY_DECODE return kwargs - + def __call__(self, model, rname): model.eval() - - results = [] + + results, golden_sents = {}, {} with torch.no_grad(): - for _, (indices, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)): - ids = self.eval_ids[indices] + for _, (indices, target_seq, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)): + ids = indices gv_feat = gv_feat.cuda() att_feats = att_feats.cuda() att_mask = att_mask.cuda() kwargs = self.make_kwargs(indices, ids, gv_feat, att_feats, att_mask) if kwargs['BEAM_SIZE'] > 1: - seq, _ = model.module.decode_beam(**kwargs) + seq, _ = model.decode_beam(**kwargs) # modified else: - seq, _ = model.module.decode(**kwargs) - sents = utils.decode_sequence(self.vocab, seq.data) + seq, _ = model.decode(**kwargs) + sents = utils.decode_sequence(self.tokenizer.idx2token, seq.data) # to check + # sents: [sent (str), ... ] + gold_sents = utils.decode_sequence(self.tokenizer.idx2token, + target_seq.data) # to check target_seq callable + + + # for sid, sent in enumerate(sents): + # result = {ids[sid]: [sent]} + # results.append(result) + # for sid, sent in enumerate(gold_sents): + # g_sents = {ids[sid]: [sent]} + # golden_sents.append(g_sents) + for sid, sent in enumerate(sents): - result = {cfg.INFERENCE.ID_KEY: int(ids[sid]), cfg.INFERENCE.CAP_KEY: sent} - results.append(result) - eval_res = self.evaler.eval(results) + results[ids[sid]] = [sent] # sents: [sent (str), ... ] + # results.append(result) + for sid, sent in enumerate(gold_sents): + golden_sents[ids[sid]] = [sent] + # golden_sents.append(g_sents) + # golden_sents : {image_id, [sent (str) ]} + eval_res = self.evaler(golden_sents, results) result_folder = os.path.join(cfg.ROOT_DIR, 'result') if not os.path.exists(result_folder): os.mkdir(result_folder) - json.dump(results, open(os.path.join(result_folder, 'result_' + rname +'.json'), 'w')) + json.dump(results, + open(os.path.join(result_folder, 'result_' + rname + '.json'), 'w')) # store the generated sentences model.train() - return eval_res \ No newline at end of file + return eval_res diff --git a/experiments/xlan_rl/train.sh b/experiments/xlan_rl/train.sh deleted file mode 100644 index 648cb01..0000000 --- a/experiments/xlan_rl/train.sh +++ /dev/null @@ -1,3 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments/xlan_rl --resume 47 - -# 47 is the epoch number of the pretrained model diff --git a/experiments/xtransformer_rl/train.sh b/experiments/xtransformer_rl/train.sh deleted file mode 100644 index 8ecb209..0000000 --- a/experiments/xtransformer_rl/train.sh +++ /dev/null @@ -1,3 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments/xtransformer_rl --resume 39 - -# 39 is the epoch number of the pretrained model diff --git a/experiments/xlan/config.yml b/experiments_iuxray/xlan/config.yml similarity index 95% rename from experiments/xlan/config.yml rename to experiments_iuxray/xlan/config.yml index eedf3e9..8967371 100644 --- a/experiments/xlan/config.yml +++ b/experiments_iuxray/xlan/config.yml @@ -39,8 +39,8 @@ DATA_LOADER: ############################ MODEL ############################ MODEL: TYPE: 'XLAN' - SEQ_LEN: 17 # include / - VOCAB_SIZE: 9487 # exclude / + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / ########## word embedding ########## WORD_EMBED_DIM: 1024 WORD_EMBED_ACT: 'CELU' @@ -52,7 +52,7 @@ MODEL: GVFEAT_EMBED_ACT: 'NONE' DROPOUT_GV_EMBED: 0.0 ########## attention features ########## - ATT_FEATS_DIM: 2048 + ATT_FEATS_DIM: 1024 # Modified ATT_FEATS_EMBED_DIM: 1024 ATT_FEATS_EMBED_ACT: 'CELU' DROPOUT_ATT_EMBED: 0.5 diff --git a/experiments/xlan/train.sh b/experiments_iuxray/xlan/train.sh similarity index 50% rename from experiments/xlan/train.sh rename to experiments_iuxray/xlan/train.sh index e16d6f5..c498753 100644 --- a/experiments/xlan/train.sh +++ b/experiments_iuxray/xlan/train.sh @@ -1 +1 @@ -CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments/xlan +CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan diff --git a/experiments/xlan_rl/config.yml b/experiments_iuxray/xlan_rl/config.yml similarity index 95% rename from experiments/xlan_rl/config.yml rename to experiments_iuxray/xlan_rl/config.yml index 861da38..931c356 100644 --- a/experiments/xlan_rl/config.yml +++ b/experiments_iuxray/xlan_rl/config.yml @@ -39,8 +39,8 @@ DATA_LOADER: ############################ MODEL ############################ MODEL: TYPE: 'XLAN' - SEQ_LEN: 17 # include / - VOCAB_SIZE: 9487 # exclude / + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / ########## word embedding ########## WORD_EMBED_DIM: 1024 WORD_EMBED_ACT: 'CELU' @@ -52,7 +52,7 @@ MODEL: GVFEAT_EMBED_ACT: 'NONE' DROPOUT_GV_EMBED: 0.0 ########## attention features ########## - ATT_FEATS_DIM: 2048 + ATT_FEATS_DIM: 1024 # Modified ATT_FEATS_EMBED_DIM: 1024 ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE' DROPOUT_ATT_EMBED: 0.5 diff --git a/experiments_iuxray/xlan_rl/train.sh b/experiments_iuxray/xlan_rl/train.sh new file mode 100644 index 0000000..5079ea7 --- /dev/null +++ b/experiments_iuxray/xlan_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47 + +# 47 is the epoch number of the pretrained model diff --git a/experiments/xtransformer/config.yml b/experiments_iuxray/xtransformer/config.yml similarity index 90% rename from experiments/xtransformer/config.yml rename to experiments_iuxray/xtransformer/config.yml index 51d1c6b..3c4e4fd 100644 --- a/experiments/xtransformer/config.yml +++ b/experiments_iuxray/xtransformer/config.yml @@ -3,14 +3,14 @@ SEED: 1546884941.160048 ############################ TRAIN ############################ TRAIN: - BATCH_SIZE: 40 + BATCH_SIZE: 16 #################### REINFORCEMENT #################### REINFORCEMENT: START: 9999 -############################ TEST ############################ +############################ TEST ############################ TEST: - BATCH_SIZE: 36 + BATCH_SIZE: 16 ############################ DATA_LOADER ############################ DATA_LOADER: @@ -27,18 +27,18 @@ DATA_LOADER: TEST_ID: './mscoco/txt/coco_test_image_id.txt' INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' - SEQ_PER_IMG: 5 + SEQ_PER_IMG: 1 MAX_FEAT: 50 ############################ MODEL ############################ MODEL: TYPE: 'XTransformer' - SEQ_LEN: 17 # include / - VOCAB_SIZE: 9487 # exclude / + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 ########## word embedding ########## WORD_EMBED_DIM: 768 WORD_EMBED_ACT: 'CELU' - WORD_EMBED_NORM: False + WORD_EMBED_NORM: False DROPOUT_WORD_EMBED: 0.1 ########## global features ########## GVFEAT_DIM: 2048 @@ -46,7 +46,7 @@ MODEL: GVFEAT_EMBED_ACT: 'NONE' DROPOUT_GV_EMBED: 0.0 ########## attention features ########## - ATT_FEATS_DIM: 2048 + ATT_FEATS_DIM: 1024 # Modified ATT_FEATS_EMBED_DIM: 768 ATT_FEATS_EMBED_ACT: 'CELU' DROPOUT_ATT_EMBED: 0.5 @@ -87,9 +87,9 @@ MODEL: ENCODE_BIFEAT_EMB_DROPOUT: 0.3 DECODE_BIFEAT_EMB_DROPOUT: 0.3 -############################ SOLVER ############################ +############################ SOLVER ############################ SOLVER: - BASE_LR: 0.0005 + BASE_LR: 0.000001 TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' MAX_EPOCH: 70 MAX_ITER: -1 @@ -98,7 +98,7 @@ SOLVER: WEIGHT_DECAY: 0.0000 WEIGHT_DECAY_BIAS: 0.0 BIAS_LR_FACTOR: 1 - DISPLAY: 20 + DISPLAY: 100 TEST_INTERVAL: 1 SNAPSHOT_ITERS: 1 diff --git a/experiments_iuxray/xtransformer/train.sh b/experiments_iuxray/xtransformer/train.sh new file mode 100644 index 0000000..c5b6482 --- /dev/null +++ b/experiments_iuxray/xtransformer/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer --resume 0 + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/experiments_iuxray/xtransformer_VSEGCN/config.yml b/experiments_iuxray/xtransformer_VSEGCN/config.yml new file mode 100644 index 0000000..3c4e4fd --- /dev/null +++ b/experiments_iuxray/xtransformer_VSEGCN/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 100 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray/xtransformer_VSEGCN/train.sh b/experiments_iuxray/xtransformer_VSEGCN/train.sh new file mode 100644 index 0000000..467a7a3 --- /dev/null +++ b/experiments_iuxray/xtransformer_VSEGCN/train.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer_VSEGCN + --submodel VSEGCN \ + --resume 0 + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/experiments_iuxray/xtransformer_rl/config.yml b/experiments_iuxray/xtransformer_rl/config.yml new file mode 100644 index 0000000..659e45e --- /dev/null +++ b/experiments_iuxray/xtransformer_rl/config.yml @@ -0,0 +1,147 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.0 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.1 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.1 + DECODE_BIFEAT_EMB_DROPOUT: 0.1 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000005 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 60 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray/xtransformer_rl/train.sh b/experiments_iuxray/xtransformer_rl/train.sh new file mode 100644 index 0000000..e69c213 --- /dev/null +++ b/experiments_iuxray/xtransformer_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 0 + +# 39 is the epoch number of the pretrained model diff --git a/experiments_iuxray_dwe/xlan/config.yml b/experiments_iuxray_dwe/xlan/config.yml new file mode 100644 index 0000000..8967371 --- /dev/null +++ b/experiments_iuxray_dwe/xlan/config.yml @@ -0,0 +1,148 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 10 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.0005 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments/xtransformer/train.sh b/experiments_iuxray_dwe/xlan/train.sh similarity index 50% rename from experiments/xtransformer/train.sh rename to experiments_iuxray_dwe/xlan/train.sh index 787b9d8..c498753 100644 --- a/experiments/xtransformer/train.sh +++ b/experiments_iuxray_dwe/xlan/train.sh @@ -1 +1 @@ -CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments/xtransformer +CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan diff --git a/experiments_iuxray_dwe/xlan_rl/config.yml b/experiments_iuxray_dwe/xlan_rl/config.yml new file mode 100644 index 0000000..931c356 --- /dev/null +++ b/experiments_iuxray_dwe/xlan_rl/config.yml @@ -0,0 +1,151 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 20 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' # 'RELU', 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.0 + DECODE_ATT_MID_DROPOUT: 0.0 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.00001 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 35 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments_iuxray_dwe/xlan_rl/train.sh b/experiments_iuxray_dwe/xlan_rl/train.sh new file mode 100644 index 0000000..5079ea7 --- /dev/null +++ b/experiments_iuxray_dwe/xlan_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47 + +# 47 is the epoch number of the pretrained model diff --git a/experiments_iuxray_dwe/xtransformer/config.yml b/experiments_iuxray_dwe/xtransformer/config.yml new file mode 100644 index 0000000..2a8a23d --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 12 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 12 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 100 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Step' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.9 + STEP_SIZE: 1 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 1000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray_dwe/xtransformer/train.sh b/experiments_iuxray_dwe/xtransformer/train.sh new file mode 100644 index 0000000..07292c5 --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray_dwe/xtransformer --resume 0 --encoder_mode dualwayencoder --dataset_name IUXRAY --submodel VSEGCN + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml b/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml new file mode 100644 index 0000000..3c4e4fd --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer_VSEGCN/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 100 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh b/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh new file mode 100644 index 0000000..467a7a3 --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer_VSEGCN/train.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer_VSEGCN + --submodel VSEGCN \ + --resume 0 + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/experiments_iuxray_dwe/xtransformer_rl/config.yml b/experiments_iuxray_dwe/xtransformer_rl/config.yml new file mode 100644 index 0000000..659e45e --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer_rl/config.yml @@ -0,0 +1,147 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.0 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.1 + ENCODE_LAYERS: 6 + DECODE_LAYERS: 6 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.1 + DECODE_BIFEAT_EMB_DROPOUT: 0.1 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000005 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 60 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray_dwe/xtransformer_rl/train.sh b/experiments_iuxray_dwe/xtransformer_rl/train.sh new file mode 100644 index 0000000..e69c213 --- /dev/null +++ b/experiments_iuxray_dwe/xtransformer_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 0 + +# 39 is the epoch number of the pretrained model diff --git a/experiments_iuxray_testing/xlan/config.yml b/experiments_iuxray_testing/xlan/config.yml new file mode 100644 index 0000000..8967371 --- /dev/null +++ b/experiments_iuxray_testing/xlan/config.yml @@ -0,0 +1,148 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 10 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.0005 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments_iuxray_testing/xlan/train.sh b/experiments_iuxray_testing/xlan/train.sh new file mode 100644 index 0000000..c498753 --- /dev/null +++ b/experiments_iuxray_testing/xlan/train.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch --nproc_per_node=4 main.py --folder ./experiments_iuxray/xlan diff --git a/experiments_iuxray_testing/xlan_rl/config.yml b/experiments_iuxray_testing/xlan_rl/config.yml new file mode 100644 index 0000000..931c356 --- /dev/null +++ b/experiments_iuxray_testing/xlan_rl/config.yml @@ -0,0 +1,151 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 20 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' # 'RELU', 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.0 + DECODE_ATT_MID_DROPOUT: 0.0 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.00001 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 35 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments_iuxray_testing/xlan_rl/train.sh b/experiments_iuxray_testing/xlan_rl/train.sh new file mode 100644 index 0000000..5079ea7 --- /dev/null +++ b/experiments_iuxray_testing/xlan_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xlan_rl --resume 47 + +# 47 is the epoch number of the pretrained model diff --git a/experiments_iuxray_testing/xtransformer/config.yml b/experiments_iuxray_testing/xtransformer/config.yml new file mode 100644 index 0000000..b6d726f --- /dev/null +++ b/experiments_iuxray_testing/xtransformer/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 760 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 3 + DECODE_LAYERS: 3 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 100 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 3 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_iuxray_testing/xtransformer/train.sh b/experiments_iuxray_testing/xtransformer/train.sh new file mode 100644 index 0000000..873e8e0 --- /dev/null +++ b/experiments_iuxray_testing/xtransformer/train.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 main.py --folder ./experiments_iuxray/xtransformer diff --git a/experiments/xtransformer_rl/config.yml b/experiments_iuxray_testing/xtransformer_rl/config.yml similarity index 95% rename from experiments/xtransformer_rl/config.yml rename to experiments_iuxray_testing/xtransformer_rl/config.yml index 7544a0e..8fb2803 100644 --- a/experiments/xtransformer_rl/config.yml +++ b/experiments_iuxray_testing/xtransformer_rl/config.yml @@ -33,8 +33,8 @@ DATA_LOADER: ############################ MODEL ############################ MODEL: TYPE: 'XTransformer' - SEQ_LEN: 17 # include / - VOCAB_SIZE: 9487 # exclude / + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 9487 # TODO # exclude / ########## word embedding ########## WORD_EMBED_DIM: 768 WORD_EMBED_ACT: 'CELU' @@ -46,7 +46,7 @@ MODEL: GVFEAT_EMBED_ACT: 'NONE' DROPOUT_GV_EMBED: 0.0 ########## attention features ########## - ATT_FEATS_DIM: 2048 + ATT_FEATS_DIM: 1024 # Modified ATT_FEATS_EMBED_DIM: 768 ATT_FEATS_EMBED_ACT: 'CELU' DROPOUT_ATT_EMBED: 0.5 @@ -89,7 +89,7 @@ MODEL: ############################ SOLVER ############################ SOLVER: - BASE_LR: 0.00001 + BASE_LR: 0.000005 TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' MAX_EPOCH: 60 MAX_ITER: -1 diff --git a/experiments_iuxray_testing/xtransformer_rl/train.sh b/experiments_iuxray_testing/xtransformer_rl/train.sh new file mode 100644 index 0000000..c97b3fd --- /dev/null +++ b/experiments_iuxray_testing/xtransformer_rl/train.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_iuxray/xtransformer_rl --resume 39 + +# 39 is the epoch number of the pretrained model diff --git a/experiments_mimiccxr/xlan/config.yml b/experiments_mimiccxr/xlan/config.yml new file mode 100644 index 0000000..e79cc1f --- /dev/null +++ b/experiments_mimiccxr/xlan/config.yml @@ -0,0 +1,148 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 10 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 7863 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.0005 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.5 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Norm' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 2000 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 300 + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments_mimiccxr/xlan/train.sh b/experiments_mimiccxr/xlan/train.sh new file mode 100644 index 0000000..55a8bec --- /dev/null +++ b/experiments_mimiccxr/xlan/train.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=3,2,1,0 python3 -m torch.distributed.launch \ + --nproc_per_node=4 main.py --folder ./experiments_mimiccxr/xlan \ + --dataset_name MIMICCXR \ + --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json \ No newline at end of file diff --git a/experiments_mimiccxr/xlan_rl/config.yml b/experiments_mimiccxr/xlan_rl/config.yml new file mode 100644 index 0000000..1d7cf07 --- /dev/null +++ b/experiments_mimiccxr/xlan_rl/config.yml @@ -0,0 +1,151 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 20 + #################### SCHEDULED_SAMPLING #################### + SCHEDULED_SAMPLING: + START: 6 + INC_EVERY: 5 + INC_PROB: 0.05 + MAX_PROB: 0.5 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 36 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: -1 + +############################ MODEL ############################ +MODEL: + TYPE: 'XLAN' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 7863 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 1024 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.5 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 1024 + ATT_FEATS_EMBED_ACT: 'CELU' # 'RELU', 'NONE' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: False + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' # 'RELU', 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 1024 + ENCODE_ATT_MID_DIM: [128, 64, 128] + DECODE_ATT_MID_DIM: [128, 64, 128] + ENCODE_ATT_MID_DROPOUT: 0.0 + DECODE_ATT_MID_DROPOUT: 0.0 + ATT_DIM: 1024 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 4 + DECODE_LAYERS: 1 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_BLOCK: 'LowRankBilinearEnc' + DECODE_BLOCK: 'LowRankBilinearDec' + ELU_ALPHA: 1.3 + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.00001 + TYPE: 'ADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 35 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'MultiStep', 'Poly', Noam' + GAMMA: 0.8 + STEP_SIZE: 300 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 1024 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'CrossEntropy' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.0 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 3 + GREEDY_DECODE: True diff --git a/experiments_mimiccxr/xlan_rl/train.sh b/experiments_mimiccxr/xlan_rl/train.sh new file mode 100644 index 0000000..6a3862a --- /dev/null +++ b/experiments_mimiccxr/xlan_rl/train.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py \ + --folder ./experiments_mimiccxr/xlan_rl --resume 47 \ + --dataset_name MIMICCXR \ + --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json +# 47 is the epoch number of the pretrained model diff --git a/experiments_mimiccxr/xtransformer/config.yml b/experiments_mimiccxr/xtransformer/config.yml new file mode 100644 index 0000000..a2a7d12 --- /dev/null +++ b/experiments_mimiccxr/xtransformer/config.yml @@ -0,0 +1,144 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 9999 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 1 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 7863 # TODO # exclude / IUXRAY: 760 + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.5 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.5 + ENCODE_LAYERS: 3 + DECODE_LAYERS: 3 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.3 + DECODE_BIFEAT_EMB_DROPOUT: 0.3 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000001 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'RADAM' + MAX_EPOCH: 70 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:0.5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp' , 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 2000 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.98] + EPS: 1.0e-9 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Noam' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 300 # modified + SETP_TYPE: 'Iter' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_mimiccxr/xtransformer/train.sh b/experiments_mimiccxr/xtransformer/train.sh new file mode 100644 index 0000000..066e383 --- /dev/null +++ b/experiments_mimiccxr/xtransformer/train.sh @@ -0,0 +1,7 @@ +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch \ + --nproc_per_node=1 main.py --folder ./experiments_mimiccxr/xtransformer \ + --dataset_name MIMICCXR \ + --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json \ + --submodel VSEGCN --resume 0 + +###if you want to use checkpoint, download your model in experiments_mimiccxr/xtransformer/snapshot and change 0 to your model's number diff --git a/experiments_mimiccxr/xtransformer_rl/config.yml b/experiments_mimiccxr/xtransformer_rl/config.yml new file mode 100644 index 0000000..f5fa64d --- /dev/null +++ b/experiments_mimiccxr/xtransformer_rl/config.yml @@ -0,0 +1,147 @@ +LOGGER_NAME: 'log' +SEED: 1546884941.160048 + +############################ TRAIN ############################ +TRAIN: + BATCH_SIZE: 16 + #################### REINFORCEMENT #################### + REINFORCEMENT: + START: 0 + +############################ TEST ############################ +TEST: + BATCH_SIZE: 16 + +############################ DATA_LOADER ############################ +DATA_LOADER: + NUM_WORKERS: 4 + SHUFFLE: True + TRAIN_GV_FEAT: '' + TRAIN_ATT_FEATS: './mscoco/feature/up_down_100' + VAL_GV_FEAT: '' + VAL_ATT_FEATS: './mscoco/feature/up_down_100' + TEST_GV_FEAT: '' + TEST_ATT_FEATS: './mscoco/feature/up_down_100' + TRAIN_ID: './mscoco/txt/coco_train_image_id.txt' + VAL_ID: './mscoco/txt/coco_val_image_id.txt' + TEST_ID: './mscoco/txt/coco_test_image_id.txt' + INPUT_SEQ_PATH: './mscoco/sent/coco_train_input.pkl' + TARGET_SEQ_PATH: './mscoco/sent/coco_train_target.pkl' + SEQ_PER_IMG: 5 + MAX_FEAT: 50 + +############################ MODEL ############################ +MODEL: + TYPE: 'XTransformer' + SEQ_LEN: 60 # Modified # include / + VOCAB_SIZE: 7863 # TODO # exclude / + ########## word embedding ########## + WORD_EMBED_DIM: 768 + WORD_EMBED_ACT: 'CELU' + WORD_EMBED_NORM: False + DROPOUT_WORD_EMBED: 0.1 + ########## global features ########## + GVFEAT_DIM: 2048 + GVFEAT_EMBED_DIM: -1 + GVFEAT_EMBED_ACT: 'NONE' + DROPOUT_GV_EMBED: 0.0 + ########## attention features ########## + ATT_FEATS_DIM: 1024 # Modified + ATT_FEATS_EMBED_DIM: 768 + ATT_FEATS_EMBED_ACT: 'CELU' + DROPOUT_ATT_EMBED: 0.5 + ATT_FEATS_NORM: True + ########## attention param ########## + ATT_HIDDEN_SIZE: -1 + ATT_HIDDEN_DROP: 0.0 + ATT_ACT: 'TANH' + ########## rnn param ########## + RNN_SIZE: 1024 + DROPOUT_LM: 0.0 + + ########## BOTTOM_UP ########## + BOTTOM_UP: + DROPOUT_FIRST_INPUT: 0.0 + DROPOUT_SEC_INPUT: 0.0 + + ########## BILINEAR ########## + BILINEAR: + DIM: 768 + ENCODE_ATT_MID_DIM: [96, 48, 96] + DECODE_ATT_MID_DIM: [96, 48, 96] + ENCODE_ATT_MID_DROPOUT: 0.1 + DECODE_ATT_MID_DROPOUT: 0.1 + ATT_DIM: 768 + ACT: 'CELU' + ENCODE_DROPOUT: 0.5 + DECODE_DROPOUT: 0.1 + ENCODE_LAYERS: 3 + DECODE_LAYERS: 3 + TYPE: 'LowRank' + ATTTYPE: 'SCAtt' # SCAtt, BasicAtt + HEAD: 8 + ENCODE_FF_DROPOUT: 0.5 + DECODE_FF_DROPOUT: 0.5 + ELU_ALPHA: 1.3 + BIFEAT_EMB_ACT: 'RELU' + ENCODE_BIFEAT_EMB_DROPOUT: 0.1 + DECODE_BIFEAT_EMB_DROPOUT: 0.1 + +############################ SOLVER ############################ +SOLVER: + BASE_LR: 0.000005 + TYPE: 'RADAM' # 'ADAM', 'SGD', 'ADAGRAD', 'RMSPROP' + MAX_EPOCH: 60 + MAX_ITER: -1 + GRAD_CLIP: 0.1 # Norm:5 , Clamp:0.1 + GRAD_CLIP_TYPE: 'Clamp' # 'Clamp', 'Norm' + WEIGHT_DECAY: 0.0000 + WEIGHT_DECAY_BIAS: 0.0 + BIAS_LR_FACTOR: 1 + DISPLAY: 20 + TEST_INTERVAL: 1 + SNAPSHOT_ITERS: 1 + + ########## SGD ########## + SGD: + MOMENTUM: 0.9 + ########## ADAM ########## + ADAM: + BETAS: [0.9, 0.999] + EPS: 1.0e-8 + ########## LR_POLICY ########## + LR_POLICY: + TYPE: 'Plateau' # 'Fix', 'Step', 'Noam', 'Plateau' + GAMMA: 0.8 + STEP_SIZE: 300 + SETP_TYPE: 'Epoch' # 'Epoch', 'Iter' + WARMUP: 10000 # For Noam only + FACTOR: 1.0 # For Noam only + MODEL_SIZE: 768 # For Noam only + + PLATEAU_FACTOR: 0.8 + PLATEAU_PATIENCE: 3 + +############################ LOSSES ############################ +LOSSES: + XE_TYPE: 'LabelSmoothing' # 'CrossEntropy', 'LabelSmoothing' + LABELSMOOTHING: 0.1 + RL_TYPE: 'RewardCriterion' + +############################ SCORER ############################ +SCORER: + TYPES: ['CIDEr'] + WEIGHTS: [1.0] + GT_PATH: './mscoco/misc/coco_train_gts.pkl' + CIDER_CACHED: './mscoco/misc/coco_train_cider.pkl' + +############################ INFERENCE ############################ +INFERENCE: + VOCAB: './mscoco/txt/coco_vocabulary.txt' + ID_KEY: 'image_id' + CAP_KEY: 'caption' + EVAL: 'COCO' + VAL_ANNFILE: './mscoco/misc/captions_val5k.json' + TEST_ANNFILE: './mscoco/misc/captions_test5k.json' + BEAM_SIZE: 2 + GREEDY_DECODE: True diff --git a/experiments_mimiccxr/xtransformer_rl/train.sh b/experiments_mimiccxr/xtransformer_rl/train.sh new file mode 100644 index 0000000..c267ddd --- /dev/null +++ b/experiments_mimiccxr/xtransformer_rl/train.sh @@ -0,0 +1,6 @@ +CUDA_VISIBLE_DEVICES=0 python3 main.py --folder ./experiments_mimiccxr/xtransformer_rl \ + --dataset_name MIMICCXR \ + --image_dir /content/mimic_cxr/images --ann_path /content/mimic_cxr/annotation.json \ + --submodel VSEGCN + +# 39 is the epoch number of the pretrained model diff --git a/layers/sc_att.py b/layers/sc_att.py index 7bc4dea..89afadf 100755 --- a/layers/sc_att.py +++ b/layers/sc_att.py @@ -12,6 +12,19 @@ def __init__(self, mid_dims, mid_dropout): self.attention_last2 = nn.Linear(mid_dims[-2], mid_dims[-1]) def forward(self, att_map, att_mask, value1, value2): + """ + att_map, att_mask, value1, value2 + + torch.Size([4, 8, 49, 96]) + torch.Size([4, 49]) + torch.Size([4, 8, 96]) + torch.Size([4, 8, 49, 96]) + + """ + # print('att_map, att_mask, value1, value2') + # for i in [att_map, att_mask, value1, value2]: + # print(i.shape) + if self.attention_basic is not None: att_map = self.attention_basic(att_map) @@ -37,4 +50,5 @@ def forward(self, att_map, att_mask, value1, value2): value2 = torch.matmul(alpha_spatial.unsqueeze(-2), value2).squeeze(-2) attn = value1 * value2 * alpha_channel + # raise Exception('lol') return attn diff --git a/lib/config.py b/lib/config.py index a8da4fa..5c04958 100755 --- a/lib/config.py +++ b/lib/config.py @@ -87,13 +87,15 @@ # ---------------------------------------------------------------------------- # __C.MODEL = edict() +__C.MODEL.PretrainedImageModel = '/content/image-captioning/model_auc14.dict.gz' # TODO # Modified + __C.MODEL.TYPE = 'UpDown' # 'UpDown', 'XLAN', 'XTransformer' -__C.MODEL.SEQ_LEN = 17 # include / +__C.MODEL.SEQ_LEN = 60 # include / # modified -__C.MODEL.VOCAB_SIZE = 9487 # exclude / +__C.MODEL.VOCAB_SIZE = 760 # exclude / # TODO : IUXRAY: 760 -__C.MODEL.WORD_EMBED_DIM = 1000 +__C.MODEL.WORD_EMBED_DIM = 512 # TODO # Modified __C.MODEL.WORD_EMBED_ACT = 'NONE' # 'RELU', 'CELU', 'NONE' @@ -101,7 +103,7 @@ __C.MODEL.DROPOUT_WORD_EMBED = 0.0 -__C.MODEL.GVFEAT_DIM = 2048 +__C.MODEL.GVFEAT_DIM = 2048 # TODO __C.MODEL.GVFEAT_EMBED_DIM = -1 @@ -109,7 +111,7 @@ __C.MODEL.DROPOUT_GV_EMBED = 0.0 -__C.MODEL.ATT_FEATS_DIM = 2048 +__C.MODEL.ATT_FEATS_DIM = 1024 # Not used. Modified on the init stage of model __C.MODEL.ATT_FEATS_EMBED_DIM = -1 @@ -119,13 +121,13 @@ __C.MODEL.ATT_FEATS_NORM = False -__C.MODEL.ATT_HIDDEN_SIZE = 512 +__C.MODEL.ATT_HIDDEN_SIZE = 512 # TODO __C.MODEL.ATT_HIDDEN_DROP = 0.0 __C.MODEL.ATT_ACT = 'RELU' # 'RELU', 'CELU', 'TANH' -__C.MODEL.RNN_SIZE = 1000 +__C.MODEL.RNN_SIZE = 1000 # TODO __C.MODEL.DROPOUT_LM = 0.5 @@ -304,7 +306,7 @@ __C.INFERENCE.ID_KEY = 'image_id' -__C.INFERENCE.CAP_KEY = 'caption' +__C.INFERENCE.CAP_KEY = 'report' # Modified __C.INFERENCE.EVAL = 'COCO' diff --git a/lib/utils.py b/lib/utils.py index d331c75..8b8c4e7 100755 --- a/lib/utils.py +++ b/lib/utils.py @@ -75,6 +75,7 @@ def decode_sequence(vocab, seq): words = [] for t in range(T): ix = seq[n, t] + ix = ix.item() if ix == 0: break words.append(vocab[ix]) diff --git a/main.py b/main.py index f478dbb..03fe3a6 100755 --- a/main.py +++ b/main.py @@ -16,12 +16,174 @@ import losses import models import datasets +from datasets.radiology_dataset import IUXRAY, MIMICCXR +from datasets.tokenizers import Tokenizer import lib.utils as utils from lib.utils import AverageMeter -from optimizer.optimizer import Optimizer +from optimizer.optimizer import Optimizer, build_optimizer from evaluation.evaler import Evaler from scorer.scorer import Scorer from lib.config import cfg, cfg_from_file +from mlclassifier import GCNClassifier +device = torch.device('cuda') + +def parse_args(): + """ + Parse input arguments + """ + parser = argparse.ArgumentParser(description='Image Captioning') + parser.add_argument('--folder', dest='folder', type=str, default=None) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--resume", type=int, default=-1) + parser.add_argument('--image_dir', type=str, default='/content/iu_xray_resized/images/', + help='the path to the directory containing the data.') + parser.add_argument('--ann_path', type=str, default='/content/iu_xray_resized/annotation.json', + help='the path to the directory containing the data.') + + parser.add_argument('--dataset_name', type=str, default='IUXRAY', choices=['IUXRAY', 'MIMICCXR','MIMICCXR_MultiImages'], + help='the dataset to be used.') + parser.add_argument('--submodel', type=str, default='RGMG', choices=['RGMG', 'VSEGCN'], + help='the knowledge graph to be used.') + # Encoder Mode + parser.add_argument('--encoder_mode', type=str, default='normal', choices=['normal', 'dualwayencoder'], + help='Specify the transformer encoder') + + parser.add_argument('--training_ratio', type = float, default = '1.0', help ='Select the training ratio. Recommend: 0.001, 0.005, 0.01, 0.1, 0.5 and 1.0') + parser.add_argument('--KG_path', type = str, help='the path to the pretrained kg checkpoint') + + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + return args + +args = parse_args() + +if args.submodel =='RGMG' and args.dataset_name =='IUXRAY': + fw_adj = torch.tensor([ + #FOR RGMG on IUXray + [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], dtype=torch.float,device=device) +elif args.submodel == 'VSEGCN' and args.dataset_name =='IUXRAY': + fw_adj = torch.tensor([ +#FOR VSEGCN on IUXray + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08624227881040383, 0.0, 0.0, 0.0, 0.0, 0.08531678128946102, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.01865074607267221, 0.0, 0.2924299133554616, 0.0, 0.0, 0.21304488410089617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17180192556684684, 0.6418055548125825, 0.0, 0.5855658364897064, 0.0, 0.8458489347533727, 0.9602592859311171, 0.4274547554463511, 0.0, 0.7595885904689661, 0.0] , + [0.0, 0.0, 0.01865074607267221, 0.0, 0.0, 0.4009853199098073, 1.5255730949811777, 0.0, 0.08521151259101123, 0.631755218959081, 0.0, 0.752383206747696, 0.0, 0.0, 0.331650626508743, 0.426960806313068, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7569183619130871, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.5610118144338234, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 0.059920020742461305, 0.12193009552667401, 0.0, 0.6916982549261147, 0.0, 0.028649105523448872, 0.0, 1.38484543548606, 0.0, 0.0, 0.0, 0.6395125017555444] , + [0.0, 0.0, 0.2924299133554616, 0.4009853199098073, 0.5610118144338234, 0.0, 1.522720025998771, 0.6039632016294225, 1.1321805681072825, 0.9342837995278563, 1.281557969181883, 0.8673131734216728, 1.2434062032175066, 0.3490255809790961, 0.6754221656115674, 0.4370111421665694, 0.0, 0.0, 0.0, 0.0, 0.40348845012792584, 0.0, 0.0, 0.06928636204125185, 0.0] , + [0.0, 0.0, 0.0, 1.5255730949811777, 0.0, 1.522720025998771, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1794995623878415, 0.0, 0.0, 1.535623430834679, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.6039632016294225, 0.0, 0.0, 0.0, 1.5819475025089174, 0.0, 0.3162811291776416, 0.0, 0.5912241758519928, 0.7710172862925886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8175373019274815, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.21304488410089617, 0.08521151259101123, 0.0, 1.1321805681072825, 0.0, 0.0, 0.0, 0.9061920646608415, 0.0, 0.7391379799976754, 0.0, 0.0, 1.1938741371126225, 0.0952618484445128, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.16859312153006234, 0.0, 1.376195693906577, 0.0] , + [0.0, 0.0, 0.0, 0.631755218959081, 0.0, 0.9342837995278563, 0.0, 1.5819475025089174, 0.9061920646608415, 0.0, 0.9602592859311171, 1.6911467944739096, 0.3242705192111205, 0.39747392323441544, 0.8649491061267922, 0.50827416218806, 0.0, 0.0, 0.0, 0.0, 0.623787049309904, 0.0, 0.021989647338186796, 0.0, 0.46624078048150774] , + [0.0, 0.0, 0.0, 0.0, 0.0, 1.281557969181883, 0.0, 0.0, 0.0, 0.9602592859311171, 0.0, 0.793205201267951, 0.0, 1.068148247942302, 1.535623430834679, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 1.294461374017791, 0.0, 0.0, 0.0, 0.6669114759436587] , + [0.0, 0.0, 0.0, 0.752383206747696, 0.0, 0.8673131734216728, 2.1794995623878415, 0.3162811291776416, 0.7391379799976754, 1.6911467944739096, 0.793205201267951, 0.0, 0.0, 0.007276287257039512, 1.3910422020235713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23358941333252825, 0.0, 0.0, 0.3693909544915901, 0.29918669581834156] , + [0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 1.2434062032175066, 0.0, 0.0, 0.0, 0.3242705192111205, 0.0, 0.0, 0.0, 0.4321594812223054, 0.0, 0.20648748355473706, 0.2166398550187551, 0.0, 0.16826627073453929, 0.0, 0.6584726072977941, 0.0, 0.0, 0.0, 1.4172170703435527] , + [0.0, 0.0, 0.0, 0.0, 0.059920020742461305, 0.3490255809790961, 0.0, 0.5912241758519928, 0.0, 0.39747392323441544, 1.068148247942302, 0.007276287257039512, 0.4321594812223054, 0.0, 0.6161631241992449, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1179703724409795, 0.35302216066358155, 0.0, 0.6443340011659412, 0.0] , + [0.0, 0.0, 0.17180192556684684, 0.331650626508743, 0.12193009552667401, 0.6754221656115674, 1.535623430834679, 0.7710172862925886, 1.1938741371126225, 0.8649491061267922, 1.535623430834679, 1.3910422020235713, 0.0, 0.6161631241992449, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 1.0937906785556397, 0.3096717197899676, 0.0, 1.430262915176853, 0.3484577448251243] , + [0.0, 0.0, 0.6418055548125825, 0.426960806313068, 0.0, 0.4370111421665694, 0.0, 0.0, 0.0952618484445128, 0.50827416218806, 0.0, 0.0, 0.20648748355473706, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.5060040136535207, 0.0, 0.3096717197899676, 0.4186620034983729, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.6916982549261147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2166398550187551, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10158746275990369, 0.029803617870273677, 0.0, 0.0, 0.137502534460031, 0.0, 0.07092804383736119] , + [0.0, 0.08624227881040383, 0.5855658364897064, 0.0, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1461991767058608, 0.845848934753373, 0.0, 0.0, 0.9958502310338198, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.028649105523448872, 0.0, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.0, 0.16826627073453929, 0.0, 0.0, 0.0, 0.10158746275990369, 0.1461991767058608, 0.0, 0.08964361822629091, 0.0034771927022252134, 0.0, 0.08912895017581536, 0.0, 0.06907447518803858] , + [0.0, 0.0, 0.8458489347533727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.029803617870273677, 0.845848934753373, 0.08964361822629091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.9602592859311171, 0.0, 1.38484543548606, 0.40348845012792584, 0.0, 0.8175373019274815, 0.0, 0.623787049309904, 1.294461374017791, 0.23358941333252825, 0.6584726072977941, 2.1179703724409795, 1.0937906785556397, 0.5060040136535207, 0.0, 0.0, 0.0034771927022252134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.4274547554463511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16859312153006234, 0.0, 0.0, 0.0, 0.0, 0.35302216066358155, 0.3096717197899676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.6449325692248835] , + [0.0, 0.08531678128946102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.021989647338186796, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3096717197899676, 0.137502534460031, 0.9958502310338198, 0.08912895017581536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.7595885904689661, 0.7569183619130871, 0.0, 0.06928636204125185, 0.0, 0.0, 1.376195693906577, 0.0, 0.0, 0.3693909544915901, 0.0, 0.6443340011659412, 1.430262915176853, 0.4186620034983729, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.6395125017555444, 0.0, 0.0, 0.0, 0.0, 0.46624078048150774, 0.6669114759436587, 0.29918669581834156, 1.4172170703435527, 0.0, 0.3484577448251243, 0.0, 0.07092804383736119, 0.0, 0.06907447518803858, 0.0, 0.0, 0.6449325692248835, 0.0, 0.0, 0.0] , + ], dtype=torch.float,device=device) +elif args.submodel == 'VSEGCN' and args.dataset_name != 'IUXRAY': + fw_adj = torch.tensor([ +#FOR VSEGCN on mimic + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] , + [0.0, 0.0, 0.554004673298568, 0.3230466430334976, 0.2631825039749924, 0.7016287008480984, 0.24746163509836083, 0.0010015098042039912, 0.5068736479106769, 0.0, 0.9524974820767607, 0.0, 0.274863833501621, 0.41494699762748494, 0.5840689600939755, 0.0, 0.0, 0.1302442384969287, 0.0, 0.6339873741551628, 0.0, 0.0, 0.0, 0.0, 0.13757450134915059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20484800268752174, 0.0, 0.3451355245988348, 0.0, 0.0, 0.31496065589330335] , + [0.0, 0.554004673298568, 0.0, 0.0, 0.8777428349239078, 0.48756689514453816, 0.15609090494487807, 0.0, 0.2680360388191779, 0.0, 0.6199516595911282, 0.42238326753620825, 0.11635380051050957, 0.0020601220064393085, 0.6321690244338739, 0.0, 0.0, 0.08334337914718853, 0.5194350922631013, 0.12854873550846002, 0.6928120918404297, 0.0, 0.0, 0.9837299370816517, 0.2314271028213519, 0.5817251164456204, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33811183663241207, 0.0, 0.13936476949018967, 0.0, 0.0, 0.07089242056885223] , + [0.0, 0.3230466430334976, 0.0, 0.0, 0.49497396267096466, 0.6594324804768135, 0.0, 0.706453671957811, 0.4967156785324145, 0.0, 0.8143326399158364, 0.09215690907124767, 1.0617056232382251, 0.1725070158434512, 0.3677521118316441, 0.0, 0.0, 0.0, 0.0, 0.31306543839429407, 1.8437693102858905, 0.0, 0.0, 0.0, 0.12775432598701517, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.0, 0.037260086847355586, 0.0, 0.7070285247253186, 0.0, 0.0, 0.36112875829496877] , + [0.0, 0.2631825039749924, 0.8777428349239078, 0.49497396267096466, 0.0, 0.5081730413487658, 0.0, 0.0, 0.4345747745875523, 0.0, 0.9271810641740084, 0.0, 0.43551881399664394, 0.0, 0.32510239740035224, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5363277306682555, 0.0, 0.0, 0.35190615068655606, 0.5705232126845626, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6523482076409699, 0.05734284387314562, 0.029039130518496915, 0.0, 0.0, 0.47600439204660566] , + [0.0, 0.7016287008480984, 0.48756689514453816, 0.6594324804768135, 0.5081730413487658, 0.0, 0.5879416440743396, 0.7282368105665861, 0.6655288739610595, 0.0, 0.5829843060689561, 1.1084681450499108, 0.23022619090918064, 0.7869908341567744, 0.9237236161240492, 0.0, 0.0, 0.00679805799268029, 0.3508343765662928, 0.7283807004431154, 1.2142749896595786, 0.0, 0.0, 0.21380809279477983, 0.0, 0.0, 0.0, 0.7384514725987071, 0.036197082879009246, 0.024547912417246583, 0.0, 0.0031487780150971255, 0.0, 0.17255665551694127, 0.0, 0.30464217507912805, 0.5800348339999603] , + [0.0, 0.24746163509836083, 0.15609090494487807, 0.0, 0.0, 0.5879416440743396, 0.0, 0.0, 0.0, 0.0, 0.050047418864644054, 1.1040108892060951, 0.0, 1.0758558236410007, 0.20507481280870768, 0.0, 0.0, 0.20482994095079318, 0.20102931464117388, 0.5609723973768375, 0.5907785254095393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1108353406999546, 2.2722344254125035, 0.6708537856962394, 0.0, 0.0, 0.0, 0.656008131578071, 0.019380062443701017, 0.0] , + [0.0, 0.0010015098042039912, 0.0, 0.706453671957811, 0.0, 0.7282368105665861, 0.0, 0.0, 0.8928317911567151, 0.0, 0.2412417134011427, 1.1756920438484275, 0.45754054092273994, 0.0, 0.0, 0.07862826830847675, 0.0, 0.08084792186159706, 0.17752298133873087, 0.0, 0.23167676395941764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22902820929658516, 0.0, 0.18822158336999106, 1.3653688666139736, 0.0, 0.0, 0.08262146112756505, 0.8291646636467345, 0.3474221602648512, 0.028679991724856257, 0.15823233737507106] , + [0.0, 0.5068736479106769, 0.2680360388191779, 0.4967156785324145, 0.4345747745875523, 0.6655288739610595, 0.0, 0.8928317911567151, 0.0, 0.0, 0.5266439565166569, 0.7732181640322486, 0.8796823806455875, 0.04357377094721733, 0.38483053231749587, 0.0, 0.0, 0.01715279715950329, 0.0, 0.1935198098014637, 0.7726446050833446, 0.0, 0.0, 0.01203069298044543, 0.009614236915435854, 0.0, 0.028445852424580496, 0.0, 0.017029382779510584, 0.1773653331099141, 0.0, 0.0, 0.0, 0.7478386613505874, 0.0, 0.0, 0.3988597091752934] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11819146504902717, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.3441683893946561, 0.0, 0.0, 0.0, 0.056702571766598944, 0.19444256855484526, 0.0, 0.0, 0.034767887372249416, 0.0, 0.05535863446625411, 0.0, 0.0, 0.2006116936209859, 0.0] , + [0.0, 0.9524974820767607, 0.6199516595911282, 0.8143326399158364, 0.9271810641740084, 0.5829843060689561, 0.050047418864644054, 0.2412417134011427, 0.5266439565166569, 0.0, 0.0, 0.27606588231803425, 0.2702763655065547, 0.5155464612930855, 0.5345046760872463, 0.0, 0.0, 0.0, 0.0, 0.6619248249548156, 0.11676497651446016, 0.0, 0.0, 0.012298886029973912, 0.36531125789699864, 0.22014333096829503, 0.0, 0.0, 0.1421597490181444, 0.0, 0.0, 0.5023838505056603, 0.007094738878332813, 0.16482707917210404, 0.12422641272658975, 0.0, 0.16806740202956802] , + [0.0, 0.0, 0.42238326753620825, 0.09215690907124767, 0.0, 1.1084681450499108, 1.1040108892060951, 1.1756920438484275, 0.7732181640322486, 0.0, 0.27606588231803425, 0.0, 0.4571939169349286, 0.44285881406541816, 0.014435627721654973, 0.0, 0.0, 0.0, 0.35645685756909534, 0.3273858184572163, 0.0, 0.0, 0.0, 0.0, 0.22124160687808572, 0.11022312228841263, 0.04830169724837715, 0.0, 0.92753653417296, 0.8948599965191218, 0.15171053363715095, 0.28220534987677603, 0.3001183720922399, 0.5143515680120436, 0.2778525728199598, 0.0, 0.0] , + [0.0, 0.274863833501621, 0.11635380051050957, 1.0617056232382251, 0.43551881399664394, 0.23022619090918064, 0.0, 0.45754054092273994, 0.8796823806455875, 0.0, 0.2702763655065547, 0.4571939169349286, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.448038949914647, 0.0, 0.0, 0.07261956980307784, 0.0, 0.0, 0.07421899390203622, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.594487035114069, 0.0, 0.0, 0.0] , + [0.0, 0.41494699762748494, 0.0020601220064393085, 0.1725070158434512, 0.0, 0.7869908341567744, 1.0758558236410007, 0.0, 0.04357377094721733, 0.0, 0.5155464612930855, 0.44285881406541816, 0.0, 0.0, 0.9595683261315329, 0.0, 0.0, 0.0, 0.0, 2.4365238654909045, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4720717303053189, 0.5727697527229471, 0.0, 0.0, 0.0, 0.9575685889709006, 1.2377447423336991, 0.0, 0.0] , + [0.0, 0.5840689600939755, 0.6321690244338739, 0.3677521118316441, 0.32510239740035224, 0.9237236161240492, 0.20507481280870768, 0.0, 0.38483053231749587, 0.0, 0.5345046760872463, 0.014435627721654973, 0.0, 0.9595683261315329, 0.0, 0.17363312653927349, 0.19895016785061964, 0.0, 0.0, 1.208717109379973, 0.0, 0.015167532069747917, 0.0, 0.06703478683378344, 0.6177554201058338, 0.4406078996794569, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3022525474987056, 0.0884563575795982, 0.0, 1.2708509083599127] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07862826830847675, 0.0, 0.11819146504902717, 0.0, 0.0, 0.0, 0.0, 0.17363312653927349, 0.0, 2.447688346482624, 0.17983209096647654, 0.0, 0.0, 0.0, 0.28338141933933964, 0.0, 0.0, 0.29702177034947524, 0.07753776902134689, 0.007825343616291445, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5708374900647346, 0.0, 0.15510336927054139, 0.3542350803376888, 0.32669031419868527] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.19895016785061964, 2.447688346482624, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.139880360686599, 0.0, 0.1439133818303898, 0.0, 0.0] , + [0.0, 0.1302442384969287, 0.08334337914718853, 0.0, 0.0, 0.00679805799268029, 0.20482994095079318, 0.08084792186159706, 0.01715279715950329, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17983209096647654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4124407401173898, 0.27031693276904833, 0.6962440948684386, 0.01868026565576617, 0.0, 0.058711117532272095, 0.3108414982982819, 0.8917732120346288, 0.5028865181216288, 0.3710258416111474, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.5194350922631013, 0.0, 0.0, 0.3508343765662928, 0.20102931464117388, 0.17752298133873087, 0.0, 0.0, 0.0, 0.35645685756909534, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.33960729056256184, 0.14426353106038403, 0.48295034078748456, 0.04096791345986367, 0.0, 0.21459856249590378, 0.3394630945304, 0.9426984665379273, 0.22121090818098985, 0.3817453264511766, 0.3146902017201184, 0.2525866980495684, 0.041488578522214464, 0.0] , + [0.0, 0.6339873741551628, 0.12854873550846002, 0.31306543839429407, 0.0, 0.7283807004431154, 0.5609723973768375, 0.0, 0.1935198098014637, 0.0, 0.6619248249548156, 0.3273858184572163, 0.0, 2.4365238654909045, 1.208717109379973, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.023645690589139633, 0.15053136759027846, 0.0, 0.24360268371473795, 0.0, 0.8821797270781289, 1.442822630962518, 0.0, 0.21531865324924443] , + [0.0, 0.0, 0.6928120918404297, 1.8437693102858905, 0.5363277306682555, 1.2142749896595786, 0.5907785254095393, 0.23167676395941764, 0.7726446050833446, 0.0, 0.11676497651446016, 0.0, 1.448038949914647, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.0, 0.0, 0.33727296191859424, 0.0, 0.11424727258813801, 0.004340630185881583, 0.0, 0.0, 0.17398826794432173, 0.699350130525858, 0.0, 0.0, 0.4114035987596012, 0.0, 0.8603703454658308, 0.6698830923247899] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.0, 0.0, 0.0, 0.0, 0.015167532069747917, 0.28338141933933964, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.3311971887232972, 0.0, 0.221278095246663, 0.7821347726775432, 0.4309320418653175] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3441683893946561, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.9837299370816517, 0.0, 0.35190615068655606, 0.21380809279477983, 0.0, 0.0, 0.01203069298044543, 0.0, 0.012298886029973912, 0.0, 0.07261956980307784, 0.0, 0.06703478683378344, 0.0, 0.0, 0.4124407401173898, 0.33960729056256184, 0.0, 0.33727296191859424, 0.0, 0.0, 0.0, 0.0, 0.08979034229911677, 0.0, 0.38518142976105196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.13757450134915059, 0.2314271028213519, 0.12775432598701517, 0.5705232126845626, 0.0, 0.0, 0.0, 0.009614236915435854, 0.0, 0.36531125789699864, 0.22124160687808572, 0.0, 0.0, 0.6177554201058338, 0.29702177034947524, 0.0, 0.27031693276904833, 0.14426353106038403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.03024606051820196, 0.9015367412440394, 0.4187448648843476, 0.11315910768263383, 1.2343811011663848, 0.0, 0.24145505239298407] , + [0.0, 0.0, 0.5817251164456204, 0.0, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22014333096829503, 0.11022312228841263, 0.0, 0.0, 0.4406078996794569, 0.07753776902134689, 0.0, 0.6962440948684386, 0.48295034078748456, 0.0, 0.11424727258813801, 0.0, 0.0, 0.08979034229911677, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3231554047178068, 1.2532084879227057, 0.5614893243263129, 0.046137943483739924, 1.540851917651678, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.22902820929658516, 0.028445852424580496, 0.056702571766598944, 0.0, 0.04830169724837715, 0.07421899390203622, 0.0, 0.0, 0.007825343616291445, 0.0, 0.01868026565576617, 0.04096791345986367, 0.0, 0.004340630185881583, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0564784860967394, 0.12567556923304798, 0.06594334312122997, 0.0, 0.0, 0.06688509334837037, 0.36269590043980027, 0.030319915064814112, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.7384514725987071, 0.0, 0.0, 0.0, 0.19444256855484526, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.38518142976105196, 0.0, 0.0, 0.0564784860967394, 0.0, 0.0, 0.0, 0.0, 0.2104088857291676, 0.4561074463243395, 0.0, 0.0, 0.0, 0.6247400008366559] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.036197082879009246, 0.1108353406999546, 0.18822158336999106, 0.017029382779510584, 0.0, 0.1421597490181444, 0.92753653417296, 0.0, 0.4720717303053189, 0.0, 0.0, 0.0, 0.058711117532272095, 0.21459856249590378, 0.023645690589139633, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.0, 0.12567556923304798, 0.0, 0.0, 0.4309948864530749, 0.0, 0.02542734722301152, 0.0, 0.0, 0.2510742769865065, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.024547912417246583, 2.2722344254125035, 1.3653688666139736, 0.1773653331099141, 0.0, 0.0, 0.8948599965191218, 0.0, 0.5727697527229471, 0.0, 0.0, 0.0, 0.3108414982982819, 0.3394630945304, 0.15053136759027846, 0.17398826794432173, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06594334312122997, 0.0, 0.4309948864530749, 0.0, 0.8692491673212553, 0.0, 0.19435111707051575, 0.266221588915103, 0.676971258598654, 0.20262509195680156, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6708537856962394, 0.0, 0.0, 0.034767887372249416, 0.0, 0.15171053363715095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8917732120346288, 0.9426984665379273, 0.0, 0.699350130525858, 0.0, 0.0, 0.0, 0.03024606051820196, 0.3231554047178068, 0.0, 0.0, 0.0, 0.8692491673212553, 0.0, 1.1605097349751408, 0.3461994713504898, 0.0, 1.5082427073912366, 0.0, 0.0] , + [0.0, 0.20484800268752174, 0.33811183663241207, 0.037260086847355586, 0.6523482076409699, 0.0031487780150971255, 0.0, 0.0, 0.0, 0.0, 0.5023838505056603, 0.28220534987677603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5028865181216288, 0.22121090818098985, 0.24360268371473795, 0.0, 0.0, 0.0, 0.0, 0.9015367412440394, 1.2532084879227057, 0.0, 0.2104088857291676, 0.02542734722301152, 0.0, 1.1605097349751408, 0.0, 1.681814915277294, 0.0, 0.43925238577431364, 0.04338409257246071, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.05734284387314562, 0.0, 0.0, 0.08262146112756505, 0.0, 0.05535863446625411, 0.007094738878332813, 0.3001183720922399, 0.0, 0.0, 0.0, 1.5708374900647346, 2.139880360686599, 0.3710258416111474, 0.3817453264511766, 0.0, 0.0, 0.3311971887232972, 0.0, 0.0, 0.4187448648843476, 0.5614893243263129, 0.06688509334837037, 0.4561074463243395, 0.0, 0.19435111707051575, 0.3461994713504898, 1.681814915277294, 0.0, 0.2674633965945188, 0.8740258958015569, 0.2474464019557586, 0.0] , + [0.0, 0.3451355245988348, 0.13936476949018967, 0.7070285247253186, 0.029039130518496915, 0.17255665551694127, 0.0, 0.8291646636467345, 0.7478386613505874, 0.0, 0.16482707917210404, 0.5143515680120436, 0.594487035114069, 0.9575685889709006, 0.3022525474987056, 0.0, 0.0, 0.0, 0.3146902017201184, 0.8821797270781289, 0.4114035987596012, 0.0, 0.0, 0.0, 0.11315910768263383, 0.046137943483739924, 0.36269590043980027, 0.0, 0.0, 0.266221588915103, 0.0, 0.0, 0.2674633965945188, 0.0, 0.42342916886466814, 0.0, 0.33323103097930845] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.656008131578071, 0.3474221602648512, 0.0, 0.0, 0.12422641272658975, 0.2778525728199598, 0.0, 1.2377447423336991, 0.0884563575795982, 0.15510336927054139, 0.1439133818303898, 0.0, 0.2525866980495684, 1.442822630962518, 0.0, 0.221278095246663, 0.0, 0.0, 1.2343811011663848, 1.540851917651678, 0.030319915064814112, 0.0, 0.2510742769865065, 0.676971258598654, 1.5082427073912366, 0.43925238577431364, 0.8740258958015569, 0.42342916886466814, 0.0, 0.0, 0.0] , + [0.0, 0.0, 0.0, 0.0, 0.0, 0.30464217507912805, 0.019380062443701017, 0.028679991724856257, 0.0, 0.2006116936209859, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3542350803376888, 0.0, 0.0, 0.041488578522214464, 0.0, 0.8603703454658308, 0.7821347726775432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20262509195680156, 0.0, 0.04338409257246071, 0.2474464019557586, 0.0, 0.0, 0.0, 0.0] , + [0.0, 0.31496065589330335, 0.07089242056885223, 0.36112875829496877, 0.47600439204660566, 0.5800348339999603, 0.0, 0.15823233737507106, 0.3988597091752934, 0.0, 0.16806740202956802, 0.0, 0.0, 0.0, 1.2708509083599127, 0.32669031419868527, 0.0, 0.0, 0.0, 0.21531865324924443, 0.6698830923247899, 0.4309320418653175, 0.0, 0.0, 0.24145505239298407, 0.0, 0.0, 0.6247400008366559, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33323103097930845, 0.0, 0.0, 0.0] , + ], dtype=torch.float,device=device) +else: + raise Nonetype("There is no this kind of KG or dataset") + + +bw_adj = fw_adj.t() +num_feat = fw_adj.shape[0] - 1 + +submodel = GCNClassifier(num_feat, fw_adj, bw_adj) + +state_dict = submodel.state_dict() + +# Load the kg path from argument + +# if args.submodel =='RGMG' and args.dataset_name =='IUXRAY': +# KG_path = '/content/pretrainedKG/gcnclassifier_v2_ones3_t401v2t3_lr1e-6_e80.pth' +# +# elif args.submodel == 'VSEGCN' and args.dataset_name =='IUXRAY': +# KG_path = '/content/pretrainedKG/iuxray_gcnclassifier_v1_ones3_t0v1t2_lr1e-6_23050521_e180.pth' +# +# elif args.submodel == 'VSEGCN' and args.dataset_name =='MIMICCXR': +# KG_path = '/content/pretrainedKG/mimic_gcnclassifier_v1_ones3_t0v1t2_lr1e-6_24052021_e10.pth' +# else: +# raise Nonetype("There is no this kind of KG or dataset") + +state_dict.update({k:v for k, v in torch.load(args.KG_path).items() if k in state_dict}) + +submodel.load_state_dict(state_dict) + class Trainer(object): def __init__(self, args): @@ -35,28 +197,37 @@ def __init__(self, args): self.num_gpus = torch.cuda.device_count() self.distributed = self.num_gpus > 1 if self.distributed: + print('init distributed') torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) self.device = torch.device("cuda") + # self.device = 'cpu' self.rl_stage = False self.setup_logging() self.setup_dataset() self.setup_network() self.val_evaler = Evaler( - eval_ids = cfg.DATA_LOADER.VAL_ID, - gv_feat = cfg.DATA_LOADER.VAL_GV_FEAT, - att_feats = cfg.DATA_LOADER.VAL_ATT_FEATS, - eval_annfile = cfg.INFERENCE.VAL_ANNFILE - ) + datasets.create(name = args.dataset_name, + image_dir=args.image_dir, + ann_path=args.ann_path, + tokenizer=self.tokenizer, + split='val', + args = args, + ), + tokenizer=self.tokenizer + ) # TODO self.test_evaler = Evaler( - eval_ids = cfg.DATA_LOADER.TEST_ID, - gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT, - att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS, - eval_annfile = cfg.INFERENCE.TEST_ANNFILE - ) + datasets.create(name = args.dataset_name, + image_dir=args.image_dir, + ann_path=args.ann_path, + tokenizer=self.tokenizer, + split='test', + args=args), + tokenizer=self.tokenizer + ) # TODO self.scorer = Scorer() def setup_logging(self): @@ -64,7 +235,7 @@ def setup_logging(self): self.logger.setLevel(logging.INFO) if self.distributed and dist.get_rank() > 0: return - + ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.INFO) formatter = logging.Formatter("[%(levelname)s: %(asctime)s] %(message)s") @@ -73,65 +244,76 @@ def setup_logging(self): if not os.path.exists(cfg.ROOT_DIR): os.makedirs(cfg.ROOT_DIR) - + fh = logging.FileHandler(os.path.join(cfg.ROOT_DIR, cfg.LOGGER_NAME + '.txt')) fh.setLevel(logging.INFO) fh.setFormatter(formatter) self.logger.addHandler(fh) - + self.logger.info('Training with config:') self.logger.info(pprint.pformat(cfg)) def setup_network(self): - model = models.create(cfg.MODEL.TYPE) + # model = models.create(cfg.MODEL.TYPE, args) + model = models.create('XTransformer', args, submodel = submodel) if self.distributed: # this should be removed if we update BatchNorm stats self.model = torch.nn.parallel.DistributedDataParallel( - model.to(self.device), - device_ids = [self.args.local_rank], - output_device = self.args.local_rank, - broadcast_buffers = False + model.to(self.device), + device_ids=[self.args.local_rank], + output_device=self.args.local_rank, + broadcast_buffers=False ) else: - self.model = torch.nn.DataParallel(model).cuda() + # self.model = torch.nn.DataParallel(model).cuda() # strange + self.model = model.cuda() # strange if self.args.resume > 0: self.model.load_state_dict( torch.load(self.snapshot_path("caption_model", self.args.resume), - map_location=lambda storage, loc: storage) + map_location=lambda storage, loc: storage) ) self.optim = Optimizer(self.model) + # self.optim = build_optimizer(args, model) self.xe_criterion = losses.create(cfg.LOSSES.XE_TYPE).cuda() self.rl_criterion = losses.create(cfg.LOSSES.RL_TYPE).cuda() - + def setup_dataset(self): - self.coco_set = datasets.coco_dataset.CocoDataset( - image_ids_path = cfg.DATA_LOADER.TRAIN_ID, - input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH, - target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH, - gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT, - att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS, - seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG, - max_feat_num = cfg.DATA_LOADER.MAX_FEAT - ) + self.tokenizer = Tokenizer(ann_path=args.ann_path, dataset_name=args.dataset_name) + self.dataset = datasets.create(name = args.dataset_name, + image_dir=args.image_dir, + ann_path=args.ann_path, + tokenizer=self.tokenizer, + split='train', + args=args, + ) + # self.coco_set = datasets.coco_dataset.CocoDataset( + # image_ids_path = cfg.DATA_LOADER.TRAIN_ID, + # input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH, + # target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH, + # gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT, + # att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS, + # seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG, + # max_feat_num = cfg.DATA_LOADER.MAX_FEAT + # ) def setup_loader(self, epoch): self.training_loader = datasets.data_loader.load_train( - self.distributed, epoch, self.coco_set) + self.distributed, epoch, self.dataset) def eval(self, epoch): if (epoch + 1) % cfg.SOLVER.TEST_INTERVAL != 0: return None if self.distributed and dist.get_rank() > 0: return None - + val_res = self.val_evaler(self.model, 'val_' + str(epoch + 1)) self.logger.info('######## Epoch (VAL)' + str(epoch + 1) + ' ########') self.logger.info(str(val_res)) - test_res = self.test_evaler(self.model,'test_' + str(epoch + 1)) + test_res = self.test_evaler(self.model, 'test_' + str(epoch + 1)) self.logger.info('######## Epoch (TEST)' + str(epoch + 1) + ' ########') self.logger.info(str(test_res)) @@ -152,11 +334,12 @@ def save_model(self, epoch): snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') if not os.path.exists(snapshot_folder): os.mkdir(snapshot_folder) - torch.save(self.model.state_dict(), self.snapshot_path("caption_model", epoch+1)) + torch.save(self.model.state_dict(), self.snapshot_path("caption_model", epoch + 1)) def make_kwargs(self, indices, input_seq, target_seq, gv_feat, att_feats, att_mask): seq_mask = (input_seq > 0).type(torch.cuda.LongTensor) - seq_mask[:,0] += 1 + # print(seq_mask) + seq_mask[:, 0] += 1 seq_mask_sum = seq_mask.sum(-1) max_len = int(seq_mask_sum.max()) input_seq = input_seq[:, 0:max_len].contiguous() @@ -176,7 +359,7 @@ def scheduled_sampling(self, epoch): if epoch > cfg.TRAIN.SCHEDULED_SAMPLING.START: frac = (epoch - cfg.TRAIN.SCHEDULED_SAMPLING.START) // cfg.TRAIN.SCHEDULED_SAMPLING.INC_EVERY ss_prob = min(cfg.TRAIN.SCHEDULED_SAMPLING.INC_PROB * frac, cfg.TRAIN.SCHEDULED_SAMPLING.MAX_PROB) - self.model.module.ss_prob = ss_prob + # self.model.ss_prob = ss_prob def display(self, iteration, data_time, batch_time, losses, loss_info): if iteration % cfg.SOLVER.DISPLAY != 0: @@ -184,7 +367,7 @@ def display(self, iteration, data_time, batch_time, losses, loss_info): if self.distributed and dist.get_rank() > 0: return info_str = ' (DataTime/BatchTime: {:.3}/{:.3}) losses = {:.5}'.format(data_time.avg, batch_time.avg, losses.avg) - self.logger.info('Iteration ' + str(iteration) + info_str +', lr = ' + str(self.optim.get_lr())) + self.logger.info('Iteration ' + str(iteration) + info_str + ', lr = ' + str(self.optim.get_lr())) for name in sorted(loss_info): self.logger.info(' ' + name + ' = ' + str(loss_info[name])) data_time.reset() @@ -200,6 +383,7 @@ def forward(self, kwargs): gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT] att_feats = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + target_seq = kwargs[cfg.PARAM.TARGET_SENT] # max kwargs['BEAM_SIZE'] = 1 @@ -210,12 +394,12 @@ def forward(self, kwargs): self.model.eval() with torch.no_grad(): - seq_max, logP_max = self.model.module.decode(**kwargs) + seq_max, logP_max = self.model.decode(**kwargs) self.model.train() - rewards_max, rewards_info_max = self.scorer(ids, seq_max.data.cpu().numpy().tolist()) + rewards_max, rewards_info_max = self.scorer(target_seq, seq_max.data.cpu().numpy().tolist()) # Modified rewards_max = utils.expand_numpy(rewards_max) - ids = utils.expand_numpy(ids) + ids = utils.expand_numpy(ids) # to check? gv_feat = utils.expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) @@ -228,12 +412,13 @@ def forward(self, kwargs): kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask seq_sample, logP_sample = self.model.module.decode(**kwargs) - rewards_sample, rewards_info_sample = self.scorer(ids, seq_sample.data.cpu().numpy().tolist()) + rewards_sample, rewards_info_sample = self.scorer(target_seq, + seq_sample.data.cpu().numpy().tolist()) # Modified rewards = rewards_sample - rewards_max rewards = torch.from_numpy(rewards).float().cuda() loss = self.rl_criterion(seq_sample, logP_sample, rewards) - + loss_info = {} for key in rewards_info_sample: loss_info[key + '_sample'] = rewards_info_sample[key] @@ -247,7 +432,7 @@ def train(self): self.optim.zero_grad() iteration = 0 - for epoch in range(cfg.SOLVER.MAX_EPOCH): + for epoch in range(cfg.SOLVER.MAX_EPOCH): if epoch == cfg.TRAIN.REINFORCEMENT.START: self.rl_stage = True self.setup_loader(epoch) @@ -264,16 +449,19 @@ def train(self): gv_feat = gv_feat.cuda() att_feats = att_feats.cuda() att_mask = att_mask.cuda() + # att_mask = torch.ones(16,70).cuda() + # print(att_mask.shape) + kwargs = self.make_kwargs(indices, input_seq, target_seq, gv_feat, att_feats, att_mask) loss, loss_info = self.forward(kwargs) loss.backward() - utils.clip_gradient(self.optim.optimizer, self.model, - cfg.SOLVER.GRAD_CLIP_TYPE, cfg.SOLVER.GRAD_CLIP) + # utils.clip_gradient(self.optim.optimizer, self.model, + # cfg.SOLVER.GRAD_CLIP_TYPE, cfg.SOLVER.GRAD_CLIP) self.optim.step() self.optim.zero_grad() - self.optim.scheduler_step('Iter') - + # self.optim.scheduler_step('Iter') + batch_time.update(time.time() - start) start = time.time() losses.update(loss.item()) @@ -282,31 +470,15 @@ def train(self): if self.distributed: dist.barrier() - + self.save_model(epoch) val = self.eval(epoch) - self.optim.scheduler_step('Epoch', val) - self.scheduled_sampling(epoch) + # self.optim.scheduler_step('Epoch', val) + # self.scheduled_sampling(epoch) if self.distributed: dist.barrier() -def parse_args(): - """ - Parse input arguments - """ - parser = argparse.ArgumentParser(description='Image Captioning') - parser.add_argument('--folder', dest='folder', type=str, default=None) - parser.add_argument("--local_rank", type=int, default=0) - parser.add_argument("--resume", type=int, default=-1) - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(1) - - args = parser.parse_args() - return args - if __name__ == '__main__': args = parse_args() print('Called with args:') diff --git a/mlclassifier.py b/mlclassifier.py new file mode 100644 index 0000000..9223340 --- /dev/null +++ b/mlclassifier.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +class MLClassifier(nn.Module): + + def __init__(self, num_classes): + super().__init__() + self.densenet121 = models.densenet121(pretrained=True) + num_ftrs = self.densenet121.classifier.in_features + # self.backbone = nn.Sequential(*list(densenet.children())[:-1]) + self.densenet121.classifier = nn.Linear(num_ftrs, num_classes) + + def forward(self, img1, img2): + x1 = self.densenet121.features(img1) + x1 = F.relu(x1, inplace=True) + x1 = F.adaptive_avg_pool2d(x1, (1, 1)) + x1 = torch.flatten(x1, 1) + x1 = self.densenet121.classifier(x1) + x2 = self.densenet121.features(img2) + x2 = F.relu(x2, inplace=True) + x2 = F.adaptive_avg_pool2d(x2, (1, 1)) + x2 = torch.flatten(x2, 1) + x2 = self.densenet121.classifier(x2) + return x1 + x2 + + +class ClsAttention(nn.Module): + + def __init__(self, feat_size, num_classes): + super().__init__() + self.feat_size = feat_size + self.num_classes = num_classes + self.channel_w = nn.Conv2d(feat_size, num_classes, 1, bias=False) + + def forward(self, feats): + # feats: batch size x feat size x H x W + batch_size, feat_size, H, W = feats.size() + att_maps = self.channel_w(feats) + att_maps = torch.softmax(att_maps.view(batch_size, self.num_classes, -1), dim=2) + feats_t = feats.view(batch_size, feat_size, H * W).permute(0, 2, 1) + cls_feats = torch.bmm(att_maps, feats_t) + return cls_feats + + +class GCLayer(nn.Module): + + def __init__(self, in_size, state_size): + super().__init__() + self.condense = nn.Conv1d(in_size, state_size, 1, bias=False) + self.condense_norm = nn.BatchNorm1d(state_size) + self.fw_trans = nn.Conv1d(in_size, state_size, 1, bias=False) + self.fw_norm = nn.BatchNorm1d(state_size) + self.bw_trans = nn.Conv1d(in_size, state_size, 1, bias=False) + self.bw_norm = nn.BatchNorm1d(state_size) + self.update = nn.Conv1d(3 * state_size, in_size, 1, bias=False) + self.update_norm = nn.BatchNorm1d(in_size) + self.relu = nn.ReLU(inplace=True) + # v2: + self.dropout = nn.Dropout(0.5) + + def forward(self, states, fw_A, bw_A): + # states: batch size x feat size x nodes + condensed = self.relu(self.condense_norm(self.condense(states))) + fw_msg = self.relu(self.fw_norm(self.fw_trans(states).bmm(fw_A))) + bw_msg = self.relu(self.bw_norm(self.bw_trans(states).bmm(bw_A))) + updated = self.update_norm(self.update(torch.cat((condensed, fw_msg, bw_msg), dim=1))) + updated = self.relu(self.dropout(updated) + states) + return updated + + +class GCN(nn.Module): + + def __init__(self, in_size, state_size, steps=3): + super().__init__() + self.in_size = in_size + self.state_size = state_size + self.steps = steps + + # layers = [] + # for istep in range(steps): + # layers.append(GCLayer(in_size, state_size)) + # self.layers = nn.Sequential(*layers) + self.layer1 = GCLayer(in_size, state_size) + self.layer2 = GCLayer(in_size, state_size) + self.layer3 = GCLayer(in_size, state_size) + + def forward(self, states, fw_A, bw_A): + states = states.permute(0, 2, 1) + states = self.layer1(states, fw_A, bw_A) + states = self.layer2(states, fw_A, bw_A) + states = self.layer3(states, fw_A, bw_A) + return states.permute(0, 2, 1) + + +class GCNClassifier(nn.Module): + + def __init__(self, num_classes, fw_adj, bw_adj): + super().__init__() + self.num_classes = num_classes + self.densenet121 = models.densenet121(pretrained=True) + feat_size = self.densenet121.classifier.in_features + self.densenet121.classifier = nn.Linear(feat_size, num_classes) + self.cls_atten = ClsAttention(feat_size, num_classes) + self.gcn = GCN(feat_size, 256) + # v1: + self.fcs = nn.ModuleList([nn.Linear(feat_size, 1) for _ in range(num_classes)]) + # v2: + self.fc2 = nn.Linear(feat_size, num_classes) + + fw_D = torch.diag_embed(fw_adj.sum(dim=1)) + bw_D = torch.diag_embed(bw_adj.sum(dim=1)) + inv_sqrt_fw_D = fw_D.pow(-0.5) + inv_sqrt_fw_D[torch.isinf(inv_sqrt_fw_D)] = 0 + inv_sqrt_bw_D = bw_D.pow(-0.5) + inv_sqrt_bw_D[torch.isinf(inv_sqrt_bw_D)] = 0 + self.fw_A = inv_sqrt_fw_D.mm(fw_adj).mm(inv_sqrt_bw_D) + self.bw_A = inv_sqrt_bw_D.mm(bw_adj).mm(inv_sqrt_fw_D) + + self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) + + def forward(self, img1, img2 = None): + if img2 is not None: + batch_size = img1.size(0) + fw_A = self.fw_A.repeat(batch_size, 1, 1) + bw_A = self.bw_A.repeat(batch_size, 1, 1) + cnn_feats1 = self.densenet121.features(img1) #no linear layer + cnn_feats2 = self.densenet121.features(img2) #cnn_feats1 torch.Size([16, 1024, 7, 7]) + # print('cnn_feats1',cnn_feats1.shape) + global_feats1 = cnn_feats1.mean(dim=(2, 3)) + global_feats2 = cnn_feats2.mean(dim=(2, 3)) + cls_feats1 = self.cls_atten(cnn_feats1) + cls_feats2 = self.cls_atten(cnn_feats2) + node_feats1 = torch.cat((global_feats1.unsqueeze(1), cls_feats1), dim=1) + node_feats2 = torch.cat((global_feats2.unsqueeze(1), cls_feats2), dim=1) + node_states1 = self.gcn(node_feats1, fw_A, bw_A) + node_states2 = self.gcn(node_feats2, fw_A, bw_A) + # v1: + # logits = img1.new_zeros((batch_size, self.num_classes), dtype=torch.float) + # for c in range(self.num_classes): + # logits[:, c] = self.fcs[c](node_states1[:, c+1] + node_states2[:, c+1]).squeeze(1) + # return logits + # v2: + # global_states = node_states1.mean(dim=1) + node_states2.mean(dim=1) + # global_states = torch.cat((node_states1.mean(dim = 1), node_states2.mean(dim = 1)),dim = 1) + + cnn_feats1_reshaped = cnn_feats1.reshape(cnn_feats1.size(0),cnn_feats1.size(1),-1).permute(0,2,1) + cnn_feats2_reshaped = cnn_feats2.reshape(cnn_feats2.size(0),cnn_feats2.size(1),-1).permute(0,2,1) + + cnn_feats = torch.cat((cnn_feats1_reshaped,cnn_feats2_reshaped),dim = 2) + # logits = self.fc2(global_states) + # print('global_states',global_states.shape) + # print('logits',logits.shape) + node_states = torch.cat((node_states1, node_states2), dim = 2) + + + avg_feats1 = self.avg_fnt(cnn_feats1).squeeze().reshape(-1, cnn_feats1.size(1)) + avg_feats2 = self.avg_fnt(cnn_feats2).squeeze().reshape(-1, cnn_feats1.size(1)) + global_states = torch.cat((avg_feats1, avg_feats2), dim = 1) + # cnn_feats torch.Size([16, 49, 2048]) + # node_states torch.Size([16, 21, 2048]) + # global_states torch.Size([16, 2048]) + if img2 is None: + batch_size = img1.size(0) + fw_A = self.fw_A.repeat(batch_size, 1, 1) + bw_A = self.bw_A.repeat(batch_size, 1, 1) + cnn_feats1 = self.densenet121.features(img1) #no linear layer + # print('cnn_feats1',cnn_feats1.shape) + global_feats1 = cnn_feats1.mean(dim=(2, 3)) + cls_feats1 = self.cls_atten(cnn_feats1) + node_feats1 = torch.cat((global_feats1.unsqueeze(1), cls_feats1), dim=1) + node_states1 = self.gcn(node_feats1, fw_A, bw_A) + + cnn_feats1_reshaped = cnn_feats1.reshape(cnn_feats1.size(0),cnn_feats1.size(1),-1).permute(0,2,1) + cnn_feats = cnn_feats1_reshaped + + node_states = node_states1 + + + avg_feats1 = self.avg_fnt(cnn_feats1).squeeze().reshape(-1, cnn_feats1.size(1)) + global_states = avg_feats1 + + return cnn_feats, node_states, global_states + diff --git a/models/__init__.py b/models/__init__.py index 9f56b54..222b6ef 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,17 +1,22 @@ from models.updown import UpDown from models.xlan import XLAN from models.xtransformer import XTransformer +from models.dwextransformer import DWEXTransformer + __factory = { 'UpDown': UpDown, 'XLAN': XLAN, - 'XTransformer': XTransformer + 'XTransformer': XTransformer, + 'DWEXtransformer' : DWEXTransformer } def names(): return sorted(__factory.keys()) -def create(name, *args, **kwargs): +def create(name, args, submodel, **kwargs): + if name == 'XTransformer' and args.encoder_mode == 'dualwayencoder': + name = 'DWEXtransformer' if name not in __factory: raise KeyError("Unknown caption model:", name) - return __factory[name](*args, **kwargs) \ No newline at end of file + return __factory[name](args, submodel, **kwargs) diff --git a/models/dwextransformer.py b/models/dwextransformer.py new file mode 100644 index 0000000..378b667 --- /dev/null +++ b/models/dwextransformer.py @@ -0,0 +1,774 @@ +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from lib.config import cfg + +from layers.low_rank import LowRank +import blocks +import lib.utils as utils +from models.basic_model import BasicModel +from layers.positional_encoding import PositionalEncoding +from .pretrained_models import ImageClassification +from mlclassifier import GCNClassifier + +device = torch.device("cuda") + + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + + +def clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def makeMask(input_features): + bs, seq, feats = input_features.shape + return torch.ones(bs, seq, dtype=torch.long).cuda() # hardcode to cuda + + +class DWEXTransformer(BasicModel): + def __init__(self, args, submodel): + super(DWEXTransformer, self).__init__() + self.vocab_size = cfg.MODEL.VOCAB_SIZE + 1 + # image pretrained + # self.image_pretrained_models, self.input_visual_feats = ImageClassification.image_features( + # 'densenet', fixed_weight=False, pretrained_model=cfg.MODEL.PretrainedImageModel) + if args.dataset_name == 'IUXRAY': + num_images = 2 + self.get_visual_features = self.forward_iuxray + elif args.dataset_name == 'MIMICCXR': + num_images = 1 + self.get_visual_features = self.forward_mimiccxr + elif args.dataset_name == 'MIMICCXR_MultiImages': + num_images = 2 + self.get_visual_features = self.forward_mimiccxr + + # att_feats encoder + cnn_sequential = [] + self.input_visual_feats = 1024 + cnn_sequential.append(nn.Linear(self.input_visual_feats * num_images, cfg.MODEL.ATT_FEATS_EMBED_DIM)) + cnn_sequential.append(utils.activation(cfg.MODEL.ATT_FEATS_EMBED_ACT)) + if cfg.MODEL.ATT_FEATS_NORM == True: + cnn_sequential.append(nn.LayerNorm(cfg.MODEL.ATT_FEATS_EMBED_DIM)) + if cfg.MODEL.DROPOUT_ATT_EMBED > 0: + cnn_sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED)) + + gcn_sequential = copy.deepcopy(cnn_sequential) + + self.cnn_embed = nn.Sequential(*cnn_sequential) if len(cnn_sequential) > 0 else None + self.gcn_embed = nn.Sequential(*gcn_sequential) if len(gcn_sequential) > 0 else None + + self.encoder = Encoder( + embed_dim=cfg.MODEL.BILINEAR.DIM, + dropout=cfg.MODEL.BILINEAR.ENCODE_DROPOUT, + att_type=cfg.MODEL.BILINEAR.ATTTYPE, + att_heads=cfg.MODEL.BILINEAR.HEAD, + att_mid_dim=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM, + att_mid_drop=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT, + bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, + bifeat_emb_drop=cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT, + ff_dropout=cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT, + layer_num=cfg.MODEL.BILINEAR.ENCODE_LAYERS) + + self.decoder = Decoder( + vocab_size=self.vocab_size, + embed_dim=cfg.MODEL.BILINEAR.DIM, + dropout=cfg.MODEL.BILINEAR.DECODE_DROPOUT, + att_type=cfg.MODEL.BILINEAR.ATTTYPE, + att_heads=cfg.MODEL.BILINEAR.HEAD, + att_mid_dim=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM, + att_mid_drop=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT, + bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, + bifeat_emb_drop=cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT, + ff_dropout=cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT, + layer_num=cfg.MODEL.BILINEAR.DECODE_LAYERS) + self.submodel = submodel + + def forward(self, **kwargs): + # forward entry + + att_feats = kwargs[cfg.PARAM.ATT_FEATS] + seq = kwargs[cfg.PARAM.INPUT_SENT] + att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + # att_mask = torch.ones(16,70).to(device) + att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) + att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) + # HARDCODE: att_mask = None + # Regenerate later + att_mask = None + + ############################################## + seq_mask = (seq > 0).type(torch.cuda.IntTensor) + seq_mask[:, 0] += 1 + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + seq_mask = seq_mask.type(torch.cuda.FloatTensor) + ############################################## + + cnn_feats, gcn_feats = self.get_visual_features(att_feats) + all_feats = torch.cat([cnn_feats, gcn_feats], dim=1) + + att_mask = makeMask(all_feats) + cnn_mask = makeMask(cnn_feats) + gcn_mask = makeMask(gcn_feats) + + cnn_feats = self.cnn_embed(cnn_feats) # forward entry + gcn_feats = self.gcn_embed(gcn_feats) + + gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask) + + decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask) + # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761 + # raise Exception('lol') + return decoder_out + + def forward_iuxray(self, att_feats): + + att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1]) # bs, 49 2048 + + return att_feats, node_feats + + def forward_mimiccxr(self, att_feats): + if self.args.dataset_name == 'mimic_cxr_2images': + att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1]) + else: + # if only one image is inputted. + att_feats, node_feats, fc_feats = self.submodel(att_feats) + + return att_feats, node_feats + + def get_logprobs_state(self, **kwargs): + wt = kwargs[cfg.PARAM.WT] + state = kwargs[cfg.PARAM.STATE] + encoder_out = kwargs[cfg.PARAM.ATT_FEATS] + att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + + gx = kwargs[cfg.PARAM.GLOBAL_FEAT] + p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS] + + if state is None: + ys = wt.unsqueeze(1) + else: + ys = torch.cat([state[0][0], wt.unsqueeze(1)], dim=1) + # ys = torch.zeros(16,1, dtype = int).to(device) + seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze( + 1) + decoder_out = self.decoder(gx, ys[:, -1].unsqueeze(-1), encoder_out, att_mask, seq_mask, p_att_feats, + True).squeeze(1) + + logprobs = F.log_softmax(decoder_out, dim=-1) + return logprobs, [ys.unsqueeze(0)] + + def _expand_state(self, batch_size, beam_size, cur_beam_size, selected_beam): + def fn(s): + shape = [int(sh) for sh in s.shape] + beam = selected_beam + for _ in shape[1:]: + beam = beam.unsqueeze(-1) + beam = beam.long() + s = torch.gather(s.view(*([batch_size, cur_beam_size] + shape[1:])), 1, + beam.expand(*([batch_size, beam_size] + shape[1:]))) + + s = s.view(*([-1, ] + shape[1:])) + return s + + return fn + + # the beam search code is inspired by https://github.com/aimagelab/meshed-memory-transformer + def decode_beam(self, **kwargs): + att_feats = kwargs[cfg.PARAM.ATT_FEATS] + # att_mask = torch.ones(16,70).to(device) + att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + beam_size = kwargs['BEAM_SIZE'] + batch_size = att_feats.size(0) + seq_logprob = torch.zeros((batch_size, 1, 1)).cuda() + log_probs = [] + selected_words = None + seq_mask = torch.ones((batch_size, beam_size, 1)).cuda() + + cnn_feats, gcn_feats = self.get_visual_features(att_feats) + all_feats = torch.cat([cnn_feats, gcn_feats], dim=1) + + att_mask = makeMask(all_feats) + cnn_mask = makeMask(cnn_feats) + gcn_mask = makeMask(gcn_feats) + + cnn_feats = self.cnn_embed(cnn_feats) # forward entry + gcn_feats = self.gcn_embed(gcn_feats) + + gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask) + + p_att_feats = self.decoder.precompute(encoder_out) + + state = None + wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) + kwargs[cfg.PARAM.ATT_FEATS] = encoder_out + kwargs[cfg.PARAM.GLOBAL_FEAT] = gx + kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats + + outputs = [] + self.decoder.init_buffer(batch_size) + for t in range(cfg.MODEL.SEQ_LEN): + cur_beam_size = 1 if t == 0 else beam_size + + kwargs[cfg.PARAM.WT] = wt + kwargs[cfg.PARAM.STATE] = state + word_logprob, state = self.get_logprobs_state(**kwargs) + word_logprob = word_logprob.view(batch_size, cur_beam_size, -1) + candidate_logprob = seq_logprob + word_logprob + + # Mask sequence if it reaches EOS + if t > 0: + mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1) + seq_mask = seq_mask * mask + word_logprob = word_logprob * seq_mask.expand_as(word_logprob) + old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous() + old_seq_logprob[:, :, 1:] = -999 + candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask) + + selected_idx, selected_logprob = self.select(batch_size, beam_size, t, candidate_logprob) + selected_beam = selected_idx / candidate_logprob.shape[-1] + selected_beam = selected_beam.long() + selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] + + self.decoder.apply_to_states(self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam)) + seq_logprob = selected_logprob.unsqueeze(-1) + seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1)) + outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) + outputs.append(selected_words.unsqueeze(-1)) + + this_word_logprob = torch.gather(word_logprob, 1, + selected_beam.unsqueeze(-1).expand(batch_size, beam_size, + word_logprob.shape[-1])) + this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) + log_probs = list( + torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs) + log_probs.append(this_word_logprob) + selected_words = selected_words.view(-1, 1) + wt = selected_words.squeeze(-1) + + if t == 0: + encoder_out = utils.expand_tensor(encoder_out, beam_size) + gx = utils.expand_tensor(gx, beam_size) + att_mask = utils.expand_tensor(att_mask, beam_size) + state[0] = state[0].squeeze(0) + state[0] = utils.expand_tensor(state[0], beam_size) + state[0] = state[0].unsqueeze(0) + + p_att_feats_tmp = [] + for p_feat in p_att_feats: + p_key, p_value2 = p_feat + p_key = utils.expand_tensor(p_key, beam_size) + p_value2 = utils.expand_tensor(p_value2, beam_size) + p_att_feats_tmp.append((p_key, p_value2)) + + kwargs[cfg.PARAM.ATT_FEATS] = encoder_out + kwargs[cfg.PARAM.GLOBAL_FEAT] = gx + kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask + kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp + + seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) + outputs = torch.cat(outputs, -1) + outputs = torch.gather(outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) + log_probs = torch.cat(log_probs, -1) + log_probs = torch.gather(log_probs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) + + outputs = outputs.contiguous()[:, 0] + log_probs = log_probs.contiguous()[:, 0] + + self.decoder.clear_buffer() + return outputs, log_probs + + def decode(self, **kwargs): + beam_size = kwargs['BEAM_SIZE'] + greedy_decode = kwargs['GREEDY_DECODE'] + att_feats = kwargs[cfg.PARAM.ATT_FEATS] + # att_mask = torch.ones(16,70).to(device) + att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + + batch_size = att_feats.shape[0] + cnn_feats, gcn_feats = self.get_visual_features(att_feats) + all_feats = torch.cat([cnn_feats, gcn_feats], dim=1) + + att_mask = makeMask(all_feats) + cnn_mask = makeMask(cnn_feats) + gcn_mask = makeMask(gcn_feats) + + cnn_feats = self.cnn_embed(cnn_feats) # forward entry + gcn_feats = self.gcn_embed(gcn_feats) + + gx, encoder_out = self.encoder(cnn_feats, gcn_feats, cnn_mask, gcn_mask) + + p_att_feats = self.decoder.precompute(encoder_out) + self.decoder.init_buffer(batch_size) + + state = None + sents = Variable(torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda()) + logprobs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda()) + wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda()) + unfinished = wt.eq(wt) + kwargs[cfg.PARAM.ATT_FEATS] = encoder_out + kwargs[cfg.PARAM.GLOBAL_FEAT] = gx + kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats + for t in range(cfg.MODEL.SEQ_LEN): + kwargs[cfg.PARAM.WT] = wt + kwargs[cfg.PARAM.STATE] = state + logprobs_t, state = self.get_logprobs_state(**kwargs) + + if greedy_decode: + logP_t, wt = torch.max(logprobs_t, 1) + else: + probs_t = torch.exp(logprobs_t) + wt = torch.multinomial(probs_t, 1) + logP_t = logprobs_t.gather(1, wt) + wt = wt.view(-1).long() + unfinished = unfinished * (wt > 0) + wt = wt * unfinished.type_as(wt) + sents[:, t] = wt + logprobs[:, t] = logP_t.view(-1) + + if unfinished.sum() == 0: + break + self.decoder.clear_buffer() + return sents, logprobs + + +class Encoder(nn.Module): + def __init__( + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + layer_num + ): + super(Encoder, self).__init__() + self.att_heads = att_heads + self.layers = nn.ModuleList([]) + for i in range(layer_num): + sublayer = EncoderLayer( + embed_dim=embed_dim, + dropout=dropout, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop, + bifeat_emb_act=bifeat_emb_act, + bifeat_emb_drop=bifeat_emb_drop, + ff_dropout=ff_dropout) + self.layers.append(sublayer) + assert embed_dim % 2 == 0 + half_embed_dim = int(embed_dim / 2) + self.cnn_proj_norm = nn.Sequential( + nn.Linear(embed_dim * (layer_num + 1), half_embed_dim), + torch.nn.LayerNorm(half_embed_dim)) + self.gcn_proj_norm = nn.Sequential( + nn.Linear(embed_dim * (layer_num + 1), half_embed_dim), + torch.nn.LayerNorm(half_embed_dim)) + + # def forward(self, x, mask): + # # we try not to use the mask + # gx = (torch.sum(x * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) + + def forward(self, cnn_feats, gcn_feats, cnn_mask, gcn_mask): + + # we try not to use the mask + + # drop mask + + # if att_masks is None: + # att_masks = x.new_ones(x.shape[:2], dtype=torch.long) + # att_masks = att_masks.unsqueeze(-2) + # print(x.shape) + # print(att_masks.shape) + cnn_gx = (torch.sum(cnn_feats * cnn_mask.unsqueeze(-1), 1) / torch.sum(cnn_mask.unsqueeze(-1), 1)) + gcn_gx = (torch.sum(gcn_feats * gcn_mask.unsqueeze(-1), 1) / torch.sum(gcn_mask.unsqueeze(-1), 1)) + + cnn_gx_arr = [cnn_gx] + gcn_gx_arr = [gcn_gx] + for layer in self.layers: + cnn_gx, gcn_gx, cnn_feats, gcn_feats = layer(cnn_gx, gcn_gx, cnn_feats, gcn_feats, cnn_mask, + gcn_mask) # modified + cnn_gx_arr.append(cnn_gx) + gcn_gx_arr.append(gcn_gx) + + print(cnn_gx_arr[0].shape) + print(gcn_gx_arr[0].shape) + + cnn_gx = torch.cat(cnn_gx_arr, dim=-1) # cat dim? + + cnn_gx = self.cnn_proj_norm(cnn_gx) + + gcn_gx = torch.cat(gcn_gx_arr, dim=-1) + + gcn_gx = self.gcn_proj_norm(gcn_gx) # TODO + + gx = torch.cat([cnn_gx, gcn_gx], dim=-1) # TODO: cat dim + x = torch.cat([cnn_feats, gcn_feats], dim=1) # cat at seq + + return gx, x + + +class EncoderLayer(nn.Module): + def __init__( + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout + ): + super(EncoderLayer, self).__init__() + self.encoder_attn = clones(LowRank( + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop), 4) + self.dropout = clones(nn.Dropout(dropout), 4) + + self.bifeat_emb = clones(nn.Sequential( + nn.Linear(2 * embed_dim, embed_dim), + utils.activation(bifeat_emb_act), + nn.Dropout(bifeat_emb_drop) + ), 4) + + self.layer_norm = clones(torch.nn.LayerNorm(embed_dim), 4) + + self.ff_layer = clones(blocks.create( + 'FeedForward', + embed_dim=embed_dim, + ffn_embed_dim=embed_dim * 4, + relu_dropout=ff_dropout, + dropout=ff_dropout), 4) + + def forward(self, cnn_gx, gcn_gx, cnn_feats, gcn_feats, cnn_mask, gcn_mask): + """ + gx : torch.Size([4, 768]) + x : torch.Size([4, 49, 768]) + mask : torch.Size([4, 49]) + + """ + assert self.ff_layer != None + + cnn_gx = self.dropout[0]( + self.encoder_attn[0](query=cnn_gx, key=cnn_feats, mask=cnn_mask, value1=cnn_gx, value2=cnn_feats)) + cnn_x_ = torch.cat([cnn_gx.unsqueeze(1).expand_as(cnn_feats), cnn_feats], dim=-1) + cnn_feats = self.ff_layer[0](self.layer_norm[0](self.bifeat_emb[0](cnn_x_) + cnn_feats)) + + gcn_gx = self.dropout[1]( + self.encoder_attn[1](query=gcn_gx, key=gcn_feats, mask=gcn_mask, value1=gcn_gx, value2=gcn_feats)) + gcn_x_ = torch.cat([gcn_gx.unsqueeze(1).expand_as(gcn_feats), gcn_feats], dim=-1) + gcn_feats = self.ff_layer[1](self.layer_norm[1](self.bifeat_emb[1](gcn_x_) + gcn_feats)) + + all_feats = torch.cat([cnn_feats, gcn_feats], dim=1) + all_mask = torch.cat([cnn_mask, gcn_mask], dim=-1) + + cnn_gx = self.dropout[2]( + self.encoder_attn[2](query=cnn_gx, key=all_feats, mask=all_mask, value1=cnn_gx, value2=all_feats)) + cnn_x_ = torch.cat([cnn_gx.unsqueeze(1).expand_as(cnn_feats), cnn_feats], dim=-1) + cnn_feats = self.ff_layer[2](self.layer_norm[2](self.bifeat_emb[2](cnn_x_) + cnn_feats)) + + gcn_gx = self.dropout[3]( + self.encoder_attn[3](query=gcn_gx, key=all_feats, mask=all_mask, value1=gcn_gx, value2=all_feats)) + gcn_x_ = torch.cat([gcn_gx.unsqueeze(1).expand_as(gcn_feats), gcn_feats], dim=-1) + gcn_feats = self.ff_layer[3](self.layer_norm[3](self.bifeat_emb[3](gcn_x_) + gcn_feats)) + + return cnn_gx, gcn_gx, cnn_feats, gcn_feats + + +class Decoder(nn.Module): + def __init__( + self, + vocab_size, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + layer_num + ): + super(Decoder, self).__init__() + self.att_heads = att_heads + self.layers = nn.ModuleList([]) + self.embed_dim = embed_dim + for i in range(layer_num): + sublayer = DecoderLayer( + embed_dim=embed_dim, + dropout=dropout, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop, + bifeat_emb_act=bifeat_emb_act, + bifeat_emb_drop=bifeat_emb_drop, + ff_dropout=ff_dropout, + last_layer=(i == layer_num - 1)) + self.layers.append(sublayer) + + self.dropout = nn.Dropout(cfg.MODEL.DROPOUT_WORD_EMBED) + self.embed_tokens = nn.Embedding(vocab_size, embed_dim) + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = PositionalEncoding( + embed_dim, cfg.MODEL.TRANSFORMER.PE_MAX_LEN + ) + + self.layer_norm_word = torch.nn.LayerNorm(embed_dim) + self.generator = nn.Linear(embed_dim, vocab_size) + + self.wbil1 = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + utils.activation(cfg.MODEL.BILINEAR.ACT), + torch.nn.LayerNorm(embed_dim) + ) + self.wbil2 = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + utils.activation(cfg.MODEL.BILINEAR.ACT), + torch.nn.LayerNorm(embed_dim) + ) + self.wbi_drop = nn.Dropout(cfg.MODEL.BILINEAR.DECODE_DROPOUT) + self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM) + + self.proj_norm = nn.Sequential( + nn.Linear(embed_dim * (layer_num + 1), 2 * embed_dim), + nn.GLU(), + torch.nn.LayerNorm(embed_dim)) + + self.clear_buffer() + + def init_buffer(self, batch_size): + self.seq_len = 0 + self.x = torch.zeros((batch_size, 1, self.embed_dim)).cuda() + for layer in self.layers: + layer.init_buffer(batch_size) + + def clear_buffer(self): + self.seq_len = None + self.x = None + for layer in self.layers: + layer.clear_buffer() + + def apply_to_states(self, fn): + self.x = fn(self.x) + for layer in self.layers: + layer.apply_to_states(fn) + + def precompute(self, encoder_out): + p_att_feats = [] + for layer in self.layers: + key, value2 = layer.precompute(encoder_out) + p_att_feats.append((key, value2)) + return p_att_feats + + def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, p_att_feats=None, precompute=False): + # device = torch.device("cuda") + att_mask = att_mask.unsqueeze(1) + + # embed positions + seq_len = prev_output_tokens.size(1) + if self.seq_len is not None: + seq_len = self.seq_len + seq_len + self.seq_len = seq_len + positions = self.embed_positions(seq_len)[:, -1, :].unsqueeze(1) + else: + positions = self.embed_positions(seq_len) + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + x = x + positions + x = self.layer_norm_word(x) + if self.dropout is not None: + x = self.dropout(x) + + # decoder layers + gx = self.wbil1(gx) + if self.x is None: + x_gx = (torch.sum(x.unsqueeze(1) * seq_mask.unsqueeze(-1), -2) / torch.sum(seq_mask, -1).unsqueeze(-1)) + else: + self.x = self.x + x + x_gx = self.x / seq_len + x_gx = self.wbil2(x_gx) + gx = gx.unsqueeze(1) + gx = gx * x_gx + gx = self.wbi_drop(gx) + + # print(gx.shape) + # print(type(gx)) + # print((gx[0, 0, 0].dtype)) + + bs, att_nums, features_sizes = gx.shape + + # gx_arr = torch.zeros((len(self.layers), bs, att_nums, features_sizes), dtype = torch.float32, device = device)# .to(torch.device("cuda")) # HARD CODE + # print(gx_arr.shape) + # print(gx_arr.shape) + # gx_arr = gx_arr.to(device) + + gx_arr = [gx] # .to('cuda') + # gx_arr = + + for layerid, layer in enumerate(self.layers): + # print('layerid',layerid) + if precompute == False: + p_key = None + p_value2 = None + else: + p_key, p_value2 = p_att_feats[layerid] + # print('x.shape b4',x.shape) + # print('precompute',precompute) + # print('att_mask.shape',att_mask.shape) + # print('seq_mask.shape',seq_mask.shape) + gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2, + precompute=precompute) + # print('gx.shape',gx.shape) # bs, 41, 768 + # print('x.shape after',x.shape) + # print(type(gx_arr)) + gx_arr.append(gx) + # print('test',gx_arr[layerid, :, :, :].shape) + # print('layerid',layerid) + # print('gx_arr.shape',gx_arr.shape) + # print('gx.device',gx.device) + # print('gx_arr.device',gx_arr.device) + # gx_arr[layerid] = gx + # raise Exception('decoder lol') + gx = torch.cat(gx_arr, dim=-1) + + gx = self.proj_norm(gx) + + gx = self.dropout_lm(gx) + out = self.generator(gx) + # raise Exception('decoder lol') + return out + + +class DecoderLayer(nn.Module): + def __init__( + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + last_layer=False + ): + super(DecoderLayer, self).__init__() + self.last_layer = last_layer + self.word_attn = LowRank( + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop) + self.word_dropout = nn.Dropout(dropout) + + self.cross_att = LowRank( + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop) + self.cross_dropout = nn.Dropout(dropout) + self.layer_norm_cross = torch.nn.LayerNorm(embed_dim) + + if self.last_layer == False: + self.bifeat_emb = nn.Sequential( + nn.Linear(2 * embed_dim, embed_dim), + utils.activation(bifeat_emb_act), + nn.Dropout(bifeat_emb_drop) + ) + self.layer_norm_x = torch.nn.LayerNorm(embed_dim) + + self.ff_layer = blocks.create( + 'FeedForward', + embed_dim=embed_dim, + ffn_embed_dim=embed_dim * 4, + relu_dropout=ff_dropout, + dropout=ff_dropout) + + self.layer_norm_gx = torch.nn.LayerNorm(embed_dim) + + def apply_to_states(self, fn): + self.word_attn.apply_to_states(fn) + + def init_buffer(self, batch_size): + self.word_attn.init_buffer(batch_size) + + def clear_buffer(self): + self.word_attn.clear_buffer() + + def precompute(self, encoder_out): + key, value2 = self.cross_att.precompute(encoder_out, encoder_out) + return key, value2 + + def forward( + self, + gx, + x, + encoder_out, + att_mask, + seq_mask, + p_key=None, + p_value2=None, + precompute=False + ): + word_x = x + residual = x + x = self.word_attn.forward2( + query=gx, + key=x, + mask=seq_mask, + value1=gx, + value2=x) + x = self.word_dropout(x) + + x = residual + x + + residual = x + x = self.layer_norm_cross(x) + x = self.cross_att.forward2( + query=x, + key=encoder_out if precompute == False else p_key, + mask=att_mask, + value1=x, + value2=encoder_out if precompute == False else p_value2, + precompute=precompute) + x = self.cross_dropout(x) + gx = residual + x + gx = self.layer_norm_gx(gx) + + if self.last_layer == False: + x_ = torch.cat([gx, word_x], dim=-1) + x = self.bifeat_emb(x_) + word_x + x = self.layer_norm_x(x) + + if self.ff_layer is not None: + x = self.ff_layer(x) + else: + x = None + + return gx, x diff --git a/models/pretrained_models.py b/models/pretrained_models.py new file mode 100644 index 0000000..480f6a6 --- /dev/null +++ b/models/pretrained_models.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import gzip +import torch +from torch import relu +from torch.nn import Dropout, Linear, Sequential +from torch.nn.functional import adaptive_avg_pool2d, cross_entropy +from torchvision import models + + +# # Image processes +# if image_model is None: +# image_model = 'densenet' +# self.image_feats, image_dim = ImageClassification.image_features(image_model, not finetune_image, True, +# image_pretrained, device) + +class ImageClassification(torch.nn.Module): + def __init__(self, model, num_labels, num_classes, multi_image=1, dropout=0.0, pretrained=True): + super(ImageClassification, self).__init__() + self.image_feats, self.image_dim = self.image_features(model, False, pretrained) + for i in range(num_labels): + setattr(self, 'linear{0}'.format(i), Linear(self.image_dim, num_classes)) + self.num_labels = num_labels + self.multi_image = multi_image + self.dropout = Dropout(p=dropout) + + @classmethod + def fix_layers(cls, model): + for param in model.parameters(): + param.requires_grad = False + + @classmethod + def image_features(cls, name, fixed_weight=False, pretrained=True, pretrained_model=None, device='gpu'): + if pretrained_model is None: + if name == 'densenet121' or name == 'densenet': + m = models.densenet121(pretrained=pretrained) + if fixed_weight: + cls.fix_layers(m) + return Sequential(*list(m.features.children())), 1024 + elif name == 'resnet50': + m = models.resnet50(pretrained=pretrained) + if fixed_weight: + cls.fix_layers(m) + return Sequential(*list(m.children())[:-2]), 2048 + elif name == 'resnet152' or name == 'resnet': + m = models.resnet152(pretrained=pretrained) + if fixed_weight: + cls.fix_layers(m) + return Sequential(*list(m.children())[:-2]), 2048 + elif name == 'vgg19' or name == 'vgg': + m = models.vgg19(pretrained=pretrained) + if fixed_weight: + cls.fix_layers(m) + return Sequential(*list(m.features.children())[:-1]), 512 + else: + raise ValueError('Unknown model {0}'.format(name)) + else: + d = torch.device('cpu')if device == 'cpu' else torch.device('cuda:0') + with gzip.open(pretrained_model) as f: + state = torch.load(f, map_location=d) + m = ImageClassification(name, 14, 3, pretrained=False) + m.load_state_dict(state['model']) + if fixed_weight: + cls.fix_layers(m) + return m.image_feats, m.image_dim + + def deflatten_image(self, x): + if self.multi_image > 1: + x = x.view(int(x.shape[0] / self.multi_image), self.multi_image, x.shape[1]) + x, _ = torch.max(x, dim=1) + return x + + def flatten_image(self, x): + if self.multi_image > 1: + return x.flatten(start_dim=0, end_dim=1) + else: + return x + + def forward(self, x): + x = self.flatten_image(x) + x = self.image_feats(x) + x = relu(x) + x = adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.deflatten_image(x) + xs = [] + for i in range(self.num_labels): + xi = self.dropout(x) + xi = getattr(self, 'linear{0}'.format(i))(xi).unsqueeze(dim=2) + xs.append(xi) + x = torch.cat(xs, dim=2) + return x \ No newline at end of file diff --git a/models/xlan.py b/models/xlan.py index 109dcf2..6d43815 100755 --- a/models/xlan.py +++ b/models/xlan.py @@ -7,7 +7,7 @@ import blocks class XLAN(AttBasicModel): - def __init__(self): + def __init__(self, args): super(XLAN, self).__init__() self.num_layers = 2 @@ -51,7 +51,7 @@ def Forward(self, **kwargs): att, _ = self.attention(h_att, att_feats, att_mask, p_att_feats, precompute=True) ctx_input = torch.cat([att, h_att], 1) - output = self.att2ctx(ctx_input) + output = self.att2ctx(ctx_input) # forward entry state = [torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))] return output, state \ No newline at end of file diff --git a/models/xtransformer.py b/models/xtransformer.py index 5179608..9ec09d4 100755 --- a/models/xtransformer.py +++ b/models/xtransformer.py @@ -12,6 +12,9 @@ import lib.utils as utils from models.basic_model import BasicModel from layers.positional_encoding import PositionalEncoding +from .pretrained_models import ImageClassification +from mlclassifier import GCNClassifier +device = torch.device("cuda") def subsequent_mask(size): "Mask out subsequent positions." @@ -19,71 +22,129 @@ def subsequent_mask(size): subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') return torch.from_numpy(subsequent_mask) == 0 + class XTransformer(BasicModel): - def __init__(self): + def __init__(self, args, submodel): super(XTransformer, self).__init__() self.vocab_size = cfg.MODEL.VOCAB_SIZE + 1 + # image pretrained + # self.image_pretrained_models, self.input_visual_feats = ImageClassification.image_features( + # 'densenet', fixed_weight=False, pretrained_model=cfg.MODEL.PretrainedImageModel) + if args.dataset_name == 'IUXRAY': + num_images = 2 + self.get_visual_features = self.forward_iuxray + else: # 'MIMICCXR' + num_images = 1 + self.get_visual_features = self.forward_mimiccxr # att_feats encoder + self.input_visual_feats = 1024 sequential = [] - sequential.append(nn.Linear(cfg.MODEL.ATT_FEATS_DIM, cfg.MODEL.ATT_FEATS_EMBED_DIM)) + sequential.append(nn.Linear(self.input_visual_feats * num_images, cfg.MODEL.ATT_FEATS_EMBED_DIM)) sequential.append(utils.activation(cfg.MODEL.ATT_FEATS_EMBED_ACT)) if cfg.MODEL.ATT_FEATS_NORM == True: sequential.append(nn.LayerNorm(cfg.MODEL.ATT_FEATS_EMBED_DIM)) if cfg.MODEL.DROPOUT_ATT_EMBED > 0: - sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED)) + sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED)) + self.att_embed = nn.Sequential(*sequential) if len(sequential) > 0 else None self.encoder = Encoder( - embed_dim = cfg.MODEL.BILINEAR.DIM, - dropout = cfg.MODEL.BILINEAR.ENCODE_DROPOUT, - att_type = cfg.MODEL.BILINEAR.ATTTYPE, - att_heads = cfg.MODEL.BILINEAR.HEAD, - att_mid_dim = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM, - att_mid_drop = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT, - bifeat_emb_act = cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, - bifeat_emb_drop = cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT, - ff_dropout = cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT, - layer_num = cfg.MODEL.BILINEAR.ENCODE_LAYERS) - - self.decoder = Decoder( - vocab_size = self.vocab_size, - embed_dim = cfg.MODEL.BILINEAR.DIM, - dropout = cfg.MODEL.BILINEAR.DECODE_DROPOUT, - att_type = cfg.MODEL.BILINEAR.ATTTYPE, - att_heads = cfg.MODEL.BILINEAR.HEAD, - att_mid_dim = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM, - att_mid_drop = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT, - bifeat_emb_act = cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, - bifeat_emb_drop = cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT, - ff_dropout = cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT, - layer_num = cfg.MODEL.BILINEAR.DECODE_LAYERS) + embed_dim=cfg.MODEL.BILINEAR.DIM, + dropout=cfg.MODEL.BILINEAR.ENCODE_DROPOUT, + att_type=cfg.MODEL.BILINEAR.ATTTYPE, + att_heads=cfg.MODEL.BILINEAR.HEAD, + att_mid_dim=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM, + att_mid_drop=cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT, + bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, + bifeat_emb_drop=cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT, + ff_dropout=cfg.MODEL.BILINEAR.ENCODE_FF_DROPOUT, + layer_num=cfg.MODEL.BILINEAR.ENCODE_LAYERS) + + self.decoder = Decoder( + vocab_size=self.vocab_size, + embed_dim=cfg.MODEL.BILINEAR.DIM, + dropout=cfg.MODEL.BILINEAR.DECODE_DROPOUT, + att_type=cfg.MODEL.BILINEAR.ATTTYPE, + att_heads=cfg.MODEL.BILINEAR.HEAD, + att_mid_dim=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM, + att_mid_drop=cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT, + bifeat_emb_act=cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT, + bifeat_emb_drop=cfg.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT, + ff_dropout=cfg.MODEL.BILINEAR.DECODE_FF_DROPOUT, + layer_num=cfg.MODEL.BILINEAR.DECODE_LAYERS) + self.submodel = submodel def forward(self, **kwargs): + # forward entry + att_feats = kwargs[cfg.PARAM.ATT_FEATS] seq = kwargs[cfg.PARAM.INPUT_SENT] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + # att_mask = torch.ones(16,70).to(device) att_mask = utils.expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG) att_feats = utils.expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG) ############################################## seq_mask = (seq > 0).type(torch.cuda.IntTensor) - seq_mask[:,0] += 1 + seq_mask[:, 0] += 1 seq_mask = seq_mask.unsqueeze(-2) seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) seq_mask = seq_mask.type(torch.cuda.FloatTensor) ############################################## - - att_feats = self.att_embed(att_feats) + # print('att_feats.shape b4 pretrain',att_feats.shape) + # if len(att_feats.shape) == 5: + # batch_size = att_feats.shape[0] + # img_size = att_feats.shape[1] + # channel_size = att_feats.shape[2] + # hidden_size1 = att_feats.shape[3] + # hidden_size2 = att_feats.shape[4] + # att_feats = att_feats.view(batch_size, -1, hidden_size1, hidden_size2) + + # FEED IUXRAY: two images + # att_feats_0 = self.image_pretrained_models(att_feats[:, 0]) + # att_feats_1 = self.image_pretrained_models(att_feats[:, 1]) + # att_feats = torch.cat((att_feats_0, att_feats_1), dim=1) # shape (bs, 2048, 7, 7) + att_feats = self.get_visual_features(att_feats) + batch_size, feat_size, _,_ = att_feats.shape + att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) + # print('att_feats.shape after pretrain', att_feats.shape) + att_feats = self.att_embed(att_feats) # forward entry + # print('att_feats.shape after pretrain', att_feats.shape) + gx, encoder_out = self.encoder(att_feats, att_mask) + # print(gx.shape, encoder_out.shape) + decoder_out = self.decoder(gx, seq, encoder_out, att_mask, seq_mask) + # print('decoder_out.shape',decoder_out.shape) # 4, 41, 761 + # raise Exception('lol') return decoder_out + def forward_iuxray(self, att_feats): + + att_feats, node_feats, fc_feats = self.submodel(att_feats[:, 0], att_feats[:, 1]) #bs, 49 2048 + att_feats = torch.cat((att_feats, node_feats), dim = 1)#Gcn+cnn +# att_feats = att_feats#cnn only +# att_feats = node_feats#gcn only + att_feats = att_feats.permute(0,2,1) + att_feats = att_feats.unsqueeze(-1) # bs, 2048,74,1 + return att_feats + + def forward_mimiccxr(self, att_feats): + att_feats, node_feats, fc_feats = self.submodel(att_feats) + att_feats = torch.cat((att_feats, node_feats), dim = 1) #torch.Size([16, 70, 2048]) +# att_feats = att_feats#cnn only +# att_feats = node_feats#gcn only + att_feats = att_feats.permute(0,2,1) + att_feats = att_feats.unsqueeze(-1) # bs, 2048,74,1 + return att_feats + def get_logprobs_state(self, **kwargs): wt = kwargs[cfg.PARAM.WT] state = kwargs[cfg.PARAM.STATE] encoder_out = kwargs[cfg.PARAM.ATT_FEATS] att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] + gx = kwargs[cfg.PARAM.GLOBAL_FEAT] p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS] @@ -91,9 +152,13 @@ def get_logprobs_state(self, **kwargs): ys = wt.unsqueeze(1) else: ys = torch.cat([state[0][0], wt.unsqueeze(1)], dim=1) - seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze(1) + # ys = torch.zeros(16,1, dtype = int).to(device) + seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze( + 1) decoder_out = self.decoder(gx, ys[:, -1].unsqueeze(-1), encoder_out, att_mask, seq_mask, p_att_feats, True).squeeze(1) - + # print('HHHHHHHHHHHHHHHHHHHHHHH',decoder_out.shape) + # raise Exception('end') + logprobs = F.log_softmax(decoder_out, dim=-1) return logprobs, [ys.unsqueeze(0)] @@ -103,15 +168,19 @@ def fn(s): beam = selected_beam for _ in shape[1:]: beam = beam.unsqueeze(-1) + beam = beam.long() s = torch.gather(s.view(*([batch_size, cur_beam_size] + shape[1:])), 1, beam.expand(*([batch_size, beam_size] + shape[1:]))) + s = s.view(*([-1, ] + shape[1:])) return s + return fn # the beam search code is inspired by https://github.com/aimagelab/meshed-memory-transformer def decode_beam(self, **kwargs): att_feats = kwargs[cfg.PARAM.ATT_FEATS] + # att_mask = torch.ones(16,70).to(device) att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] beam_size = kwargs['BEAM_SIZE'] batch_size = att_feats.size(0) @@ -120,7 +189,13 @@ def decode_beam(self, **kwargs): selected_words = None seq_mask = torch.ones((batch_size, beam_size, 1)).cuda() + # Modified + att_feats = self.get_visual_features(att_feats) + batch_size, feat_size, _ ,_= att_feats.shape + att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) + # print(att_feats.shape) att_feats = self.att_embed(att_feats) + gx, encoder_out = self.encoder(att_feats, att_mask) p_att_feats = self.decoder.precompute(encoder_out) @@ -152,6 +227,7 @@ def decode_beam(self, **kwargs): selected_idx, selected_logprob = self.select(batch_size, beam_size, t, candidate_logprob) selected_beam = selected_idx / candidate_logprob.shape[-1] + selected_beam = selected_beam.long() selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] self.decoder.apply_to_states(self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam)) @@ -161,7 +237,8 @@ def decode_beam(self, **kwargs): outputs.append(selected_words.unsqueeze(-1)) this_word_logprob = torch.gather(word_logprob, 1, - selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1])) + selected_beam.unsqueeze(-1).expand(batch_size, beam_size, + word_logprob.shape[-1])) this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) log_probs = list( torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs) @@ -188,7 +265,7 @@ def decode_beam(self, **kwargs): kwargs[cfg.PARAM.GLOBAL_FEAT] = gx kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp - + seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True) outputs = torch.cat(outputs, -1) outputs = torch.gather(outputs, 1, sort_idxs.expand(batch_size, beam_size, cfg.MODEL.SEQ_LEN)) @@ -205,14 +282,20 @@ def decode(self, **kwargs): beam_size = kwargs['BEAM_SIZE'] greedy_decode = kwargs['GREEDY_DECODE'] att_feats = kwargs[cfg.PARAM.ATT_FEATS] + # att_mask = torch.ones(16,70).to(device) att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK] - batch_size = att_feats.size(0) + att_feats = self.get_visual_features(att_feats) + batch_size, feat_size, _ ,_= att_feats.shape + att_feats = att_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1) + att_feats = self.att_embed(att_feats) gx, encoder_out = self.encoder(att_feats, att_mask) + # print(gx.shape) + # print(encoder_out.shape) p_att_feats = self.decoder.precompute(encoder_out) self.decoder.init_buffer(batch_size) - + state = None sents = Variable(torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda()) logprobs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda()) @@ -225,7 +308,7 @@ def decode(self, **kwargs): kwargs[cfg.PARAM.WT] = wt kwargs[cfg.PARAM.STATE] = state logprobs_t, state = self.get_logprobs_state(**kwargs) - + if greedy_decode: logP_t, wt = torch.max(logprobs_t, 1) else: @@ -235,80 +318,93 @@ def decode(self, **kwargs): wt = wt.view(-1).long() unfinished = unfinished * (wt > 0) wt = wt * unfinished.type_as(wt) - sents[:,t] = wt - logprobs[:,t] = logP_t.view(-1) + sents[:, t] = wt + logprobs[:, t] = logP_t.view(-1) if unfinished.sum() == 0: break self.decoder.clear_buffer() return sents, logprobs + class Encoder(nn.Module): def __init__( - self, - embed_dim, - dropout, - att_type, - att_heads, - att_mid_dim, - att_mid_drop, - bifeat_emb_act, - bifeat_emb_drop, - ff_dropout, - layer_num + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + layer_num ): super(Encoder, self).__init__() self.att_heads = att_heads - self.layers = nn.ModuleList([]) + self.layers = nn.ModuleList([]) for i in range(layer_num): - sublayer = EncoderLayer( - embed_dim = embed_dim, - dropout = dropout, - att_type = att_type, - att_heads = att_heads, - att_mid_dim = att_mid_dim, - att_mid_drop = att_mid_drop, - bifeat_emb_act = bifeat_emb_act, - bifeat_emb_drop = bifeat_emb_drop, - ff_dropout = ff_dropout) + sublayer = EncoderLayer( + embed_dim=embed_dim, + dropout=dropout, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop, + bifeat_emb_act=bifeat_emb_act, + bifeat_emb_drop=bifeat_emb_drop, + ff_dropout=ff_dropout) self.layers.append(sublayer) - + self.proj_norm = nn.Sequential( - nn.Linear(embed_dim * (layer_num + 1), embed_dim), + nn.Linear(embed_dim * (layer_num + 1), embed_dim), torch.nn.LayerNorm(embed_dim)) + + + def forward(self, x, mask): - gx = (torch.sum(x * mask.unsqueeze(-1), 1) / torch.sum(mask.unsqueeze(-1), 1)) + # drop mask + att_masks = mask + # if att_masks is None: + # att_masks = x.new_ones(x.shape[:2], dtype=torch.long) + # att_masks = att_masks.unsqueeze(-2) + # print(x.shape) + # print(att_masks.shape) + gx = (torch.sum(x * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) gx_arr = [gx] for layer in self.layers: - gx, x = layer(gx, x, mask) + gx, x = layer(gx, x, att_masks) # modified gx_arr.append(gx) gx = torch.cat(gx_arr, dim=-1) gx = self.proj_norm(gx) + return gx, x + class EncoderLayer(nn.Module): def __init__( - self, - embed_dim, - dropout, - att_type, - att_heads, - att_mid_dim, - att_mid_drop, - bifeat_emb_act, - bifeat_emb_drop, - ff_dropout + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout ): super(EncoderLayer, self).__init__() self.encoder_attn = LowRank( - embed_dim = embed_dim, - att_type = att_type, - att_heads = att_heads, - att_mid_dim = att_mid_dim, - att_mid_drop = att_mid_drop) + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop) self.dropout = nn.Dropout(dropout) self.bifeat_emb = nn.Sequential( @@ -320,60 +416,70 @@ def __init__( self.ff_layer = blocks.create( 'FeedForward', - embed_dim = embed_dim, - ffn_embed_dim = embed_dim * 4, - relu_dropout = ff_dropout, - dropout = ff_dropout) + embed_dim=embed_dim, + ffn_embed_dim=embed_dim * 4, + relu_dropout=ff_dropout, + dropout=ff_dropout) def forward(self, gx, x, mask): + """ + gx : torch.Size([4, 768]) + x : torch.Size([4, 49, 768]) + mask : torch.Size([4, 49]) + + """ + gx = self.encoder_attn( - query = gx, - key = x, - mask = mask, - value1 = gx, - value2 = x + query=gx, + key=x, + mask=mask, + value1=gx, + value2=x ) + gx = self.dropout(gx) - x_ = torch.cat([gx.unsqueeze(1).expand_as(x), x], dim = -1) + x_ = torch.cat([gx.unsqueeze(1).expand_as(x), x], dim=-1) x = self.bifeat_emb(x_) + x x = self.layer_norm(x) if self.ff_layer is not None: x = self.ff_layer(x) + return gx, x + class Decoder(nn.Module): def __init__( - self, - vocab_size, - embed_dim, - dropout, - att_type, - att_heads, - att_mid_dim, - att_mid_drop, - bifeat_emb_act, - bifeat_emb_drop, - ff_dropout, - layer_num + self, + vocab_size, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + layer_num ): super(Decoder, self).__init__() self.att_heads = att_heads self.layers = nn.ModuleList([]) self.embed_dim = embed_dim for i in range(layer_num): - sublayer = DecoderLayer( - embed_dim = embed_dim, - dropout = dropout, - att_type = att_type, - att_heads = att_heads, - att_mid_dim = att_mid_dim, - att_mid_drop = att_mid_drop, - bifeat_emb_act = bifeat_emb_act, - bifeat_emb_drop = bifeat_emb_drop, - ff_dropout = ff_dropout, - last_layer = (i == layer_num -1)) + sublayer = DecoderLayer( + embed_dim=embed_dim, + dropout=dropout, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop, + bifeat_emb_act=bifeat_emb_act, + bifeat_emb_drop=bifeat_emb_drop, + ff_dropout=ff_dropout, + last_layer=(i == layer_num - 1)) self.layers.append(sublayer) self.dropout = nn.Dropout(cfg.MODEL.DROPOUT_WORD_EMBED) @@ -397,7 +503,7 @@ def __init__( torch.nn.LayerNorm(embed_dim) ) self.wbi_drop = nn.Dropout(cfg.MODEL.BILINEAR.DECODE_DROPOUT) - self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM) + self.dropout_lm = nn.Dropout(cfg.MODEL.DROPOUT_LM) self.proj_norm = nn.Sequential( nn.Linear(embed_dim * (layer_num + 1), 2 * embed_dim), @@ -431,14 +537,15 @@ def precompute(self, encoder_out): return p_att_feats def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, p_att_feats=None, precompute=False): + # device = torch.device("cuda") att_mask = att_mask.unsqueeze(1) - + # embed positions seq_len = prev_output_tokens.size(1) if self.seq_len is not None: seq_len = self.seq_len + seq_len self.seq_len = seq_len - positions = self.embed_positions(seq_len)[:,-1,:].unsqueeze(1) + positions = self.embed_positions(seq_len)[:, -1, :].unsqueeze(1) else: positions = self.embed_positions(seq_len) @@ -449,7 +556,7 @@ def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, x = self.layer_norm_word(x) if self.dropout is not None: x = self.dropout(x) - + # decoder layers gx = self.wbil1(gx) if self.x is None: @@ -462,53 +569,84 @@ def forward(self, gx, prev_output_tokens, encoder_out, att_mask, seq_mask=None, gx = gx * x_gx gx = self.wbi_drop(gx) - gx_arr = [gx] + # print(gx.shape) + # print(type(gx)) + # print((gx[0, 0, 0].dtype)) + + bs, att_nums, features_sizes = gx.shape + + # gx_arr = torch.zeros((len(self.layers), bs, att_nums, features_sizes), dtype = torch.float32, device = device)# .to(torch.device("cuda")) # HARD CODE + # print(gx_arr.shape) + # print(gx_arr.shape) + # gx_arr = gx_arr.to(device) + + gx_arr = [gx] # .to('cuda') + # gx_arr = + for layerid, layer in enumerate(self.layers): + # print('layerid',layerid) if precompute == False: p_key = None p_value2 = None else: - p_key, p_value2 = p_att_feats[layerid] - gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2, precompute=precompute) + p_key, p_value2 = p_att_feats[layerid] + # print('x.shape b4',x.shape) + # print('precompute',precompute) + # print('att_mask.shape',att_mask.shape) + # print('seq_mask.shape',seq_mask.shape) + gx, x = layer(gx, x, encoder_out, att_mask, seq_mask=seq_mask, p_key=p_key, p_value2=p_value2, + precompute=precompute) + # print('gx.shape',gx.shape) # bs, 41, 768 + # print('x.shape after',x.shape) + # print(type(gx_arr)) gx_arr.append(gx) + # print('test',gx_arr[layerid, :, :, :].shape) + # print('layerid',layerid) + # print('gx_arr.shape',gx_arr.shape) + # print('gx.device',gx.device) + # print('gx_arr.device',gx_arr.device) + # gx_arr[layerid] = gx + # raise Exception('decoder lol') + gx = torch.cat(gx_arr, dim=-1) - gx = torch.cat(gx_arr, dim = -1) gx = self.proj_norm(gx) gx = self.dropout_lm(gx) out = self.generator(gx) + # raise Exception('decoder lol') return out + class DecoderLayer(nn.Module): def __init__( - self, - embed_dim, - dropout, - att_type, - att_heads, - att_mid_dim, - att_mid_drop, - bifeat_emb_act, - bifeat_emb_drop, - ff_dropout, - last_layer = False + self, + embed_dim, + dropout, + att_type, + att_heads, + att_mid_dim, + att_mid_drop, + bifeat_emb_act, + bifeat_emb_drop, + ff_dropout, + last_layer=False ): super(DecoderLayer, self).__init__() self.last_layer = last_layer self.word_attn = LowRank( - embed_dim = embed_dim, - att_type = att_type, - att_heads = att_heads, - att_mid_dim = att_mid_dim, - att_mid_drop = att_mid_drop) + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop) self.word_dropout = nn.Dropout(dropout) self.cross_att = LowRank( - embed_dim = embed_dim, - att_type = att_type, - att_heads = att_heads, - att_mid_dim = att_mid_dim, - att_mid_drop = att_mid_drop) + embed_dim=embed_dim, + att_type=att_type, + att_heads=att_heads, + att_mid_dim=att_mid_dim, + att_mid_drop=att_mid_drop) self.cross_dropout = nn.Dropout(dropout) self.layer_norm_cross = torch.nn.LayerNorm(embed_dim) @@ -522,10 +660,10 @@ def __init__( self.ff_layer = blocks.create( 'FeedForward', - embed_dim = embed_dim, - ffn_embed_dim = embed_dim * 4, - relu_dropout = ff_dropout, - dropout = ff_dropout) + embed_dim=embed_dim, + ffn_embed_dim=embed_dim * 4, + relu_dropout=ff_dropout, + dropout=ff_dropout) self.layer_norm_gx = torch.nn.LayerNorm(embed_dim) @@ -543,42 +681,43 @@ def precompute(self, encoder_out): return key, value2 def forward( - self, - gx, - x, - encoder_out, - att_mask, - seq_mask, - p_key=None, - p_value2=None, - precompute=False + self, + gx, + x, + encoder_out, + att_mask, + seq_mask, + p_key=None, + p_value2=None, + precompute=False ): word_x = x residual = x x = self.word_attn.forward2( - query = gx, - key = x, - mask = seq_mask, - value1 = gx, - value2 = x) + query=gx, + key=x, + mask=seq_mask, + value1=gx, + value2=x) x = self.word_dropout(x) + x = residual + x residual = x x = self.layer_norm_cross(x) x = self.cross_att.forward2( - query = x, - key = encoder_out if precompute == False else p_key, - mask = att_mask, - value1 = x, - value2 = encoder_out if precompute == False else p_value2, + query=x, + key=encoder_out if precompute == False else p_key, + mask=att_mask, + value1=x, + value2=encoder_out if precompute == False else p_value2, precompute=precompute) x = self.cross_dropout(x) gx = residual + x gx = self.layer_norm_gx(gx) if self.last_layer == False: - x_ = torch.cat([gx, word_x], dim = -1) + x_ = torch.cat([gx, word_x], dim=-1) x = self.bifeat_emb(x_) + word_x x = self.layer_norm_x(x) @@ -586,4 +725,5 @@ def forward( x = self.ff_layer(x) else: x = None + return gx, x diff --git a/optimizer/optimizer.py b/optimizer/optimizer.py index 4cbeeff..0d816e9 100755 --- a/optimizer/optimizer.py +++ b/optimizer/optimizer.py @@ -117,3 +117,28 @@ def get_lr(self): lr.append(param_group['lr']) lr = sorted(list(set(lr))) return lr + + + +def build_optimizer(args, model): + # print(model.submodel.parameters()) + #edit + # ve_params = list(map(id, model.submodel.parameters())) + # ed_params = filter(lambda x: id(x) not in ve_params, model.parameters()) + # optimizer = torch.optim.Adam( + # [{'params': model.submodel.parameters(), 'lr': 5e-5}, #edit + # {'params': ed_params, 'lr': 1e-4}], + # weight_decay=5e-5, + # amsgrad=True + # ) + optimizer = torch.optim.Adam( + [{'params': model, 'lr': 5e-5}], #edit + weight_decay=5e-5, + amsgrad=True + ) + return optimizer + + +def build_lr_scheduler(args, optimizer): + lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma) + return lr_scheduler diff --git a/pycocoevalcap/README.md b/pycocoevalcap/README.md new file mode 100644 index 0000000..942de18 --- /dev/null +++ b/pycocoevalcap/README.md @@ -0,0 +1,23 @@ +Microsoft COCO Caption Evaluation Tools
+--- + +Modified the code to work with Python 3.
+ +### Requirements +* Python 3.x +* Java 1.8 +* pycocotools + +--- + +### Tested on +* Windows 10, Python 3.5. + +--- +### To fix Windows JVM memory error:
+Add the following in System Variables
+    Variable name : _JAVA_OPTIONS
+    Variable value : -Xmx1024M
+ +--- +Original code : https://github.com/tylin/coco-caption
diff --git a/pycocoevalcap/__init__.py b/pycocoevalcap/__init__.py new file mode 100644 index 0000000..680063e --- /dev/null +++ b/pycocoevalcap/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/bleu/LICENSE b/pycocoevalcap/bleu/LICENSE new file mode 100644 index 0000000..9ccf677 --- /dev/null +++ b/pycocoevalcap/bleu/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pycocoevalcap/bleu/__init__.py b/pycocoevalcap/bleu/__init__.py new file mode 100644 index 0000000..680063e --- /dev/null +++ b/pycocoevalcap/bleu/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' \ No newline at end of file diff --git a/pycocoevalcap/bleu/bleu.py b/pycocoevalcap/bleu/bleu.py new file mode 100644 index 0000000..60e723e --- /dev/null +++ b/pycocoevalcap/bleu/bleu.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# +# File Name : bleu.py +# +# Description : Wrapper for BLEU scorer. +# +# Creation Date : 06-01-2015 +# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT +# Authors : Hao Fang and Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +from .bleu_scorer import BleuScorer + + +class Bleu: + def __init__(self, n=4): + # default compute Blue score up to 4 + self._n = n + self._hypo_for_image = {} + self.ref_for_image = {} + + def compute_score(self, gts, res, score_option = 'closest', verbose = 1): + ''' + Inputs: + gts - ground truths + res - predictions + score_option - {shortest, closest, average} + verbose - 1 or 0 + Outputs: + Blue scores + ''' + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + bleu_scorer = BleuScorer(n=self._n) + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + #assert(len(ref) >= 1) + + bleu_scorer += (hypo[0], ref) + + score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose) + + # return (bleu, bleu_info) + return score, scores + + def method(self): + return "Bleu" diff --git a/pycocoevalcap/bleu/bleu_scorer.py b/pycocoevalcap/bleu/bleu_scorer.py new file mode 100644 index 0000000..d5646aa --- /dev/null +++ b/pycocoevalcap/bleu/bleu_scorer.py @@ -0,0 +1,268 @@ +# bleu_scorer.py +# David Chiang + +# Copyright (c) 2004-2006 University of Maryland. All rights +# reserved. Do not redistribute without permission from the +# author. Not for commercial use. + +# Modified by: +# Hao Fang +# Tsung-Yi Lin + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +'''Provides: +cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). +cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). +''' + +import copy +import sys, math, re +from collections import defaultdict + +def precook(s, n=4, out=False): + """Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well.""" + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return (len(words), counts) + +def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them.''' + + reflen = [] + maxcounts = {} + for ref in refs: + rl, counts = precook(ref, n) + reflen.append(rl) + for (ngram,count) in counts.items(): + maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + # Calculate effective reference sentence length. + if eff == "shortest": + reflen = min(reflen) + elif eff == "average": + reflen = float(sum(reflen))/len(reflen) + + ## lhuang: N.B.: leave reflen computaiton to the very end!! + + ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) + + return (reflen, maxcounts) + +def cook_test(test, refs , eff=None, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it.''' + + reflen = refs[0] + refmaxcounts = refs[1] + + testlen, counts = precook(test, n, True) + + result = {} + + # Calculate effective reference sentence length. + + if eff == "closest": + result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] + else: ## i.e., "average" or "shortest" or None + result["reflen"] = reflen + + result["testlen"] = testlen + + result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] + + result['correct'] = [0]*n + for (ngram, count) in counts.items(): + result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) + + return result + +class BleuScorer(object): + """Bleu scorer. + """ + + __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" + # special_reflen is used in oracle (proportional effective ref len for a node). + + def copy(self): + ''' copy the refs.''' + new = BleuScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + new._score = None + return new + + def __init__(self, test=None, refs=None, n=4, special_reflen=None): + ''' singular instance ''' + + self.n = n + self.crefs = [] + self.ctest = [] + self.cook_append(test, refs) + self.special_reflen = special_reflen + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + cooked_test = cook_test(test, self.crefs[-1]) + self.ctest.append(cooked_test) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + self._score = None ## need to recompute + + def ratio(self, option=None): + self.compute_score(option=option) + return self._ratio + + def score_ratio(self, option=None): + '''return (bleu, len_ratio) pair''' + return (self.fscore(option=option), self.ratio(option=option)) + + def score_ratio_str(self, option=None): + return "%.4f (%.2f)" % self.score_ratio(option) + + def reflen(self, option=None): + self.compute_score(option=option) + return self._reflen + + def testlen(self, option=None): + self.compute_score(option=option) + return self._testlen + + def retest(self, new_test): + if type(new_test) is str: + new_test = [new_test] + assert len(new_test) == len(self.crefs), new_test + self.ctest = [] + for t, rs in zip(new_test, self.crefs): + self.ctest.append(cook_test(t, rs)) + self._score = None + + return self + + def rescore(self, new_test): + ''' replace test(s) with new test(s), and returns the new score.''' + + return self.retest(new_test).compute_score() + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new BleuScorer instances + self.cook_append(other[0], other[1]) + else: + assert self.compatible(other), "incompatible BLEUs." + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + self._score = None ## need to recompute + + return self + + def compatible(self, other): + return isinstance(other, BleuScorer) and self.n == other.n + + def single_reflen(self, option="average"): + return self._single_reflen(self.crefs[0][0], option) + + def _single_reflen(self, reflens, option=None, testlen=None): + + if option == "shortest": + reflen = min(reflens) + elif option == "average": + reflen = float(sum(reflens))/len(reflens) + elif option == "closest": + reflen = min((abs(l-testlen), l) for l in reflens)[1] + else: + assert False, "unsupported reflen option %s" % option + + return reflen + + def recompute_score(self, option=None, verbose=0): + self._score = None + return self.compute_score(option, verbose) + + def compute_score(self, option=None, verbose=0): + n = self.n + small = 1e-9 + tiny = 1e-15 ## so that if guess is 0 still return 0 + bleu_list = [[] for _ in range(n)] + + if self._score is not None: + return self._score + + if option is None: + option = "average" if len(self.crefs) == 1 else "closest" + + self._testlen = 0 + self._reflen = 0 + totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} + + # for each sentence + for comps in self.ctest: + testlen = comps['testlen'] + self._testlen += testlen + + if self.special_reflen is None: ## need computation + reflen = self._single_reflen(comps['reflen'], option, testlen) + else: + reflen = self.special_reflen + + self._reflen += reflen + + for key in ['guess','correct']: + for k in range(n): + totalcomps[key][k] += comps[key][k] + + # append per image bleu score + bleu = 1. + for k in range(n): + bleu *= (float(comps['correct'][k]) + tiny) \ + /(float(comps['guess'][k]) + small) + bleu_list[k].append(bleu ** (1./(k+1))) + ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleu_list[k][-1] *= math.exp(1 - 1/ratio) + + if verbose > 1: + print(comps, reflen) + + totalcomps['reflen'] = self._reflen + totalcomps['testlen'] = self._testlen + + bleus = [] + bleu = 1. + for k in range(n): + bleu *= float(totalcomps['correct'][k] + tiny) \ + / (totalcomps['guess'][k] + small) + bleus.append(bleu ** (1./(k+1))) + ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division + if ratio < 1: + for k in range(n): + bleus[k] *= math.exp(1 - 1/ratio) + + if verbose > 0: + print(totalcomps) + print("ratio:", ratio) + + self._score = bleus + return self._score, bleu_list diff --git a/pycocoevalcap/cider/__init__.py b/pycocoevalcap/cider/__init__.py new file mode 100644 index 0000000..3f7d85b --- /dev/null +++ b/pycocoevalcap/cider/__init__.py @@ -0,0 +1 @@ +__author__ = 'tylin' diff --git a/pycocoevalcap/cider/cider.py b/pycocoevalcap/cider/cider.py new file mode 100644 index 0000000..7aadb9a --- /dev/null +++ b/pycocoevalcap/cider/cider.py @@ -0,0 +1,55 @@ +# Filename: cider.py +# +# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric +# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) +# +# Creation Date: Sun Feb 8 14:16:54 2015 +# +# Authors: Ramakrishna Vedantam and Tsung-Yi Lin + + +from .cider_scorer import CiderScorer +import pdb + +class Cider: + """ + Main Class to compute the CIDEr metric + + """ + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + # set cider to sum over 1 to 4-grams + self._n = n + # set the standard deviation parameter for gaussian penalty + self._sigma = sigma + + def compute_score(self, gts, res): + """ + Main function to compute CIDEr score + :param hypo_for_image (dict) : dictionary with key and value + ref_for_image (dict) : dictionary with key and value + :return: cider (float) : computed CIDEr score for the corpus + """ + + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) + + for id in imgIds: + hypo = res[id] + ref = gts[id] + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + cider_scorer += (hypo[0], ref) + + (score, scores) = cider_scorer.compute_score() + + return score, scores + + def method(self): + return "CIDEr" \ No newline at end of file diff --git a/pycocoevalcap/cider/cider_scorer.py b/pycocoevalcap/cider/cider_scorer.py new file mode 100644 index 0000000..94752e8 --- /dev/null +++ b/pycocoevalcap/cider/cider_scorer.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# Tsung-Yi Lin +# Ramakrishna Vedantam + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import copy +from collections import defaultdict +import numpy as np +import pdb +import math + +def precook(s, n=4, out=False): + """ + Takes a string as input and returns an object that can be given to + either cook_refs or cook_test. This is optional: cook_refs and cook_test + can take string arguments as well. + :param s: string : sentence to be converted into ngrams + :param n: int : number of ngrams for which representation is calculated + :return: term frequency vector for occuring ngrams + """ + words = s.split() + counts = defaultdict(int) + for k in range(1,n+1): + for i in range(len(words)-k+1): + ngram = tuple(words[i:i+k]) + counts[ngram] += 1 + return counts + +def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" + '''Takes a list of reference sentences for a single segment + and returns an object that encapsulates everything that BLEU + needs to know about them. + :param refs: list of string : reference sentences for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (list of dict) + ''' + return [precook(ref, n) for ref in refs] + +def cook_test(test, n=4): + '''Takes a test sentence and returns an object that + encapsulates everything that BLEU needs to know about it. + :param test: list of string : hypothesis sentence for some image + :param n: int : number of ngrams for which (ngram) representation is calculated + :return: result (dict) + ''' + return precook(test, n, True) + +class CiderScorer(object): + """CIDEr scorer. + """ + + def copy(self): + ''' copy the refs.''' + new = CiderScorer(n=self.n) + new.ctest = copy.copy(self.ctest) + new.crefs = copy.copy(self.crefs) + return new + + def __init__(self, test=None, refs=None, n=4, sigma=6.0): + ''' singular instance ''' + self.n = n + self.sigma = sigma + self.crefs = [] + self.ctest = [] + self.document_frequency = defaultdict(float) + self.cook_append(test, refs) + self.ref_len = None + + def cook_append(self, test, refs): + '''called by constructor and __iadd__ to avoid creating new instances.''' + + if refs is not None: + self.crefs.append(cook_refs(refs)) + if test is not None: + self.ctest.append(cook_test(test)) ## N.B.: -1 + else: + self.ctest.append(None) # lens of crefs and ctest have to match + + def size(self): + assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) + return len(self.crefs) + + def __iadd__(self, other): + '''add an instance (e.g., from another sentence).''' + + if type(other) is tuple: + ## avoid creating new CiderScorer instances + self.cook_append(other[0], other[1]) + else: + self.ctest.extend(other.ctest) + self.crefs.extend(other.crefs) + + return self + def compute_doc_freq(self): + ''' + Compute term frequency for reference data. + This will be used to compute idf (inverse document frequency later) + The term frequency is stored in the object + :return: None + ''' + for refs in self.crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): + self.document_frequency[ngram] += 1 + # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) + + def compute_cider(self): + def counts2vec(cnts): + """ + Function maps counts of ngram to vector of tfidf weights. + The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. + The n-th entry of array denotes length of n-grams. + :param cnts: + :return: vec (array of dict), norm (array of float), length (int) + """ + vec = [defaultdict(float) for _ in range(self.n)] + length = 0 + norm = [0.0 for _ in range(self.n)] + for (ngram,term_freq) in cnts.items(): + # give word count 1 if it doesn't appear in reference corpus + df = np.log(max(1.0, self.document_frequency[ngram])) + # ngram index + n = len(ngram)-1 + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[n][ngram] = float(term_freq)*(self.ref_len - df) + # compute norm for the vector. the norm will be used for computing similarity + norm[n] += pow(vec[n][ngram], 2) + + if n == 1: + length += term_freq + norm = [np.sqrt(n) for n in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + ''' + Compute the cosine similarity of two vectors. + :param vec_hyp: array of dictionary for vector corresponding to hypothesis + :param vec_ref: array of dictionary for vector corresponding to reference + :param norm_hyp: array of float for vector corresponding to hypothesis + :param norm_ref: array of float for vector corresponding to reference + :param length_hyp: int containing length of hypothesis + :param length_ref: int containing length of reference + :return: array of score for each n-grams cosine similarity + ''' + delta = float(length_hyp - length_ref) + # measure consine similarity + val = np.array([0.0 for _ in range(self.n)]) + for n in range(self.n): + # ngram + for (ngram,count) in vec_hyp[n].items(): + # vrama91 : added clipping + val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] + + if (norm_hyp[n] != 0) and (norm_ref[n] != 0): + val[n] /= (norm_hyp[n]*norm_ref[n]) + + assert(not math.isnan(val[n])) + # vrama91: added a length based gaussian penalty + val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) + return val + + # compute log reference length + self.ref_len = np.log(float(len(self.crefs))) + + scores = [] + for test, refs in zip(self.ctest, self.crefs): + # compute vector for test captions + vec, norm, length = counts2vec(test) + # compute vector for ref captions + score = np.array([0.0 for _ in range(self.n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec(ref) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + # change by vrama91 - mean of ngram scores, instead of sum + score_avg = np.mean(score) + # divide by number of references + score_avg /= len(refs) + # multiply score by 10 + score_avg *= 10.0 + # append score of an image to the score list + scores.append(score_avg) + return scores + + def compute_score(self, option=None, verbose=0): + # compute idf + self.compute_doc_freq() + # assert to check document frequency + assert(len(self.ctest) >= max(self.document_frequency.values())) + # compute cider score + score = self.compute_cider() + # debug + # print score + return np.mean(np.array(score)), np.array(score) \ No newline at end of file diff --git a/pycocoevalcap/eval.py b/pycocoevalcap/eval.py new file mode 100644 index 0000000..21f53dc --- /dev/null +++ b/pycocoevalcap/eval.py @@ -0,0 +1,74 @@ +__author__ = 'tylin' +from .tokenizer.ptbtokenizer import PTBTokenizer +from .bleu.bleu import Bleu +from .meteor.meteor import Meteor +from .rouge.rouge import Rouge +from .cider.cider import Cider + +class COCOEvalCap: + def __init__(self, coco, cocoRes): + self.evalImgs = [] + self.eval = {} + self.imgToEval = {} + self.coco = coco + self.cocoRes = cocoRes + self.params = {'image_id': cocoRes.getImgIds()} + + def evaluate(self): + imgIds = self.params['image_id'] + # imgIds = self.coco.getImgIds() + gts = {} + res = {} + for imgId in imgIds: + gts[imgId] = self.coco.imgToAnns[imgId] + res[imgId] = self.cocoRes.imgToAnns[imgId] + + # ================================================= + # Set up scorers + # ================================================= + print('tokenization...') + tokenizer = PTBTokenizer() + gts = tokenizer.tokenize(gts) + res = tokenizer.tokenize(res) + + # ================================================= + # Set up scorers + # ================================================= + print('setting up scorers...') + scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(),"METEOR"), + (Rouge(), "ROUGE_L"), + (Cider(), "CIDEr") + ] + + # ================================================= + # Compute scores + # ================================================= + eval = {} + for scorer, method in scorers: + print('computing %s score...'%(scorer.method())) + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, scs, m in zip(score, scores, method): + self.setEval(sc, m) + self.setImgToEvalImgs(scs, imgIds, m) + print("%s: %0.3f"%(m, sc)) + else: + self.setEval(score, method) + self.setImgToEvalImgs(scores, imgIds, method) + print("%s: %0.3f"%(method, score)) + self.setEvalImgs() + + def setEval(self, score, method): + self.eval[method] = score + + def setImgToEvalImgs(self, scores, imgIds, method): + for imgId, score in zip(imgIds, scores): + if not imgId in self.imgToEval: + self.imgToEval[imgId] = {} + self.imgToEval[imgId]["image_id"] = imgId + self.imgToEval[imgId][method] = score + + def setEvalImgs(self): + self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] diff --git a/pycocoevalcap/license.txt b/pycocoevalcap/license.txt new file mode 100644 index 0000000..3ada56f --- /dev/null +++ b/pycocoevalcap/license.txt @@ -0,0 +1,26 @@ +Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those +of the authors and should not be interpreted as representing official policies, +either expressed or implied, of the FreeBSD Project. \ No newline at end of file diff --git a/pycocoevalcap/meteor/__init__.py b/pycocoevalcap/meteor/__init__.py new file mode 100644 index 0000000..349338d --- /dev/null +++ b/pycocoevalcap/meteor/__init__.py @@ -0,0 +1 @@ +from .meteor import * \ No newline at end of file diff --git a/pycocoevalcap/meteor/data/paraphrase-en.gz b/pycocoevalcap/meteor/data/paraphrase-en.gz new file mode 100644 index 0000000..88033c8 Binary files /dev/null and b/pycocoevalcap/meteor/data/paraphrase-en.gz differ diff --git a/pycocoevalcap/meteor/meteor-1.5.jar b/pycocoevalcap/meteor/meteor-1.5.jar new file mode 100644 index 0000000..a833bc0 Binary files /dev/null and b/pycocoevalcap/meteor/meteor-1.5.jar differ diff --git a/pycocoevalcap/meteor/meteor.py b/pycocoevalcap/meteor/meteor.py new file mode 100644 index 0000000..114b42a --- /dev/null +++ b/pycocoevalcap/meteor/meteor.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Python wrapper for METEOR implementation, by Xinlei Chen +# Acknowledge Michael Denkowski for the generous discussion and help + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +import os +import sys +import subprocess +import threading + +# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. +METEOR_JAR = 'meteor-1.5.jar' +# print METEOR_JAR + +class Meteor: + + def __init__(self): + self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ + '-', '-', '-stdio', '-l', 'en', '-norm'] + self.meteor_p = subprocess.Popen(self.meteor_cmd, \ + cwd=os.path.dirname(os.path.abspath(__file__)), \ + stdin=subprocess.PIPE, \ + stdout=subprocess.PIPE, \ + stderr=subprocess.PIPE, + universal_newlines = True, + bufsize = 1) + # Used to guarantee thread safety + self.lock = threading.Lock() + + def compute_score(self, gts, res): + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + scores = [] + + eval_line = 'EVAL' + self.lock.acquire() + for i in imgIds: + assert(len(res[i]) == 1) + stat = self._stat(res[i][0], gts[i]) + eval_line += ' ||| {}'.format(stat) + + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + for i in range(0,len(imgIds)): + scores.append(float(self.meteor_p.stdout.readline().strip())) + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + + return score, scores + + def method(self): + return "METEOR" + + def _stat(self, hypothesis_str, reference_list): + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + return self.meteor_p.stdout.readline().strip() + + def _score(self, hypothesis_str, reference_list): + self.lock.acquire() + # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words + hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') + score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) + self.meteor_p.stdin.write('{}\n'.format(score_line)) + stats = self.meteor_p.stdout.readline().strip() + eval_line = 'EVAL ||| {}'.format(stats) + # EVAL ||| stats + self.meteor_p.stdin.write('{}\n'.format(eval_line)) + score = float(self.meteor_p.stdout.readline().strip()) + # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice + # thanks for Andrej for pointing this out + score = float(self.meteor_p.stdout.readline().strip()) + self.lock.release() + return score + + def __del__(self): + self.lock.acquire() + self.meteor_p.stdin.close() + self.meteor_p.kill() + self.meteor_p.wait() + self.lock.release() diff --git a/pycocoevalcap/rouge/__init__.py b/pycocoevalcap/rouge/__init__.py new file mode 100644 index 0000000..e3c0469 --- /dev/null +++ b/pycocoevalcap/rouge/__init__.py @@ -0,0 +1 @@ +from .rouge import * \ No newline at end of file diff --git a/pycocoevalcap/rouge/rouge.py b/pycocoevalcap/rouge/rouge.py new file mode 100644 index 0000000..3a10f5a --- /dev/null +++ b/pycocoevalcap/rouge/rouge.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# +# File Name : rouge.py +# +# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) +# +# Creation Date : 2015-01-07 06:03 +# Author : Ramakrishna Vedantam + +import numpy as np +import pdb + +def my_lcs(string, sub): + """ + Calculates longest common subsequence for a pair of tokenized strings + :param string : list of str : tokens from a string split using whitespace + :param sub : list of str : shorter string, also split using whitespace + :returns: length (list of int): length of the longest common subsequence between the two strings + + Note: my_lcs only gives length of the longest common subsequence, not the actual LCS + """ + if(len(string)< len(sub)): + sub, string = string, sub + + lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] + + for j in range(1,len(sub)+1): + for i in range(1,len(string)+1): + if(string[i-1] == sub[j-1]): + lengths[i][j] = lengths[i-1][j-1] + 1 + else: + lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) + + return lengths[len(string)][len(sub)] + +class Rouge(): + ''' + Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set + + ''' + def __init__(self): + # vrama91: updated the value below based on discussion with Hovey + self.beta = 1.2 + + def calc_score(self, candidate, refs): + """ + Compute ROUGE-L score given one candidate and references for an image + :param candidate: str : candidate sentence to be evaluated + :param refs: list of str : COCO reference sentences for the particular image to be evaluated + :returns score: int (ROUGE-L score for the candidate evaluated against references) + """ + assert(len(candidate)==1) + assert(len(refs)>0) + prec = [] + rec = [] + + # split into tokens + token_c = candidate[0].split(" ") + + for reference in refs: + # split into tokens + token_r = reference.split(" ") + # compute the longest common subsequence + lcs = my_lcs(token_r, token_c) + prec.append(lcs/float(len(token_c))) + rec.append(lcs/float(len(token_r))) + + prec_max = max(prec) + rec_max = max(rec) + + if(prec_max!=0 and rec_max !=0): + score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) + else: + score = 0.0 + return score + + def compute_score(self, gts, res): + """ + Computes Rouge-L score given a set of reference and candidate sentences for the dataset + Invoked by evaluate_captions.py + :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values + :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values + :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) + """ + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + score = [] + for id in imgIds: + hypo = res[id] + ref = gts[id] + + score.append(self.calc_score(hypo, ref)) + + # Sanity check. + assert(type(hypo) is list) + assert(len(hypo) == 1) + assert(type(ref) is list) + assert(len(ref) > 0) + + average_score = np.mean(np.array(score)) + return average_score, np.array(score) + + def method(self): + return "Rouge" diff --git a/pycocoevalcap/tokenizer/__init__.py b/pycocoevalcap/tokenizer/__init__.py new file mode 100644 index 0000000..71357a4 --- /dev/null +++ b/pycocoevalcap/tokenizer/__init__.py @@ -0,0 +1 @@ +__author__ = 'hfang' diff --git a/pycocoevalcap/tokenizer/ptbtokenizer.py b/pycocoevalcap/tokenizer/ptbtokenizer.py new file mode 100644 index 0000000..b7d06e1 --- /dev/null +++ b/pycocoevalcap/tokenizer/ptbtokenizer.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# +# File Name : ptbtokenizer.py +# +# Description : Do the PTB Tokenization and remove punctuations. +# +# Creation Date : 29-12-2014 +# Last Modified : Thu Mar 19 09:53:35 2015 +# Authors : Hao Fang and Tsung-Yi Lin + +import os +import sys +import subprocess +import tempfile +import itertools + + +# Last modified : Wed 22 May 2019 08:10:00 PM EDT +# By Sabarish Sivanath +# To support Python 3 + +# path to the stanford corenlp jar +STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' + +# punctuations to be removed from the sentences +PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ + ".", "?", "!", ",", ":", "-", "--", "...", ";"] + +class PTBTokenizer: + """Python wrapper of Stanford PTBTokenizer""" + + def tokenize(self, captions_for_image): + cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ + 'edu.stanford.nlp.process.PTBTokenizer', \ + '-preserveLines', '-lowerCase'] + + # ====================================================== + # prepare data for PTB Tokenizer + # ====================================================== + final_tokenized_captions_for_image = {} + image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] + sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) + + # ====================================================== + # save sentences to temporary file + # ====================================================== + path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) + tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) + tmp_file.write(sentences.encode('utf-8')) + tmp_file.close() + + # ====================================================== + # tokenize sentence + # ====================================================== + cmd.append(os.path.basename(tmp_file.name)) + p_tokenizer = subprocess.Popen(cmd, + cwd=path_to_jar_dirname, + stdout=subprocess.PIPE, + universal_newlines = True, + bufsize = 1) + token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] + lines = token_lines.split('\n') + # remove temp file + os.remove(tmp_file.name) + + # ====================================================== + # create dictionary for tokenized captions + # ====================================================== + for k, line in zip(image_id, lines): + if not k in final_tokenized_captions_for_image: + final_tokenized_captions_for_image[k] = [] + tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ + if w not in PUNCTUATIONS]) + final_tokenized_captions_for_image[k].append(tokenized_caption) + + return final_tokenized_captions_for_image diff --git a/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar new file mode 100644 index 0000000..3cfa0a0 Binary files /dev/null and b/pycocoevalcap/tokenizer/stanford-corenlp-3.4.1.jar differ diff --git a/scorer/cider.py b/scorer/cider.py index c4cc910..3415ae4 100755 --- a/scorer/cider.py +++ b/scorer/cider.py @@ -33,20 +33,24 @@ def compute_score(self, gts, res): :return: cider (float) : computed CIDEr score for the corpus """ - # clear all the previous hypos and refs - self.cider_scorer.clear() - for i, hypo in enumerate(res): - ref = gts[i] + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) + + for id in imgIds: + hypo = res[id] + ref = gts[id] # Sanity check. - #assert(type(hypo) is list) - #assert(len(hypo) == 1) + assert(type(hypo) is list) + assert(len(hypo) == 1) assert(type(ref) is list) assert(len(ref) > 0) - self.cider_scorer += (hypo, ref) + cider_scorer += (hypo[0], ref) - (score, scores) = self.cider_scorer.compute_score() + (score, scores) = cider_scorer.compute_score() return score, scores diff --git a/scorer/cider_scorer.py b/scorer/cider_scorer.py index 793e180..71c4330 100755 --- a/scorer/cider_scorer.py +++ b/scorer/cider_scorer.py @@ -11,6 +11,7 @@ import math import pickle from lib.config import cfg +import gzip def precook(words, n=4, out=False): """ @@ -64,10 +65,14 @@ def __init__(self, test=None, refs=None, n=4, sigma=6.0): self.sigma = sigma self.crefs = [] self.ctest = [] - - cider_cache = pickle.load(open(cfg.SCORER.CIDER_CACHED, 'rb'), encoding='bytes') - self.document_frequency = cider_cache['document_frequency'] - self.ref_len = cider_cache['ref_len'] + cider_cache = None # TODO + try: + cider_cache = pickle.load(gzip.open(cfg.SCORER.CIDER_CACHED)) + except: + cider_cache = None + # print('CIDEr df: {0}'.format(len(cider_cache))) + self.document_frequency = defaultdict(float) if not cider_cache else cider_cache['document_frequency'] + self.ref_len = None if not cider_cache else cider_cache['ref_len'] self.cook_append(test, refs) def clear(self): @@ -169,7 +174,7 @@ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): return val # compute log reference length - # self.ref_len = np.log(float(len(self.crefs))) ########################### + self.ref_len = np.log(float(len(self.crefs))) ########################### scores = [] for test, refs in zip(self.ctest, self.crefs): diff --git a/scorer/scorer.py b/scorer/scorer.py index bb1e25a..6e1460c 100755 --- a/scorer/scorer.py +++ b/scorer/scorer.py @@ -23,16 +23,16 @@ def __init__(self): super(Scorer, self).__init__() self.scorers = [] self.weights = cfg.SCORER.WEIGHTS - self.gts = pickle.load(open(cfg.SCORER.GT_PATH, 'rb'), encoding='bytes') + # self.gts = pickle.load(open(cfg.SCORER.GT_PATH, 'rb'), encoding='bytes') for name in cfg.SCORER.TYPES: self.scorers.append(factory[name]()) - def __call__(self, ids, res): + def __call__(self, gts, res): hypo = [get_sents(r) for r in res] - gts = [self.gts[i] for i in ids] + gts = gts rewards_info = {} - rewards = np.zeros(len(ids)) + rewards = np.zeros(len(gts)) for i, scorer in enumerate(self.scorers): score, scores = scorer.compute_score(gts, hypo) rewards += self.weights[i] * scores