Variational Autoencoder#

Hide code cell source

import logging
from functools import partial

import pandas as pd
import sklearn
import torch
from fastai import learner
from fastai.basics import *
from fastai.callback.all import *
from fastai.callback.all import EarlyStoppingCallback
from fastai.learner import Learner
from fastai.torch_basics import *
from IPython.display import display
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from torch.nn import Sigmoid

import pimmslearn
import pimmslearn.model
import pimmslearn.models as models
import pimmslearn.nb
from pimmslearn.analyzers import analyzers
from pimmslearn.io import datasplits
# overwriting Recorder callback with custom plot_loss
from pimmslearn.models import ae, plot_loss

learner.Recorder.plot_loss = plot_loss


logger = pimmslearn.logging.setup_logger(logging.getLogger('pimmslearn'))
logger.info(
    "Experiment 03 - Analysis of latent spaces and performance comparisions")

figures = {}  # collection of ax or figures
pimmslearn - INFO     Experiment 03 - Analysis of latent spaces and performance comparisions

Hide code cell source

# catch passed parameters
args = None
args = dict(globals()).keys()

Papermill script parameters:

# files and folders
# Datasplit folder with data for experiment
folder_experiment: str = 'runs/example'
folder_data: str = ''  # specify data directory if needed
file_format: str = 'csv'  # file format of create splits, default pickle (pkl)
# Machine parsed metadata from rawfile workflow
fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv'
# training
epochs_max: int = 50  # Maximum number of epochs
batch_size: int = 64  # Batch size for training (and evaluation)
cuda: bool = True  # Whether to use a GPU for training
# model
# Dimensionality of encoding dimension (latent space of model)
latent_dim: int = 25
# A underscore separated string of layers, '256_128' for the encoder, reverse will be use for decoder
hidden_layers: str = '256_128'
# force_train:bool = True # Force training when saved model could be used. Per default re-train model
patience: int = 50  # Patience for early stopping
sample_idx_position: int = 0  # position of index which is sample ID
model: str = 'VAE'  # model name
model_key: str = 'VAE'  # potentially alternative key for model (grid search)
save_pred_real_na: bool = True  # Save all predictions for missing values
# metadata -> defaults for metadata extracted from machine data
meta_date_col: str = None  # date column in meta data
meta_cat_col: str = None  # category column in meta data
# Parameters
model = "VAE"
latent_dim = 10
batch_size = 64
epochs_max = 300
hidden_layers = "64"
sample_idx_position = 0
cuda = False
save_pred_real_na = True
fn_rawfile_metadata = "https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv"
folder_experiment = "runs/alzheimer_study"
model_key = "VAE"

Some argument transformations

Hide code cell source

args = pimmslearn.nb.get_params(args, globals=globals())
args
{'folder_experiment': 'runs/alzheimer_study',
 'folder_data': '',
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'epochs_max': 300,
 'batch_size': 64,
 'cuda': False,
 'latent_dim': 10,
 'hidden_layers': '64',
 'patience': 50,
 'sample_idx_position': 0,
 'model': 'VAE',
 'model_key': 'VAE',
 'save_pred_real_na': True,
 'meta_date_col': None,
 'meta_cat_col': None}

Hide code cell source

args = pimmslearn.nb.args_from_dict(args)

if isinstance(args.hidden_layers, str):
    args.overwrite_entry("hidden_layers", [int(x)
                         for x in args.hidden_layers.split('_')])
else:
    raise ValueError(
        f"hidden_layers is of unknown type {type(args.hidden_layers)}")
args
{'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'VAE',
 'model_key': 'VAE',
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 50,
 'sample_idx_position': 0,
 'save_pred_real_na': True}

Some naming conventions

Hide code cell source

TEMPLATE_MODEL_PARAMS = 'model_params_{}.json'

Load data in long format#

Hide code cell source

data = datasplits.DataSplits.from_folder(
    args.data, file_format=args.file_format)
