-
Notifications
You must be signed in to change notification settings - Fork 45
Expand file tree
/
Copy pathtorch_s2p2.py
More file actions
322 lines (276 loc) · 13.5 KB
/
torch_s2p2.py
File metadata and controls
322 lines (276 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH
class ComplexEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super(ComplexEmbedding, self).__init__()
self.real_embedding = nn.Embedding(*args, **kwargs)
self.imag_embedding = nn.Embedding(*args, **kwargs)
self.real_embedding.weight.data *= 1e-3
self.imag_embedding.weight.data *= 1e-3
def forward(self, x):
return torch.complex(
self.real_embedding(x),
self.imag_embedding(x),
)
class IntensityNet(nn.Module):
def __init__(self, input_dim, bias, num_event_types):
super().__init__()
self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias)
self.softplus = ScaledSoftplus(num_event_types)
def forward(self, x):
return self.softplus(self.intensity_net(x))
class S2P2(TorchBaseModel):
def __init__(self, model_config):
"""Initialize the model
Args:
model_config (EasyTPP.ModelConfig): config of model specs.
"""
super(S2P2, self).__init__(model_config)
self.n_layers = model_config.num_layers
self.P = model_config.model_specs["P"] # Hidden state dimension
self.H = model_config.hidden_size # Residual stream dimension
self.beta = model_config.model_specs.get("beta", 1.0)
self.bias = model_config.model_specs.get("bias", True)
self.simple_mark = model_config.model_specs.get("simple_mark", True)
layer_kwargs = dict(
P=self.P,
H=self.H,
dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4),
dt_init_max=model_config.model_specs.get("dt_init_max", 0.1),
act_func=model_config.model_specs.get("act_func", "full_glu"),
dropout_rate=model_config.model_specs.get("dropout_rate", 0.0),
for_loop=model_config.model_specs.get("for_loop", False),
pre_norm=model_config.model_specs.get("pre_norm", True),
post_norm=model_config.model_specs.get("post_norm", False),
simple_mark=self.simple_mark,
relative_time=model_config.model_specs.get("relative_time", False),
complex_values=model_config.model_specs.get("complex_values", True),
)
int_forward_variant = model_config.model_specs.get("int_forward_variant", False)
int_backward_variant = model_config.model_specs.get(
"int_backward_variant", False
)
assert (
int_forward_variant + int_backward_variant
) <= 1 # Only one at most is allowed to be specified
if int_forward_variant:
llh_layer = Int_Forward_LLH
elif int_backward_variant:
llh_layer = Int_Backward_LLH
else:
llh_layer = LLH
self.backward_variant = int_backward_variant
self.layers = nn.ModuleList(
[
llh_layer(**layer_kwargs, is_first_layer=i == 0)
for i in range(self.n_layers)
]
)
self.layers_mark_emb = nn.Embedding(
self.num_event_types_pad,
self.H,
) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse
self.layer_type_emb = None # Remove old embeddings from EasyTPP
self.intensity_net = IntensityNet(
input_dim=self.H,
bias=self.bias,
num_event_types=self.num_event_types,
)
def _get_intensity(
self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH
) -> torch.Tensor:
"""
Assume time has already been evolved, take a vertical stack of hidden states and produce intensity.
"""
left_u_H = None
for i, layer in enumerate(self.layers):
if isinstance(
x_LP, list
): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor
left_u_H = layer.depth_pass(
x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i]
)
else:
left_u_H = layer.depth_pass(
x_LP[..., i, :],
current_left_u_H=left_u_H,
prev_right_u_H=right_us_BNH[i],
)
return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H))
def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H):
left_u_GH = None
for i, layer in enumerate(self.layers):
x_GP = layer.get_left_limit(
right_limit_P=x_LP[..., i, :],
dt_G=dt_G,
next_left_u_GH=left_u_GH,
current_right_u_H=right_us_H[i],
)
left_u_GH = layer.depth_pass(
current_left_x_P=x_GP,
current_left_u_H=left_u_GH,
prev_right_u_H=right_us_H[i],
)
return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH))
def forward(
self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Batch operations of self._forward
"""
t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch
right_xs_BNP = [] # including both t_0 and t_N
left_xs_BNm1P = []
right_us_BNH = [
None
] # Start with None as this is the 'input' to the first layer
left_u_BNH, right_u_BNH = None, None
alpha_BNP = self.layers_mark_emb(marks_BN)
for l_i, layer in enumerate(self.layers):
# for each event, compute the fixed impulse via alpha_m for event i of type m
init_state = (
initial_state_BLP[:, l_i] if initial_state_BLP is not None else None
)
# Returns right limit of xs and us for [t0, t1, ..., tN]
# "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys)
# x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N]
# next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant
# next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT
x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward(
left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state
)
assert next_layer_right_u_BNH is not None
right_xs_BNP.append(x_BNP)
if next_layer_left_u_BNH is None: # NOT backward variant
left_xs_BNm1P.append(
layer.get_left_limit( # current and next at event level
x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}]
dt_BN[..., 1:].unsqueeze(
-1
), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}]
current_right_u_H=right_u_BNH
if right_u_BNH is None
else right_u_BNH[
..., :-1, :
], # at time [t_0, t_1, ..., t_{N-1}]
next_left_u_GH=left_u_BNH
if left_u_BNH is None
else left_u_BNH[..., 1:, :].unsqueeze(
-2
), # at time [t_1, t_2 ..., t_N]
).squeeze(-2)
)
right_us_BNH.append(next_layer_right_u_BNH)
left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH
right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2)
ret_val = {
"right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N]
"right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None
}
if left_u_BNH is not None: # backward variant
ret_val["left_u_BNm1H"] = left_u_BNH[
..., 1:, :
] # The next inputs after last layer -> transformation of ys
else: # NOT backward variant
ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2)
# 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not)
# 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us
return ret_val
def loglike_loss(self, batch, **kwargs):
# hidden states at the left and right limits around event time; note for the shift by 1 in indices:
# consider a sequence [t0, t1, ..., tN]
# Produces the following:
# left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers
# OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer
#
# right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers
# right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers
forward_results = self.forward(
batch
) # N minus 1 comparing with sequence lengths
right_xs_BNLP, right_us_BNH = (
forward_results["right_xs_BNLP"],
forward_results["right_us_BNH"],
)
right_us_BNm1H = [
None if right_u_BNH is None else right_u_BNH[:, :-1, :]
for right_u_BNH in right_us_BNH
]
ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch
# evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M]
# left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N
# Note: no need to discard the left limit of t_N because "marks_mask" will deal with it
if "left_u_BNm1H" in forward_results: # ONLY backward variant
intensity_B_Nm1_M = self.intensity_net(
forward_results["left_u_BNm1H"]
) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"]))
else: # NOT backward variant
intensity_B_Nm1_M = self._get_intensity(
forward_results["left_xs_BNm1LP"], right_us_BNm1H
)
# sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample]
# N-1 because we only consider the intervals between N events
# G for grid points
dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:])
# evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M)
intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts(
right_xs_BNLP[
:, :-1
], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1
dts_sample_B_Nm1_G,
right_us_BNm1H,
)
event_ll, non_event_ll, num_events = self.compute_loglikelihood(
lambda_at_event=intensity_B_Nm1_M,
lambdas_loss_samples=intensity_dts_B_Nm1_G_M,
time_delta_seq=dts_BN[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
type_seq=marks_BN[:, 1:],
)
# compute loss to optimize
loss = -(event_ll - non_event_ll).sum()
return loss, num_events
def compute_intensities_at_sample_times(
self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs
):
"""Compute the intensity at sampled times, not only event times. *from the left limit*
Args:
time_seq (tensor): [batch_size, seq_len], times seqs.
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
event_seq (tensor): [batch_size, seq_len], event type seqs.
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps.
Returns:
tensor: [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
"""
compute_last_step_only = kwargs.get("compute_last_step_only", False)
# assume inter_event_times_BN always starts from 0
_input = event_times_BN, inter_event_times_BN, marks_BN, None, None
# 'seq_len - 1' left limit for [t_1, ..., t_N]
# 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N]
forward_results = self.forward(
_input
) # N minus 1 comparing with sequence lengths
right_xs_BNLP, right_us_BNH = (
forward_results["right_xs_BNLP"],
forward_results["right_us_BNH"],
)
if (
compute_last_step_only
): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...]
right_us_B1H = [
None if right_u_BNH is None else right_u_BNH[:, -1:, :]
for right_u_BNH in right_us_BNH
]
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts(
right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H
) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...]
else:
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts(
right_xs_BNLP, sample_dtimes, right_us_BNH
)
return sampled_intensity # [B, N, MC, M]