JensLundsgaard commited on
Commit
29c3f73
·
verified ·
1 Parent(s): 4173997

Upload raffael_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_model.py +468 -0
raffael_model.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete High-Quality ConvLSTM Autoencoder
3
+ - Uses true ConvLSTM (not regular LSTM)
4
+ - Complete Encoder (2D CNN + ConvLSTM) with flattened latents
5
+ - Complete Decoder (ConvLSTM + ConvTranspose)
6
+ - Optional Empty/Non-empty Classifier
7
+ - Works with 128x128 input images
8
+ - Latent format: (B, T, N) where N is flattened spatial dimensions
9
+ - Includes ResNet-style residual connections in CNN layers
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ from raffael_conv_lstm import ConvLSTM
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+
16
+
17
+ class ResidualBlock(nn.Module):
18
+ """
19
+ Residual block for encoder with optional downsampling
20
+ Supports ablation: can disable residual connections and batch normalization
21
+ """
22
+ def __init__(self, in_channels, out_channels, downsample=False, use_residual=True, use_batchnorm=True):
23
+ super(ResidualBlock, self).__init__()
24
+
25
+ self.use_residual = use_residual
26
+ stride = 2 if downsample else 1
27
+
28
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
29
+ self.bn1 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
32
+ self.bn2 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
33
+
34
+ # Projection shortcut if channels change or downsampling (only if using residual)
35
+ if use_residual and (in_channels != out_channels or downsample):
36
+ shortcut_layers = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)]
37
+ if use_batchnorm:
38
+ shortcut_layers.append(nn.BatchNorm2d(out_channels))
39
+ self.shortcut = nn.Sequential(*shortcut_layers)
40
+ else:
41
+ self.shortcut = nn.Identity()
42
+
43
+ def forward(self, x):
44
+ identity = self.shortcut(x) if self.use_residual else 0
45
+
46
+ out = self.conv1(x)
47
+ out = self.bn1(out)
48
+ out = self.relu(out)
49
+
50
+ out = self.conv2(out)
51
+ out = self.bn2(out)
52
+
53
+ if self.use_residual:
54
+ out += identity
55
+ out = self.relu(out)
56
+
57
+ return out
58
+
59
+
60
+ class ResidualUpBlock(nn.Module):
61
+ """
62
+ Residual block for decoder with upsampling
63
+ Supports ablation: can disable residual connections and batch normalization
64
+ """
65
+ def __init__(self, in_channels, out_channels, use_residual=True, use_batchnorm=True):
66
+ super(ResidualUpBlock, self).__init__()
67
+
68
+ self.use_residual = use_residual
69
+
70
+ self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
71
+ self.bn1 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
74
+ self.bn2 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
75
+
76
+ # Shortcut with upsampling (only if using residual)
77
+ if use_residual:
78
+ self.shortcut = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
79
+ else:
80
+ self.shortcut = nn.Identity()
81
+
82
+ def forward(self, x):
83
+ identity = self.shortcut(x) if self.use_residual else 0
84
+
85
+ out = self.upsample(x)
86
+ out = self.bn1(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv(out)
90
+ out = self.bn2(out)
91
+
92
+ if self.use_residual:
93
+ out += identity
94
+ out = self.relu(out)
95
+
96
+ return out
97
+
98
+
99
+ class Encoder(nn.Module):
100
+ """
101
+ Encoder: 2D CNN spatial compression + optional ConvLSTM temporal modeling + flatten to (B, T, N)
102
+ Output: z_seq (B, T, latent_size) and z_last (B, latent_size)
103
+ Supports ablation: dropout rate, ConvLSTM on/off, residual connections, batch normalization
104
+ """
105
+
106
+ def __init__(self, input_channels=1, hidden_dim=256, num_layers=2, latent_size=4096,
107
+ dropout_rate=0.1, use_convlstm=True, use_residual=True, use_batchnorm=True):
108
+ super(Encoder, self).__init__()
109
+
110
+ self.hidden_dim = hidden_dim
111
+ self.latent_size = latent_size
112
+ self.use_convlstm = use_convlstm
113
+
114
+ # Spatial convolution with residual connections: process each frame separately
115
+ # 128x128 -> 64x64 -> 32x32 -> 16x16
116
+ self.spatial_cnn = nn.Sequential(
117
+ # Layer 1: 128 -> 64 (with downsampling)
118
+ ResidualBlock(input_channels, 64, downsample=True, use_residual=use_residual, use_batchnorm=use_batchnorm),
119
+
120
+ # Layer 2: 64 -> 32 (with downsampling)
121
+ ResidualBlock(64, 128, downsample=True, use_residual=use_residual, use_batchnorm=use_batchnorm),
122
+
123
+ # Layer 3: 32 -> 16 (with downsampling)
124
+ ResidualBlock(128, 256, downsample=True, use_residual=use_residual, use_batchnorm=use_batchnorm),
125
+ )
126
+
127
+ if use_convlstm:
128
+ # ConvLSTM: process temporal sequence
129
+ # Input: (B, T, 256, 16, 16)
130
+ # Output: (B, T, hidden_dim, 16, 16)
131
+ self.convlstm = ConvLSTM(
132
+ input_dim=256,
133
+ hidden_dim=hidden_dim,
134
+ kernel_size=(3, 3),
135
+ num_layers=num_layers,
136
+ batch_first=True,
137
+ return_all_layers=False
138
+ )
139
+ # Compress from hidden_dim * 16 * 16
140
+ compress_size = hidden_dim * 16 * 16
141
+ else:
142
+ # No ConvLSTM - just pass through spatial features
143
+ self.convlstm = None
144
+ # Compress from 256 * 16 * 16 (spatial CNN output)
145
+ compress_size = 256 * 16 * 16
146
+
147
+ # Dropout before latent compression
148
+ self.dropout = nn.Dropout(dropout_rate)
149
+
150
+ # Linear layer to compress spatial latent to fixed size
151
+ # Input: (B*T, compress_size)
152
+ # Output: (B*T, latent_size)
153
+ self.latent_compress = nn.Linear(compress_size, latent_size)
154
+
155
+ def forward(self, x):
156
+ """
157
+ Args:
158
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized to 128x128)
159
+
160
+ Returns:
161
+ z_seq: (B, T, latent_size) - compressed latent sequence
162
+ z_last: (B, latent_size) - last timestep compressed latent
163
+ """
164
+ B, T, C, H, W = x.shape
165
+
166
+ # Resize to 128x128 if needed
167
+ x = x.view(B * T, C, H, W) # (B*T, 1, H, W)
168
+ if H != 128 or W != 128:
169
+ x = torch.nn.functional.interpolate(x, size=(128, 128), mode='bilinear', align_corners=True)
170
+
171
+ # Spatial compression: process each frame separately
172
+ x = self.spatial_cnn(x) # (B*T, 256, 16, 16)
173
+ _, C2, H2, W2 = x.shape
174
+ x = x.view(B, T, C2, H2, W2) # (B, T, 256, 16, 16)
175
+
176
+ if self.use_convlstm:
177
+ # ConvLSTM processes temporal sequence
178
+ lstm_out, _ = self.convlstm(x) # list of (B, T, hidden_dim, 16, 16)
179
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
180
+ else:
181
+ # No temporal processing - just pass through spatial features
182
+ h_seq = x # (B, T, 256, 16, 16)
183
+
184
+ # Flatten and compress spatial dimensions with linear layer
185
+ B, T, C, H, W = h_seq.shape
186
+ h_flat = h_seq.view(B * T, C * H * W) # (B*T, C * 16 * 16)
187
+ h_flat = self.dropout(h_flat) # Apply dropout
188
+ z_compressed = self.latent_compress(h_flat) # (B*T, latent_size)
189
+ z_compressed = torch.nn.functional.relu(z_compressed)
190
+ z_seq = z_compressed.view(B, T, self.latent_size) # (B, T, latent_size)
191
+
192
+ # Take last timestep
193
+ z_last = z_seq[:, -1] # (B, latent_size)
194
+
195
+ return z_seq, z_last
196
+
197
+
198
+ class Decoder(nn.Module):
199
+ """
200
+ Decoder: Linear expansion + optional ConvLSTM temporal decoding + ConvTranspose spatial reconstruction
201
+ Input: z_seq (B, T, latent_size)
202
+ Output: x_rec (B, T, 1, 128, 128)
203
+ Supports ablation: ConvLSTM on/off, residual connections, batch normalization
204
+ """
205
+
206
+ def __init__(self, seq_len, latent_size=4096, latent_dim=256, hidden_dim=128, num_layers=2,
207
+ use_latent_split=False, use_convlstm=True, use_residual=True, use_batchnorm=True):
208
+ super(Decoder, self).__init__()
209
+ self.seq_len = seq_len
210
+ self.latent_dim = latent_dim
211
+ self.latent_size = latent_size
212
+ self.use_latent_split = use_latent_split
213
+ self.use_convlstm = use_convlstm
214
+
215
+ # If using latent split, we only use half the latent for reconstruction
216
+ effective_latent_size = latent_size // 2 if use_latent_split else latent_size
217
+
218
+ # Linear layer to expand compressed latent to spatial dimensions
219
+ # Input: (B*T, effective_latent_size)
220
+ # Output: (B*T, latent_dim * 16 * 16)
221
+ self.latent_expand = nn.Linear(effective_latent_size, latent_dim * 16 * 16)
222
+ self.latent_expand_empty = nn.Linear(effective_latent_size, latent_dim * 16 * 16)
223
+ if use_convlstm:
224
+ # ConvLSTM decodes temporal dimension
225
+ self.convlstm = ConvLSTM(
226
+ input_dim=latent_dim,
227
+ hidden_dim=hidden_dim,
228
+ kernel_size=(3, 3),
229
+ num_layers=num_layers,
230
+ batch_first=True,
231
+ return_all_layers=False
232
+ )
233
+ # Spatial decoder input channels
234
+ spatial_input_channels = hidden_dim
235
+ else:
236
+ # No ConvLSTM - just pass through expanded latent
237
+ self.convlstm = None
238
+ # Spatial decoder input channels
239
+ spatial_input_channels = latent_dim
240
+
241
+ # Spatial decoding with residual connections: 16x16 -> 32x32 -> 64x64 -> 128x128
242
+ self.spatial_decoder = nn.Sequential(
243
+ # 16 -> 32 (with upsampling)
244
+ ResidualUpBlock(spatial_input_channels, 128, use_residual=use_residual, use_batchnorm=use_batchnorm),
245
+
246
+ # 32 -> 64 (with upsampling)
247
+ ResidualUpBlock(128, 64, use_residual=use_residual, use_batchnorm=use_batchnorm),
248
+
249
+ # 64 -> 128 (with upsampling)
250
+ ResidualUpBlock(64, 32, use_residual=use_residual, use_batchnorm=use_batchnorm),
251
+
252
+ # Final output layer
253
+ nn.Conv2d(32, 1, kernel_size=3, padding=1),
254
+ nn.Sigmoid() # Assume pixels normalized to [0,1]
255
+ )
256
+
257
+ def forward(self, z_seq, empty_well=False):
258
+ """
259
+ Args:
260
+ z_seq: (B, T, latent_size) - compressed latent sequence from encoder
261
+ empty_well: bool - whether this is an empty well (uses first half of latent)
262
+
263
+ Returns:
264
+ x_rec: (B, T, 1, 128, 128) - reconstructed video sequence
265
+ """
266
+ B, T, L = z_seq.shape
267
+
268
+ # If using latent split, select which half to use
269
+ if self.use_latent_split:
270
+ if empty_well:
271
+ z_seq = z_seq[:, :, :L//2] # First half for empty wells
272
+ else:
273
+ z_seq = z_seq[:, :, L//2:] # Second half for embryos
274
+
275
+ # Expand compressed latent to spatial dimensions
276
+ z_flat = z_seq.view(B * T, -1) # (B*T, effective_latent_size)
277
+ z_expanded = self.latent_expand_empty(z_flat) if self.use_latent_split and empty_well else self.latent_expand(z_flat)
278
+ z_expanded = torch.nn.functional.relu(z_expanded)
279
+ z_spatial = z_expanded.view(B, T, self.latent_dim, 16, 16) # (B, T, latent_dim, 16, 16)
280
+
281
+ if self.use_convlstm:
282
+ # ConvLSTM decodes temporal dimension
283
+ lstm_out, _ = self.convlstm(z_spatial) # list of (B, T, hidden_dim, 16, 16)
284
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
285
+ else:
286
+ # No temporal processing - just pass through expanded latent
287
+ h_seq = z_spatial # (B, T, latent_dim, 16, 16)
288
+
289
+ # Spatial decoding: process each timestep separately
290
+ B, T, C, H, W = h_seq.shape
291
+ h_seq = h_seq.view(B * T, C, H, W) # (B*T, C, 16, 16)
292
+ x_rec = self.spatial_decoder(h_seq) # (B*T, 1, 128, 128)
293
+ x_rec = x_rec.view(B, T, 1, 128, 128) # (B, T, 1, 128, 128)
294
+
295
+ return x_rec
296
+
297
+
298
+ class LatentClassifier(nn.Module):
299
+ """
300
+ Empty / Non-empty Well Classifier
301
+ Classifies based on last timestep latent
302
+ """
303
+
304
+ def __init__(self, latent_size=4096, num_classes=2, dropout=0.3):
305
+ super(LatentClassifier, self).__init__()
306
+
307
+ self.head = nn.Sequential(
308
+ # Classification head - input is already flattened (B, latent_size)
309
+ nn.Linear(latent_size, 512),
310
+ nn.BatchNorm1d(512),
311
+ nn.ReLU(inplace=True),
312
+ nn.Dropout(dropout),
313
+
314
+ nn.Linear(512, 256),
315
+ nn.BatchNorm1d(256),
316
+ nn.ReLU(inplace=True),
317
+ nn.Dropout(dropout),
318
+
319
+ nn.Linear(256, num_classes)
320
+ )
321
+
322
+ def forward(self, z_last):
323
+ """
324
+ Args:
325
+ z_last: (B, latent_size) - last timestep compressed latent
326
+
327
+ Returns:
328
+ logits: (B, num_classes) - classification logits
329
+ """
330
+ return self.head(z_last)
331
+
332
+
333
+ class ConvLSTMAutoencoder(nn.Module, PyTorchModelHubMixin):
334
+ """
335
+ Complete ConvLSTM Autoencoder
336
+ Includes Encoder, Decoder, and optional Classifier
337
+ Compatible with HuggingFace Hub
338
+ Works with 128x128 images
339
+ Supports ablation studies: dropout, ConvLSTM, residual connections, batch normalization
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ seq_len=20,
345
+ input_channels=1,
346
+ encoder_hidden_dim=256,
347
+ encoder_layers=2,
348
+ decoder_hidden_dim=128,
349
+ decoder_layers=2,
350
+ latent_size=4096,
351
+ use_classifier=True,
352
+ num_classes=2,
353
+ use_latent_split=False,
354
+ # Ablation parameters
355
+ dropout_rate=0.1,
356
+ use_convlstm=True,
357
+ use_residual=True,
358
+ use_batchnorm=True
359
+ ):
360
+ super(ConvLSTMAutoencoder, self).__init__()
361
+
362
+ self.seq_len = seq_len
363
+ self.use_classifier = use_classifier
364
+ self.encoder_hidden_dim = encoder_hidden_dim
365
+ self.latent_size = latent_size
366
+ self.use_latent_split = use_latent_split
367
+ # Store ablation settings for reproducibility
368
+ self.dropout_rate = dropout_rate
369
+ self.use_convlstm = use_convlstm
370
+ self.use_residual = use_residual
371
+ self.use_batchnorm = use_batchnorm
372
+
373
+ # Core components
374
+ self.encoder = Encoder(
375
+ input_channels=input_channels,
376
+ hidden_dim=encoder_hidden_dim,
377
+ num_layers=encoder_layers,
378
+ latent_size=latent_size,
379
+ dropout_rate=dropout_rate,
380
+ use_convlstm=use_convlstm,
381
+ use_residual=use_residual,
382
+ use_batchnorm=use_batchnorm
383
+ )
384
+
385
+ self.decoder = Decoder(
386
+ seq_len=seq_len,
387
+ latent_size=latent_size,
388
+ latent_dim=encoder_hidden_dim,
389
+ hidden_dim=decoder_hidden_dim,
390
+ num_layers=decoder_layers,
391
+ use_latent_split=use_latent_split,
392
+ use_convlstm=use_convlstm,
393
+ use_residual=use_residual,
394
+ use_batchnorm=use_batchnorm
395
+ )
396
+
397
+ # Optional classifier
398
+ if use_classifier:
399
+ self.classifier = LatentClassifier(
400
+ latent_size=latent_size,
401
+ num_classes=num_classes
402
+ )
403
+
404
+ def forward(self, x, empty_well=False, return_all=False):
405
+ """
406
+ Args:
407
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized internally)
408
+ empty_well: bool - whether this is an empty well (for latent split)
409
+ return_all: whether to return all intermediate results
410
+
411
+ Returns:
412
+ Tuple of (reconstruction, lat_vec_seq) where:
413
+ - reconstruction: (B, T, 1, H, W) - reconstructed video (same size as input)
414
+ - lat_vec_seq: (B, T, latent_size or latent_size//2) - compressed latent sequence
415
+
416
+ If return_all is True, returns dict with keys:
417
+ - reconstruction: (B, T, 1, H, W) - reconstructed video
418
+ - z_seq: (B, T, latent_size) - compressed latent sequence (full)
419
+ - z_last: (B, latent_size) - last timestep compressed latent (full)
420
+ - logits: (B, num_classes) - classification logits (if enabled)
421
+ """
422
+ B, T, C, orig_H, orig_W = x.shape
423
+
424
+ # Encode (will resize to 128x128 internally)
425
+ z_seq, z_last = self.encoder(x)
426
+
427
+ # Decode (outputs 128x128)
428
+ x_rec = self.decoder(z_seq, empty_well=empty_well)
429
+
430
+ # Resize back to original input size if needed
431
+ if orig_H != 128 or orig_W != 128:
432
+ x_rec_flat = x_rec.view(B * T, C, 128, 128)
433
+ x_rec_flat = torch.nn.functional.interpolate(x_rec_flat, size=(orig_H, orig_W), mode='bilinear', align_corners=True)
434
+ x_rec = x_rec_flat.view(B, T, C, orig_H, orig_W)
435
+
436
+ if return_all:
437
+ # Build output dictionary
438
+ output = {
439
+ "reconstruction": x_rec,
440
+ "z_seq": z_seq,
441
+ "z_last": z_last,
442
+ }
443
+
444
+ # Optional classification
445
+ if self.use_classifier:
446
+ logits = self.classifier(z_last)
447
+ output["logits"] = logits
448
+
449
+ return output
450
+ else:
451
+ # Return tuple: (reconstruction, latent_vector)
452
+ # If using latent split, return only the relevant half
453
+ if self.use_latent_split:
454
+ if empty_well:
455
+ return x_rec, z_seq[:, :, :self.latent_size//2]
456
+ else:
457
+ return x_rec, z_seq[:, :, self.latent_size//2:]
458
+ else:
459
+ return x_rec, z_seq
460
+
461
+ def encode(self, x):
462
+ """Encode only, for extracting latent"""
463
+ z_seq, z_last = self.encoder(x)
464
+ return z_seq, z_last
465
+
466
+ def decode(self, z_seq, empty_well=False):
467
+ """Decode only, for reconstructing from latent"""
468
+ return self.decoder(z_seq, empty_well=empty_well)