pimmslearn.io.datasplits - INFO     Loaded 'train_X' from file: runs/alzheimer_study/data/train_X.csv
pimmslearn.io.datasplits - INFO     Loaded 'val_y' from file: runs/alzheimer_study/data/val_y.csv
pimmslearn.io.datasplits - INFO     Loaded 'test_y' from file: runs/alzheimer_study/data/test_y.csv

data is loaded in long format

Hide code cell source

data.train_X.sample(5)
Sample ID   protein groups        
Sample_189  P05937                   16.652
Sample_061  P00739                   19.051
Sample_024  Q9BUD6                   15.834
Sample_143  C9JE82;Q9NY47;Q9NY47-2   16.275
Sample_094  H0YAC1;P03952            17.487
Name: intensity, dtype: float64

Infer index names from long format

Hide code cell source

index_columns = list(data.train_X.index.names)
sample_id = index_columns.pop(args.sample_idx_position)
if len(index_columns) == 1:
    index_column = index_columns.pop()
    index_columns = None
    logger.info(f"{sample_id = }, single feature: {index_column = }")
else:
    logger.info(f"{sample_id = }, multiple features: {index_columns = }")

if not index_columns:
    index_columns = [sample_id, index_column]
else:
    raise NotImplementedError(
        "More than one feature: Needs to be implemented. see above logging output.")
pimmslearn - INFO     sample_id = 'Sample ID', single feature: index_column = 'protein groups'

load meta data for splits

Hide code cell source

if args.fn_rawfile_metadata:
    df_meta = pd.read_csv(args.fn_rawfile_metadata, index_col=0)
    display(df_meta.loc[data.train_X.index.levels[0]])
else:
    df_meta = None
_collection site _age at CSF collection _gender _t-tau [ng/L] _p-tau [ng/L] _Abeta-42 [ng/L] _Abeta-40 [ng/L] _Abeta-42/Abeta-40 ratio _primary biochemical AD classification _clinical AD diagnosis _MMSE score
Sample ID
Sample_000 Sweden 71.000 f 703.000 85.000 562.000 NaN NaN biochemical control NaN NaN
Sample_001 Sweden 77.000 m 518.000 91.000 334.000 NaN NaN biochemical AD NaN NaN
Sample_002 Sweden 75.000 m 974.000 87.000 515.000 NaN NaN biochemical AD NaN NaN
Sample_003 Sweden 72.000 f 950.000 109.000 394.000 NaN NaN biochemical AD NaN NaN
Sample_004 Sweden 63.000 f 873.000 88.000 234.000 NaN NaN biochemical AD NaN NaN
... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 Berlin 69.000 f 1,945.000 NaN 699.000 12,140.000 0.058 biochemical AD AD 17.000
Sample_206 Berlin 73.000 m 299.000 NaN 1,420.000 16,571.000 0.086 biochemical control non-AD 28.000
Sample_207 Berlin 71.000 f 262.000 NaN 639.000 9,663.000 0.066 biochemical control non-AD 28.000
Sample_208 Berlin 83.000 m 289.000 NaN 1,436.000 11,285.000 0.127 biochemical control non-AD 24.000
Sample_209 Berlin 63.000 f 591.000 NaN 1,299.000 11,232.000 0.116 biochemical control non-AD 29.000

210 rows × 11 columns

Initialize Comparison#

  • replicates idea for truely missing values: Define truth as by using n=3 replicates to impute each sample

  • real test data:

    • Not used for predictions or early stopping.

    • [x] add some additional NAs based on distribution of data

Hide code cell source

freq_feat = pimmslearn.io.datasplits.load_freq(args.data)
freq_feat.head()  # training data
protein groups
A0A024QZX5;A0A087X1N8;P35237                                                     197
A0A024R0T9;K7ER74;P02655                                                         208
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8   185
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503                                          208
A0A075B6H7                                                                        97
Name: freq, dtype: int64

Produce some addional simulated samples#

The validation simulated NA is used to by all models to evaluate training performance.

Hide code cell source

val_pred_simulated_na = data.val_y.to_frame(name='observed')
val_pred_simulated_na
observed
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630
Sample_050 Q9Y287 15.755
Sample_107 Q8N475;Q8N475-2 15.029
Sample_199 P06307 19.376
Sample_067 Q5VUB5 15.309
... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822
Sample_002 A0A0A0MT36 18.165
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525
Sample_182 Q8NFT8 14.379
Sample_123 Q16853;Q16853-2 14.504

12600 rows × 1 columns

Hide code cell source

test_pred_simulated_na = data.test_y.to_frame(name='observed')
test_pred_simulated_na.describe()
observed
count 12,600.000
mean 16.339
std 2.741
min 7.209
25% 14.412
50% 15.935
75% 17.910
max 30.140

Data in wide format#

  • Autoencoder need data in wide format

Hide code cell source

data.to_wide_format()
args.M = data.train_X.shape[-1]
data.train_X.head()
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN

5 rows × 1421 columns

Add interpolation performance#

Fill Validation data with potentially missing features#

Hide code cell source

data.train_X
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 NaN 17.705 17.039 NaN 16.413 19.102 ... NaN 15.684 14.236 15.415 17.551 17.922 16.340 19.928 12.929 NaN
Sample_206 15.798 17.554 15.600 15.938 NaN 18.154 18.152 16.503 16.860 18.538 ... 15.422 16.106 NaN 15.345 17.084 18.708 NaN 19.433 NaN NaN
Sample_207 15.739 NaN 15.469 16.898 NaN 18.636 17.950 16.321 16.401 18.849 ... 15.808 16.098 14.403 15.715 NaN 18.725 16.138 19.599 13.637 11.174
Sample_208 15.477 16.779 14.995 16.132 NaN 14.908 NaN NaN 16.119 18.368 ... 15.157 16.712 NaN 14.640 16.533 19.411 15.807 19.545 NaN NaN
Sample_209 NaN 17.261 15.175 16.235 NaN 17.893 17.744 16.371 15.780 18.806 ... 15.237 15.652 15.211 14.205 16.749 19.275 15.732 19.577 11.042 11.791

210 rows × 1421 columns

Hide code cell source

data.val_y  # potentially has less features
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1419 columns

Hide code cell source

data.val_y = pd.DataFrame(pd.NA, index=data.train_X.index,
                          columns=data.train_X.columns).fillna(data.val_y)
data.val_y
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1421 columns

Variational Autoencoder#

Analysis: DataLoaders, Model, transform#

Hide code cell source

default_pipeline = sklearn.pipeline.Pipeline(
    [
        ('normalize', StandardScaler()),
        ('impute', SimpleImputer(add_indicator=False))
    ])

Analysis: DataLoaders, Model#

Hide code cell source

analysis = ae.AutoEncoderAnalysis(  # datasplits=data,
    train_df=data.train_X,
    val_df=data.val_y,
    model=models.vae.VAE,
    model_kwargs=dict(n_features=data.train_X.shape[-1],
                      n_neurons=args.hidden_layers,
                      # last_encoder_activation=None,
                      last_decoder_activation=None,
                      dim_latent=args.latent_dim),
    transform=default_pipeline,
    decode=['normalize'],
    bs=args.batch_size)
args.n_params = analysis.n_params_ae
if args.cuda:
    analysis.model = analysis.model.cuda()
analysis.model
VAE(
  (encoder): Sequential(
    (0): Linear(in_features=1421, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=20, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=2842, bias=True)
  )
)

Training#

Hide code cell source

results = []
loss_fct = partial(models.vae.loss_fct, results=results)

Hide code cell source

analysis.learn = Learner(dls=analysis.dls,
                         model=analysis.model,
                         loss_func=loss_fct,
                         cbs=[ae.ModelAdapterVAE(),
                              EarlyStoppingCallback(patience=args.patience)
                              ])

analysis.learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback, EarlyStoppingCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [ModelAdapterVAE, CastToTensor]
         - after_pred     : [ModelAdapterVAE]
         - after_loss     : [ModelAdapterVAE]
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder, EarlyStoppingCallback]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback, EarlyStoppingCallback]

Adding a EarlyStoppingCallback results in an error. Potential fix in PR3509 is not yet in current version. Try again later

Hide code cell source

# learn.summary()

Hide code cell source

suggested_lr = analysis.learn.lr_find()
analysis.params['suggested_inital_lr'] = suggested_lr.valley
suggested_lr
SuggestedLRs(valley=0.001737800776027143)
_images/b3941f7b857d62688adf7009ff5308efdcbd158196a7544db111304a9f66f5b2.png

Hide code cell source

results.clear()  # reset results

dump model config

Hide code cell source

# needs class as argument, not instance, but serialization needs instance
analysis.params['last_decoder_activation'] = Sigmoid()

pimmslearn.io.dump_json(
    pimmslearn.io.parse_dict(
        analysis.params, types=[
            (torch.nn.modules.module.Module, lambda m: str(m))
        ]),
    args.out_models / TEMPLATE_MODEL_PARAMS.format(args.model_key))

# restore original value
analysis.params['last_decoder_activation'] = Sigmoid

Hide code cell source

# papermill_description=train
analysis.learn.fit_one_cycle(args.epochs_max, lr_max=suggested_lr.valley)
epoch train_loss valid_loss time
0 1679.747314 93.264580 00:00
1 1687.586670 93.806671 00:00
2 1688.333008 93.741295 00:00
3 1685.789795 94.492981 00:00
4 1684.839844 94.405746 00:00
5 1682.767944 94.679359 00:00
6 1680.854492 95.105850 00:00
7 1678.374146 95.352272 00:00
8 1676.186157 95.920074 00:00
9 1674.153442 95.541725 00:00
10 1671.604248 95.511749 00:00
11 1669.333740 94.859673 00:00
12 1666.523438 95.256424 00:00
13 1664.035156 95.500076 00:00
14 1661.356934 95.341209 00:00
15 1659.449097 95.099937 00:00
16 1656.055176 94.962540 00:00
17 1652.710571 94.866142 00:00
18 1649.237305 94.313187 00:00
19 1644.597778 94.806648 00:00
20 1641.108398 94.324890 00:00
21 1635.779907 93.901337 00:00
22 1630.403687 94.015587 00:00
23 1625.109741 93.968086 00:00
24 1619.699341 93.510147 00:00
25 1612.748779 93.597626 00:00
26 1607.192261 93.579819 00:00
27 1601.231445 93.353172 00:00
28 1594.549072 93.113098 00:00
29 1586.864258 93.182617 00:00
30 1579.277344 92.649315 00:00
31 1570.719482 92.410507 00:00
32 1562.344360 92.277550 00:00
33 1554.878540 92.094910 00:00
34 1545.750610 92.386932 00:00
35 1536.902344 92.500809 00:00
36 1527.360107 92.074570 00:00
37 1517.319214 92.255219 00:00
38 1506.760742 92.130829 00:00
39 1497.557251 92.305214 00:00
40 1486.848389 92.177727 00:00
41 1477.448120 92.107079 00:00
42 1466.730225 91.859467 00:00
43 1456.724731 91.679298 00:00
44 1446.690430 91.749420 00:00
45 1437.525391 92.483406 00:00
46 1427.474854 93.100471 00:00
47 1419.457275 93.787056 00:00
48 1410.569702 93.598267 00:00
49 1401.533936 93.716080 00:00
50 1392.246582 93.572701 00:00
51 1383.956787 93.513519 00:00
52 1376.168335 93.527390 00:00
53 1367.707397 93.561592 00:00
54 1360.462891 93.389397 00:00
55 1353.390869 93.286911 00:00
56 1345.000000 93.539749 00:00
57 1337.416504 93.289948 00:00
58 1331.714966 92.887390 00:00
59 1324.526855 92.560875 00:00
60 1319.150513 92.585854 00:00
61 1313.045288 92.441544 00:00
62 1308.540649 92.163765 00:00
63 1301.767578 92.621353 00:00
64 1297.094238 92.269318 00:00
65 1291.561035 92.414154 00:00
66 1286.119995 92.259399 00:00
67 1279.629883 92.340088 00:00
68 1273.601562 92.285278 00:00
69 1268.223633 92.094482 00:00
70 1263.059082 92.065628 00:00
71 1257.644165 91.414749 00:00
72 1253.616943 91.352089 00:00
73 1248.113037 91.562447 00:00
74 1243.282104 91.645470 00:00
75 1238.748413 91.310043 00:00
76 1236.047485 91.536850 00:00
77 1231.263428 91.945412 00:00
78 1226.393311 91.623886 00:00
79 1223.724365 91.863571 00:00
80 1217.870605 91.334831 00:00
81 1214.563232 91.120125 00:00
82 1209.820312 91.256958 00:00
83 1206.147339 91.167900 00:00
84 1202.597290 91.214325 00:00
85 1198.552612 91.184822 00:00
86 1195.105347 90.729103 00:00
87 1191.384766 90.890106 00:00
88 1189.034912 91.291016 00:00
89 1189.107788 91.518478 00:00
90 1185.570679 90.798477 00:00
91 1182.302368 90.427101 00:00
92 1178.783691 90.796112 00:00
93 1175.132690 90.722870 00:00
94 1172.284668 90.933929 00:00
95 1169.241333 90.920288 00:00
96 1165.925781 90.671997 00:00
97 1164.308960 90.385498 00:00
98 1161.237061 90.524673 00:00
99 1159.630249 90.748047 00:00
100 1157.623901 90.548233 00:00
101 1154.378052 90.681824 00:00
102 1153.070679 90.530739 00:00
103 1150.992676 90.861694 00:00
104 1148.669800 90.873062 00:00
105 1146.234863 91.036034 00:00
106 1143.574829 90.831833 00:00
107 1140.920532 90.468056 00:00
108 1138.749512 90.691711 00:00
109 1136.838257 90.516762 00:00
110 1135.445679 90.355865 00:00
111 1133.353638 90.392067 00:00
112 1133.221558 90.678444 00:00
113 1131.347046 90.930367 00:00
114 1129.799316 90.838821 00:00
115 1128.222656 90.702187 00:00
116 1127.667358 90.979233 00:00
117 1126.755371 91.064568 00:00
118 1124.534180 90.975754 00:00
119 1125.123413 90.892509 00:00
120 1123.730103 90.771858 00:00
121 1122.260498 90.524635 00:00
122 1120.302490 90.168068 00:00
123 1117.336304 89.960754 00:00
124 1115.096191 90.372833 00:00
125 1113.867432 90.360260 00:00
126 1111.920532 90.463020 00:00
127 1109.843140 90.569412 00:00
128 1107.539185 90.411812 00:00
129 1108.063721 90.553383 00:00
130 1105.898682 90.643204 00:00
131 1104.882080 90.786690 00:00
132 1104.270508 91.152550 00:00
133 1102.231689 91.315849 00:00
134 1100.986206 91.206139 00:00
135 1099.403320 91.023964 00:00
136 1099.239380 90.633232 00:00
137 1096.340820 90.501305 00:00
138 1094.595947 90.238480 00:00
139 1093.564331 90.421631 00:00
140 1093.175537 90.355370 00:00
141 1091.858154 90.665054 00:00
142 1090.725220 90.547417 00:00
143 1091.094360 90.854782 00:00
144 1089.631714 90.817902 00:00
145 1089.246582 91.108017 00:00
146 1087.465332 90.981712 00:00
147 1086.172729 90.794289 00:00
148 1084.614746 90.783791 00:00
149 1084.220093 90.821625 00:00
150 1082.702637 90.953201 00:00
151 1081.160645 90.956665 00:00
152 1081.739746 90.817177 00:00
153 1080.257080 90.975533 00:00
154 1079.959106 90.945656 00:00
155 1079.967529 91.077652 00:00
156 1078.666260 91.122330 00:00
157 1077.766479 91.457504 00:00
158 1076.759277 91.417130 00:00
159 1075.275146 91.421455 00:00
160 1075.460815 91.279930 00:00
161 1075.438354 91.554939 00:00
162 1073.707642 91.133316 00:00
163 1074.200684 91.055656 00:00
164 1074.840088 91.217209 00:00
165 1072.674927 91.032051 00:00
166 1073.468872 90.930435 00:00
167 1072.017212 91.364990 00:00
168 1069.771484 91.579491 00:00
169 1067.674072 91.381737 00:00
170 1066.507690 91.381699 00:00
171 1065.070068 91.327263 00:00
172 1064.083618 91.260643 00:00
173 1064.306274 91.043556 00:00
No improvement since epoch 123: early stopping

Save number of actually trained epochs

Hide code cell source

args.epoch_trained = analysis.learn.epoch + 1
args.epoch_trained
174

Loss normalized by total number of measurements#

Hide code cell source

N_train_notna = data.train_X.notna().sum().sum()
N_val_notna = data.val_y.notna().sum().sum()
fig = models.plot_training_losses(analysis.learn, args.model_key,
                                  folder=args.out_figures,
                                  norm_factors=[N_train_notna, N_val_notna])
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/figures/vae_training
_images/8351a1ac4124665bb0e3ab91ad865c1eae3e64e7f4f6dc309302d87dc640b9f5.png

Predictions#

create predictions and select validation data predictions

Hide code cell source

analysis.model.eval()
pred, target = res = ae.get_preds_from_df(df=data.train_X, learn=analysis.learn,
                                          position_pred_tuple=0,
                                          transformer=analysis.transform)
pred = pred.stack()
pred
Sample ID   protein groups                                                                
Sample_000  A0A024QZX5;A0A087X1N8;P35237                                                     15.936
            A0A024R0T9;K7ER74;P02655                                                         16.647
            A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8   15.833
            A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503                                          16.792
            A0A075B6H7                                                                       17.134
                                                                                              ...  
Sample_209  Q9Y6R7                                                                           19.077
            Q9Y6X5                                                                           15.867
            Q9Y6Y8;Q9Y6Y8-2                                                                  19.383
            Q9Y6Y9                                                                           12.077
            S4R3U6                                                                           11.465
Length: 298410, dtype: float32

Hide code cell source

val_pred_simulated_na['VAE'] = pred  # 'model_key' ?
val_pred_simulated_na
observed VAE
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630 15.667
Sample_050 Q9Y287 15.755 16.818
Sample_107 Q8N475;Q8N475-2 15.029 14.404
Sample_199 P06307 19.376 19.018
Sample_067 Q5VUB5 15.309 15.004
... ... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822 22.818
Sample_002 A0A0A0MT36 18.165 15.892
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525 15.830
Sample_182 Q8NFT8 14.379 13.239
Sample_123 Q16853;Q16853-2 14.504 14.487

12600 rows × 2 columns

Hide code cell source

test_pred_simulated_na['VAE'] = pred  # model_key?
test_pred_simulated_na
observed VAE
Sample ID protein groups
Sample_000 A0A075B6P5;P01615 17.016 17.275
A0A087X089;Q16627;Q16627-2 18.280 17.850
A0A0B4J2B5;S4R460 21.735 22.300
A0A140T971;O95865;Q5SRR8;Q5SSV3 14.603 15.302
A0A140TA33;A0A140TA41;A0A140TA52;P22105;P22105-3;P22105-4 16.143 16.674
... ... ... ...
Sample_209 Q96ID5 16.074 16.083
Q9H492;Q9H492-2 13.173 13.592
Q9HC57 14.207 14.310
Q9NPH3;Q9NPH3-2;Q9NPH3-5 14.962 14.952
Q9UGM5;Q9UGM5-2 16.871 16.360

12600 rows × 2 columns

save missing values predictions

Hide code cell source

if args.save_pred_real_na:
    pred_real_na = ae.get_missing_values(df_train_wide=data.train_X,
                                         val_idx=val_pred_simulated_na.index,
                                         test_idx=test_pred_simulated_na.index,
                                         pred=pred)
    display(pred_real_na)
    pred_real_na.to_csv(args.out_preds / f"pred_real_na_{args.model_key}.csv")
Sample ID   protein groups          
Sample_000  A0A075B6J9                 15.531
            A0A075B6Q5                 15.862
            A0A075B6R2                 16.596
            A0A075B6S5                 16.187
            A0A087WSY4                 16.199
                                        ...  
Sample_209  Q9P1W8;Q9P1W8-2;Q9P1W8-4   16.134
            Q9UI40;Q9UI40-2            16.590
            Q9UIW2                     16.439
            Q9UMX0;Q9UMX0-2;Q9UMX0-4   14.065
            Q9UP79                     16.129
Name: intensity, Length: 46401, dtype: float32

Plots#

  • validation data

Hide code cell source

analysis.model = analysis.model.cpu()
# underlying data is train_X for both
# assert analysis.dls.valid.data.equals(analysis.dls.train.data)
# Reconstruct DataLoader for case that during training singleton batches were dropped
_dl = torch.utils.data.DataLoader(
    pimmslearn.io.datasets.DatasetWithTarget(
        analysis.dls.valid.data),
    batch_size=args.batch_size,
    shuffle=False)
df_latent = pimmslearn.model.get_latent_space(analysis.model.get_mu_and_logvar,
                                              dl=_dl,
                                              dl_index=analysis.dls.valid.data.index)
df_latent
latent dimension 1 latent dimension 2 latent dimension 3 latent dimension 4 latent dimension 5 latent dimension 6 latent dimension 7 latent dimension 8 latent dimension 9 latent dimension 10
Sample ID
Sample_000 -0.223 -0.382 0.632 2.599 1.628 -0.824 0.794 -0.429 -0.159 -1.026
Sample_001 0.839 -1.302 0.573 1.348 1.622 -1.291 0.595 -1.175 -0.774 -0.071
Sample_002 0.288 -1.607 0.548 1.292 2.315 -0.810 1.370 0.329 1.419 -0.867
Sample_003 0.420 -0.932 0.631 2.793 1.845 -1.099 0.505 -0.126 1.047 -0.635
Sample_004 0.301 -0.766 0.699 1.957 1.554 -1.045 -0.067 -0.522 0.432 -0.905
... ... ... ... ... ... ... ... ... ... ...
Sample_205 0.077 0.601 -1.019 1.022 0.504 -0.713 0.178 -1.012 2.406 0.311
Sample_206 -1.376 0.866 -0.294 -1.904 0.958 -0.412 1.869 -0.347 0.117 -0.220
Sample_207 -1.597 0.708 -0.924 0.696 -0.247 -0.186 -0.166 -1.559 0.582 -2.276
Sample_208 0.194 0.790 -1.888 -1.063 1.063 -2.135 -0.287 -0.904 0.266 -0.860
Sample_209 -0.095 -0.082 -1.332 -0.550 1.098 -1.844 0.177 0.213 1.701 -1.336

210 rows × 10 columns

Hide code cell source

ana_latent = analyzers.LatentAnalysis(df_latent,
                                      df_meta,
                                      args.model_key,
                                      folder=args.out_figures)
if args.meta_date_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_date'], ax = ana_latent.plot_by_date(
        args.meta_date_col)

Hide code cell source

if args.meta_cat_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_{"_".join(args.meta_cat_col.split())}'], ax = ana_latent.plot_by_category(
        args.meta_cat_col)

Hide code cell source

feat_freq_val = val_pred_simulated_na['observed'].groupby(level=-1).count()
feat_freq_val.name = 'freq_val'
ax = feat_freq_val.plot.box()
_images/96272a0cb3cf2da9c48f374842e3dd120b7161772dfd94ca21b26df8f29a9a27.png

Hide code cell source

feat_freq_val.value_counts().sort_index().head()  # require more than one feat?
freq_val
1    12
2    18
3    50
4    82
5   108
Name: count, dtype: int64

Hide code cell source

errors_val = val_pred_simulated_na.drop('observed', axis=1).sub(
    val_pred_simulated_na['observed'], axis=0)
errors_val = errors_val.abs().groupby(level=-1).mean()
errors_val = errors_val.join(freq_feat).sort_values(by='freq', ascending=True)


errors_val_smoothed = errors_val.copy()  # .loc[feat_freq_val > 1]
errors_val_smoothed[errors_val.columns[:-1]] = errors_val[errors_val.columns[:-1]
                                                          ].rolling(window=200, min_periods=1).mean()
ax = errors_val_smoothed.plot(x='freq', figsize=(15, 10))
# errors_val_smoothed
_images/d32f1f3e08e520fde15e97e5585288dee410b6f068feb9e00ef20c19b012b9d6.png

Hide code cell source

errors_val = val_pred_simulated_na.drop('observed', axis=1).sub(
    val_pred_simulated_na['observed'], axis=0)
errors_val.abs().groupby(level=-1).agg(['mean', 'count'])
VAE
mean count
protein groups
A0A024QZX5;A0A087X1N8;P35237 0.138 7
A0A024R0T9;K7ER74;P02655 1.337 4
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 0.254 9
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 0.247 6
A0A075B6H7 0.489 6
... ... ...
Q9Y6R7 0.407 10
Q9Y6X5 0.330 7
Q9Y6Y8;Q9Y6Y8-2 0.326 9
Q9Y6Y9 0.469 15
S4R3U6 0.470 24

1419 rows × 2 columns

Hide code cell source

errors_val
VAE
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 1.037
Sample_050 Q9Y287 1.063
Sample_107 Q8N475;Q8N475-2 -0.626
Sample_199 P06307 -0.358
Sample_067 Q5VUB5 -0.305
... ... ...
Sample_111 F6SYF8;Q9UBP4 -0.005
Sample_002 A0A0A0MT36 -2.273
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 0.305
Sample_182 Q8NFT8 -1.139
Sample_123 Q16853;Q16853-2 -0.017

12600 rows × 1 columns

Comparisons#

Simulated NAs : Artificially created NAs. Some data was sampled and set explicitly to misssing before it was fed to the model for reconstruction.

Validation data#

  • all measured (identified, observed) peptides in validation data

Hide code cell source

# papermill_description=metrics
# d_metrics = models.Metrics(no_na_key='NA interpolated', with_na_key='NA not interpolated')
d_metrics = models.Metrics()

The simulated NA for the validation step are real test data (not used for training nor early stopping)

Hide code cell source

added_metrics = d_metrics.add_metrics(val_pred_simulated_na, 'valid_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'VAE': {'MSE': 0.46059833913970355,
  'MAE': 0.4337781859714416,
  'N': 12600,
  'prop': 1.0}}

Test Datasplit#

Hide code cell source

added_metrics = d_metrics.add_metrics(test_pred_simulated_na, 'test_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'VAE': {'MSE': 0.4848059265006195,
  'MAE': 0.4392932054952363,
  'N': 12600,
  'prop': 1.0}}

Save all metrics as json

Hide code cell source

pimmslearn.io.dump_json(d_metrics.metrics, args.out_metrics /
                        f'metrics_{args.model_key}.json')
d_metrics
{ 'test_simulated_na': { 'VAE': { 'MAE': 0.4392932054952363,
                                  'MSE': 0.4848059265006195,
                                  'N': 12600,
                                  'prop': 1.0}},
  'valid_simulated_na': { 'VAE': { 'MAE': 0.4337781859714416,
                                   'MSE': 0.46059833913970355,
                                   'N': 12600,
                                   'prop': 1.0}}}

Hide code cell source

metrics_df = models.get_df_from_nested_dict(
    d_metrics.metrics, column_levels=['model', 'metric_name']).T
metrics_df
subset valid_simulated_na test_simulated_na
model metric_name
VAE MSE 0.461 0.485
MAE 0.434 0.439
N 12,600.000 12,600.000
prop 1.000 1.000

Save predictions#

Hide code cell source

# save simulated missing values for both splits
val_pred_simulated_na.to_csv(args.out_preds / f"pred_val_{args.model_key}.csv")
test_pred_simulated_na.to_csv(args.out_preds / f"pred_test_{args.model_key}.csv")

Config#

Hide code cell source

figures  # switch to fnames?
{}

Hide code cell source

args.dump(fname=args.out_models / f"model_config_{args.model_key}.yaml")
args
{'M': 1421,
 'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epoch_trained': 174,
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'VAE',
 'model_key': 'VAE',
 'n_params': 277998,
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 50,
 'sample_idx_position': 0,
 'save_pred_real_na': True}