Denoising Autoencoder#

Hide code cell source

import logging

import sklearn
from fastai import learner
from fastai.basics import *
from fastai.callback.all import *
from fastai.torch_basics import *
from IPython.display import display
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

import pimmslearn
import pimmslearn.model
import pimmslearn.models as models
from pimmslearn.analyzers import analyzers
from pimmslearn.io import datasplits
# overwriting Recorder callback with custom plot_loss
from pimmslearn.models import ae, plot_loss

learner.Recorder.plot_loss = plot_loss


logger = pimmslearn.logging.setup_logger(logging.getLogger('pimmslearn'))
logger.info(
    "Experiment 03 - Analysis of latent spaces and performance comparisions")

figures = {}  # collection of ax or figures
pimmslearn - INFO     Experiment 03 - Analysis of latent spaces and performance comparisions

Hide code cell source

# catch passed parameters
args = None
args = dict(globals()).keys()

Papermill script parameters:

# files and folders
# Datasplit folder with data for experiment
folder_experiment: str = 'runs/example'
folder_data: str = ''  # specify data directory if needed
file_format: str = 'csv'  # file format of create splits, default pickle (pkl)
# Machine parsed metadata from rawfile workflow
fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv'
# training
epochs_max: int = 50  # Maximum number of epochs
# early_stopping:bool = True # Wheather to use early stopping or not
patience: int = 25  # Patience for early stopping
batch_size: int = 64  # Batch size for training (and evaluation)
cuda: bool = True  # Whether to use a GPU for training
# model
# Dimensionality of encoding dimension (latent space of model)
latent_dim: int = 25
# A underscore separated string of layers, '128_64' for the encoder, reverse will be use for decoder
hidden_layers: str = '512'

sample_idx_position: int = 0  # position of index which is sample ID
model: str = 'DAE'  # model name
model_key: str = 'DAE'  # potentially alternative key for model (grid search)
save_pred_real_na: bool = True  # Save all predictions for missing values
# metadata -> defaults for metadata extracted from machine data
meta_date_col: str = None  # date column in meta data
meta_cat_col: str = None  # category column in meta data
# Parameters
model = "DAE"
latent_dim = 10
batch_size = 64
epochs_max = 300
hidden_layers = "64"
sample_idx_position = 0
cuda = False
save_pred_real_na = True
fn_rawfile_metadata = "https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv"
folder_experiment = "runs/alzheimer_study"
model_key = "DAE"

Some argument transformations

Hide code cell source

args = pimmslearn.nb.get_params(args, globals=globals())
args
{'folder_experiment': 'runs/alzheimer_study',
 'folder_data': '',
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'epochs_max': 300,
 'patience': 25,
 'batch_size': 64,
 'cuda': False,
 'latent_dim': 10,
 'hidden_layers': '64',
 'sample_idx_position': 0,
 'model': 'DAE',
 'model_key': 'DAE',
 'save_pred_real_na': True,
 'meta_date_col': None,
 'meta_cat_col': None}

Hide code cell source

args = pimmslearn.nb.args_from_dict(args)

if isinstance(args.hidden_layers, str):
    args.overwrite_entry("hidden_layers", [int(x)
                         for x in args.hidden_layers.split('_')])
else:
    raise ValueError(
        f"hidden_layers is of unknown type {type(args.hidden_layers)}")
args
{'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'DAE',
 'model_key': 'DAE',
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 25,
 'sample_idx_position': 0,
 'save_pred_real_na': True}

Some naming conventions

Hide code cell source

TEMPLATE_MODEL_PARAMS = 'model_params_{}.json'

Load data in long format#

Hide code cell source

data = datasplits.DataSplits.from_folder(
    args.data, file_format=args.file_format)
pimmslearn.io.datasplits - INFO     Loaded 'train_X' from file: runs/alzheimer_study/data/train_X.csv
pimmslearn.io.datasplits - INFO     Loaded 'val_y' from file: runs/alzheimer_study/data/val_y.csv
pimmslearn.io.datasplits - INFO     Loaded 'test_y' from file: runs/alzheimer_study/data/test_y.csv

data is loaded in long format

Hide code cell source

data.train_X.sample(5)
Sample ID   protein groups   
Sample_071  O60476              15.758
Sample_047  A0A182DWH7;P49908   19.899
Sample_049  P01912;Q5Y7A7       15.570
Sample_024  Q9UHG2              22.191
Sample_133  P33908              19.115
Name: intensity, dtype: float64

Infer index names from long format

Hide code cell source

index_columns = list(data.train_X.index.names)
sample_id = index_columns.pop(args.sample_idx_position)
if len(index_columns) == 1:
    index_column = index_columns.pop()
    index_columns = None
    logger.info(f"{sample_id = }, single feature: {index_column = }")
else:
    logger.info(f"{sample_id = }, multiple features: {index_columns = }")

if not index_columns:
    index_columns = [sample_id, index_column]
else:
    raise NotImplementedError(
        "More than one feature: Needs to be implemented. see above logging output.")
pimmslearn - INFO     sample_id = 'Sample ID', single feature: index_column = 'protein groups'

load meta data for splits

Hide code cell source

if args.fn_rawfile_metadata:
    df_meta = pd.read_csv(args.fn_rawfile_metadata, index_col=0)
    display(df_meta.loc[data.train_X.index.levels[0]])
else:
    df_meta = None
_collection site _age at CSF collection _gender _t-tau [ng/L] _p-tau [ng/L] _Abeta-42 [ng/L] _Abeta-40 [ng/L] _Abeta-42/Abeta-40 ratio _primary biochemical AD classification _clinical AD diagnosis _MMSE score
Sample ID
Sample_000 Sweden 71.000 f 703.000 85.000 562.000 NaN NaN biochemical control NaN NaN
Sample_001 Sweden 77.000 m 518.000 91.000 334.000 NaN NaN biochemical AD NaN NaN
Sample_002 Sweden 75.000 m 974.000 87.000 515.000 NaN NaN biochemical AD NaN NaN
Sample_003 Sweden 72.000 f 950.000 109.000 394.000 NaN NaN biochemical AD NaN NaN
Sample_004 Sweden 63.000 f 873.000 88.000 234.000 NaN NaN biochemical AD NaN NaN
... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 Berlin 69.000 f 1,945.000 NaN 699.000 12,140.000 0.058 biochemical AD AD 17.000
Sample_206 Berlin 73.000 m 299.000 NaN 1,420.000 16,571.000 0.086 biochemical control non-AD 28.000
Sample_207 Berlin 71.000 f 262.000 NaN 639.000 9,663.000 0.066 biochemical control non-AD 28.000
Sample_208 Berlin 83.000 m 289.000 NaN 1,436.000 11,285.000 0.127 biochemical control non-AD 24.000
Sample_209 Berlin 63.000 f 591.000 NaN 1,299.000 11,232.000 0.116 biochemical control non-AD 29.000

210 rows × 11 columns

Produce some addional simulated samples#

The validation simulated NA is used to by all models to evaluate training performance.

Hide code cell source

val_pred_simulated_na = data.val_y.to_frame(name='observed')
val_pred_simulated_na
observed
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630
Sample_050 Q9Y287 15.755
Sample_107 Q8N475;Q8N475-2 15.029
Sample_199 P06307 19.376
Sample_067 Q5VUB5 15.309
... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822
Sample_002 A0A0A0MT36 18.165
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525
Sample_182 Q8NFT8 14.379
Sample_123 Q16853;Q16853-2 14.504

12600 rows × 1 columns

Hide code cell source

test_pred_simulated_na = data.test_y.to_frame(name='observed')
test_pred_simulated_na.describe()
observed
count 12,600.000
mean 16.339
std 2.741
min 7.209
25% 14.412
50% 15.935
75% 17.910
max 30.140

Data in wide format#

  • Autoencoder need data in wide format

Hide code cell source

data.to_wide_format()
args.M = data.train_X.shape[-1]
data.train_X.head()
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN

5 rows × 1421 columns

Fill Validation data with potentially missing features#

Hide code cell source

data.train_X
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 NaN 17.705 17.039 NaN 16.413 19.102 ... NaN 15.684 14.236 15.415 17.551 17.922 16.340 19.928 12.929 NaN
Sample_206 15.798 17.554 15.600 15.938 NaN 18.154 18.152 16.503 16.860 18.538 ... 15.422 16.106 NaN 15.345 17.084 18.708 NaN 19.433 NaN NaN
Sample_207 15.739 NaN 15.469 16.898 NaN 18.636 17.950 16.321 16.401 18.849 ... 15.808 16.098 14.403 15.715 NaN 18.725 16.138 19.599 13.637 11.174
Sample_208 15.477 16.779 14.995 16.132 NaN 14.908 NaN NaN 16.119 18.368 ... 15.157 16.712 NaN 14.640 16.533 19.411 15.807 19.545 NaN NaN
Sample_209 NaN 17.261 15.175 16.235 NaN 17.893 17.744 16.371 15.780 18.806 ... 15.237 15.652 15.211 14.205 16.749 19.275 15.732 19.577 11.042 11.791

210 rows × 1421 columns

Hide code cell source

data.val_y  # potentially has less features
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1419 columns

Hide code cell source

data.val_y = pd.DataFrame(pd.NA, index=data.train_X.index,
                          columns=data.train_X.columns).fillna(data.val_y)
data.val_y
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1421 columns

Denoising Autoencoder#

Analysis: DataLoaders, Model, transform#

Hide code cell source

default_pipeline = sklearn.pipeline.Pipeline(
    [
        ('normalize', StandardScaler()),
        ('impute', SimpleImputer(add_indicator=False))
    ])

analysis = ae.AutoEncoderAnalysis(
    train_df=data.train_X,
    val_df=data.val_y,
    model=ae.Autoencoder,
    transform=default_pipeline,
    decode=['normalize'],
    model_kwargs=dict(n_features=data.train_X.shape[-1],
                      n_neurons=args.hidden_layers,
                      last_decoder_activation=None,
                      dim_latent=args.latent_dim),
    bs=args.batch_size)
args.n_params = analysis.n_params_ae

if args.cuda:
    analysis.model = analysis.model.cuda()
analysis.model
Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=1421, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=10, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=1421, bias=True)
  )
)

Training#

Hide code cell source

analysis.learn = Learner(dls=analysis.dls,
                         model=analysis.model,
                         loss_func=MSELossFlat(reduction='sum'),
                         cbs=[EarlyStoppingCallback(patience=args.patience),
                              ae.ModelAdapter(p=0.2)]
                         )

analysis.learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback, EarlyStoppingCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [ModelAdapter, CastToTensor]
         - after_pred     : [ModelAdapter]
         - after_loss     : [ModelAdapter]
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder, EarlyStoppingCallback]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback, EarlyStoppingCallback]

Adding a EarlyStoppingCallback results in an error. Potential fix in PR3509 is not yet in current version. Try again later

Hide code cell source

# learn.summary()

Hide code cell source

suggested_lr = analysis.learn.lr_find()
analysis.params['suggested_inital_lr'] = suggested_lr.valley
suggested_lr
SuggestedLRs(valley=0.015848932787775993)
_images/f9038eaee908d5ec3213838e2a8e63c374aa06ec66b266a71fe82e2cf580aab4.png

dump model config

Hide code cell source

pimmslearn.io.dump_json(analysis.params, args.out_models /
                        TEMPLATE_MODEL_PARAMS.format(args.model_key))

Hide code cell source

# papermill_description=train
analysis.learn.fit_one_cycle(args.epochs_max, lr_max=suggested_lr.valley)
epoch train_loss valid_loss time
0 65047.398438 4029.035400 00:00
1 63578.144531 4023.114258 00:00
2 62329.570312 3986.107178 00:00
3 61129.144531 3917.755127 00:00
4 59978.433594 3821.969971 00:00
5 58819.417969 3704.607422 00:00
6 57624.792969 3577.509521 00:00
7 56416.496094 3446.568848 00:00
8 55235.574219 3321.771484 00:00
9 54063.312500 3204.467285 00:00
10 52859.621094 3106.539795 00:00
11 51673.683594 3024.505127 00:00
12 50511.734375 2951.987549 00:00
13 49414.390625 2885.842285 00:00
14 48331.101562 2830.227051 00:00
15 47370.621094 2765.861816 00:00
16 46421.132812 2713.672363 00:00
17 45520.882812 2668.125488 00:00
18 44725.875000 2635.396973 00:00
19 43959.066406 2603.470947 00:00
20 43238.167969 2575.289307 00:00
21 42551.625000 2551.041260 00:00
22 41900.695312 2519.899414 00:00
23 41251.843750 2492.994141 00:00
24 40645.839844 2466.200684 00:00
25 40061.582031 2450.185059 00:00
26 39484.136719 2438.543701 00:00
27 38924.531250 2413.075195 00:00
28 38374.566406 2391.310547 00:00
29 37909.023438 2378.687500 00:00
30 37396.984375 2356.452881 00:00
31 36901.707031 2323.841309 00:00
32 36440.449219 2324.284180 00:00
33 35997.441406 2308.147705 00:00
34 35562.980469 2295.411621 00:00
35 35142.492188 2305.115967 00:00
36 34734.355469 2278.166260 00:00
37 34390.847656 2273.932617 00:00
38 34014.925781 2273.499268 00:00
39 33648.304688 2262.505615 00:00
40 33329.125000 2285.490723 00:00
41 33020.570312 2262.563721 00:00
42 32739.882812 2257.814453 00:00
43 32462.277344 2266.280762 00:00
44 32151.097656 2263.616455 00:00
45 31919.031250 2286.169678 00:00
46 31692.945312 2256.155273 00:00
47 31422.667969 2242.641846 00:00
48 31160.117188 2241.624756 00:00
49 30925.265625 2250.856689 00:00
50 30694.537109 2262.384766 00:00
51 30500.697266 2247.102295 00:00
52 30288.521484 2227.596680 00:00
53 30108.224609 2216.532471 00:00
54 29928.226562 2256.358154 00:00
55 29743.099609 2219.596436 00:00
56 29595.244141 2230.572266 00:00
57 29415.333984 2232.175049 00:00
58 29282.919922 2244.931641 00:00
59 29171.138672 2243.504395 00:00
60 29071.353516 2242.700928 00:00
61 28963.369141 2291.220459 00:00
62 28878.476562 2246.756348 00:00
63 28742.667969 2252.593506 00:00
64 28636.265625 2255.120605 00:00
65 28540.189453 2263.010010 00:00
66 28472.281250 2249.164551 00:00
67 28358.972656 2290.010498 00:00
68 28269.501953 2243.434814 00:00
69 28205.601562 2215.960205 00:00
70 28076.234375 2245.869141 00:00
71 28020.210938 2267.534180 00:00
72 27933.232422 2220.168945 00:00
73 27846.435547 2230.931885 00:00
74 27760.193359 2233.823486 00:00
75 27669.808594 2219.014648 00:00
76 27588.158203 2240.370605 00:00
77 27497.019531 2233.413818 00:00
78 27400.927734 2220.022949 00:00
79 27371.419922 2212.560059 00:00
80 27290.406250 2217.832031 00:00
81 27242.259766 2201.520752 00:00
82 27210.460938 2208.253174 00:00
83 27131.003906 2205.788818 00:00
84 27084.117188 2251.291504 00:00
85 26998.642578 2194.063477 00:00
86 26914.437500 2210.968262 00:00
87 26852.554688 2202.776367 00:00
88 26775.943359 2224.156738 00:00
89 26799.671875 2198.655029 00:00
90 26733.277344 2210.150146 00:00
91 26677.130859 2213.698242 00:00
92 26647.519531 2210.801270 00:00
93 26584.431641 2211.145508 00:00
94 26557.937500 2201.384033 00:00
95 26471.179688 2191.711426 00:00
96 26472.556641 2196.753906 00:00
97 26469.371094 2214.307617 00:00
98 26429.302734 2210.192139 00:00
99 26354.488281 2190.020996 00:00
100 26336.621094 2194.452393 00:00
101 26275.496094 2189.320312 00:00
102 26254.810547 2191.666504 00:00
103 26232.875000 2192.506348 00:00
104 26167.220703 2163.927246 00:00
105 26104.683594 2187.917725 00:00
106 26102.078125 2177.876465 00:00
107 26100.470703 2173.075195 00:00
108 26078.929688 2195.084473 00:00
109 26066.794922 2186.219238 00:00
110 26028.544922 2167.401855 00:00
111 26018.171875 2184.507812 00:00
112 25997.771484 2178.821777 00:00
113 25949.974609 2193.187988 00:00
114 25879.470703 2163.616455 00:00
115 25861.878906 2165.429932 00:00
116 25856.777344 2185.792236 00:00
117 25811.880859 2191.451172 00:00
118 25755.136719 2181.149902 00:00
119 25723.521484 2182.024170 00:00
120 25665.927734 2194.557129 00:00
121 25673.671875 2190.107666 00:00
122 25683.417969 2163.060791 00:00
123 25639.236328 2185.309326 00:00
124 25610.482422 2171.929688 00:00
125 25564.578125 2187.369385 00:00
126 25542.607422 2157.970703 00:00
127 25475.318359 2170.475586 00:00
128 25445.447266 2151.240723 00:00
129 25410.453125 2176.892334 00:00
130 25397.609375 2165.850098 00:00
131 25360.771484 2161.014404 00:00
132 25335.025391 2159.956299 00:00
133 25389.765625 2196.891113 00:00
134 25370.658203 2161.992676 00:00
135 25306.843750 2187.204102 00:00
136 25296.476562 2184.557129 00:00
137 25263.191406 2174.261719 00:00
138 25285.597656 2175.478271 00:00
139 25287.558594 2182.325928 00:00
140 25272.869141 2172.144775 00:00
141 25225.062500 2176.628662 00:00
142 25153.960938 2158.908203 00:00
143 25123.347656 2182.579834 00:00
144 25096.099609 2180.518555 00:00
145 25069.359375 2166.120117 00:00
146 25026.996094 2155.732178 00:00
147 24993.705078 2163.889648 00:00
148 24959.103516 2160.813477 00:00
149 24938.197266 2184.463623 00:00
150 24928.759766 2208.087646 00:00
151 24952.160156 2183.072998 00:00
152 24904.560547 2152.320557 00:00
153 24928.025391 2148.575195 00:00
154 24910.185547 2180.729980 00:00
155 24906.388672 2190.656738 00:00
156 24894.894531 2188.920166 00:00
157 24917.083984 2187.139648 00:00
158 24913.609375 2164.842041 00:00
159 24910.617188 2181.503662 00:00
160 24929.462891 2176.875244 00:00
161 24892.683594 2165.759033 00:00
162 24886.013672 2176.412842 00:00
163 24859.917969 2169.091309 00:00
164 24824.279297 2174.040039 00:00
165 24779.080078 2175.631104 00:00
166 24752.855469 2163.416748 00:00
167 24766.738281 2162.556641 00:00
168 24743.207031 2170.221191 00:00
169 24757.509766 2179.925537 00:00
170 24730.193359 2180.827637 00:00
171 24694.998047 2178.066895 00:00
172 24731.380859 2183.089844 00:00
173 24739.914062 2178.639404 00:00
174 24698.880859 2157.802734 00:00
175 24667.474609 2175.934570 00:00
176 24669.558594 2158.381836 00:00
177 24659.394531 2152.438232 00:00
178 24637.953125 2153.844971 00:00
No improvement since epoch 153: early stopping

Save number of actually trained epochs

Hide code cell source

args.epoch_trained = analysis.learn.epoch + 1
args.epoch_trained
179

Loss normalized by total number of measurements#

Hide code cell source

N_train_notna = data.train_X.notna().sum().sum()
N_val_notna = data.val_y.notna().sum().sum()
fig = models.plot_training_losses(analysis.learn, args.model_key,
                                  folder=args.out_figures,
                                  norm_factors=[N_train_notna, N_val_notna])
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/figures/dae_training
_images/d9d67520cffb03511133a9291c664a99c8926c8c704fa443924592993a068f21.png

Why is the validation loss better then the training loss?

  • during training input data is masked and needs to be reconstructed

  • when evaluating the model, all input data is provided and only the artifically masked data is used for evaluation.

Predictions#

  • data of training data set and validation dataset to create predictions is the same as training data.

  • predictions include missing values (which are not further compared)

  • [ ] double check ModelAdapter

create predictiona and select for validation data

Hide code cell source

analysis.model.eval()
pred, target = analysis.get_preds_from_df(df_wide=data.train_X)  # train_X
pred = pred.stack()
pred
Sample ID   protein groups                                                                
Sample_000  A0A024QZX5;A0A087X1N8;P35237                                                     15.899
            A0A024R0T9;K7ER74;P02655                                                         17.145
            A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8   15.866
            A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503                                          16.833
            A0A075B6H7                                                                       17.127
                                                                                              ...  
Sample_209  Q9Y6R7                                                                           19.105
            Q9Y6X5                                                                           15.418
            Q9Y6Y8;Q9Y6Y8-2                                                                  19.320
            Q9Y6Y9                                                                           11.279
            S4R3U6                                                                           11.314
Length: 298410, dtype: float32

Hide code cell source

val_pred_simulated_na['DAE'] = pred  # model_key ?
val_pred_simulated_na
observed DAE
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630 15.590
Sample_050 Q9Y287 15.755 16.721
Sample_107 Q8N475;Q8N475-2 15.029 14.021
Sample_199 P06307 19.376 18.697
Sample_067 Q5VUB5 15.309 15.179
... ... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822 23.040
Sample_002 A0A0A0MT36 18.165 15.755
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525 15.822
Sample_182 Q8NFT8 14.379 13.557
Sample_123 Q16853;Q16853-2 14.504 14.539

12600 rows × 2 columns

Hide code cell source

test_pred_simulated_na['DAE'] = pred  # model_key?
test_pred_simulated_na
observed DAE
Sample ID protein groups
Sample_000 A0A075B6P5;P01615 17.016 17.333
A0A087X089;Q16627;Q16627-2 18.280 18.181
A0A0B4J2B5;S4R460 21.735 22.503
A0A140T971;O95865;Q5SRR8;Q5SSV3 14.603 14.975
A0A140TA33;A0A140TA41;A0A140TA52;P22105;P22105-3;P22105-4 16.143 16.502
... ... ... ...
Sample_209 Q96ID5 16.074 15.762
Q9H492;Q9H492-2 13.173 13.337
Q9HC57 14.207 13.647
Q9NPH3;Q9NPH3-2;Q9NPH3-5 14.962 15.302
Q9UGM5;Q9UGM5-2 16.871 16.437

12600 rows × 2 columns

save missing values predictions

Hide code cell source

if args.save_pred_real_na:
    pred_real_na = ae.get_missing_values(df_train_wide=data.train_X,
                                         val_idx=val_pred_simulated_na.index,
                                         test_idx=test_pred_simulated_na.index,
                                         pred=pred)
    display(pred_real_na)
    pred_real_na.to_csv(args.out_preds / f"pred_real_na_{args.model_key}.csv")
Sample ID   protein groups          
Sample_000  A0A075B6J9                 15.484
            A0A075B6Q5                 16.228
            A0A075B6R2                 16.833
            A0A075B6S5                 16.608
            A0A087WSY4                 16.692
                                        ...  
Sample_209  Q9P1W8;Q9P1W8-2;Q9P1W8-4   15.852
            Q9UI40;Q9UI40-2            15.779
            Q9UIW2                     16.935
            Q9UMX0;Q9UMX0-2;Q9UMX0-4   13.593
            Q9UP79                     15.852
Name: intensity, Length: 46401, dtype: float32

Plots#

  • validation data

Hide code cell source

analysis.model.cpu()
df_latent = pimmslearn.model.get_latent_space(analysis.model.encoder,
                                              dl=analysis.dls.valid,
                                              dl_index=analysis.dls.valid.data.index)
df_latent
latent dimension 1 latent dimension 2 latent dimension 3 latent dimension 4 latent dimension 5 latent dimension 6 latent dimension 7 latent dimension 8 latent dimension 9 latent dimension 10
Sample ID
Sample_000 -0.782 2.272 1.529 -1.820 3.445 -4.642 -2.089 -4.293 4.514 -2.005
Sample_001 1.507 2.337 1.301 -2.400 1.204 -2.746 -1.664 -5.468 1.968 -0.595
Sample_002 0.802 1.775 -1.060 -4.645 3.212 -3.437 2.748 -1.731 0.901 -1.300
Sample_003 0.589 1.469 1.387 -2.663 4.421 -2.469 -1.195 -2.821 4.020 -3.745
Sample_004 0.287 2.585 0.583 0.629 4.542 -3.030 -0.277 -4.292 1.397 -4.080
... ... ... ... ... ... ... ... ... ... ...
Sample_205 2.438 -3.252 -2.383 -0.899 1.628 -2.676 2.546 -1.251 1.805 -1.384
Sample_206 0.834 -1.689 0.257 2.392 2.193 -4.331 0.897 0.233 -1.465 3.045
Sample_207 0.753 -1.559 -0.893 5.454 -0.663 -5.042 1.841 -4.525 2.140 -1.346
Sample_208 6.162 -0.333 0.625 1.069 -1.741 -2.833 1.520 -2.213 2.286 1.657
Sample_209 2.328 0.950 3.727 1.128 1.853 -0.136 3.339 -0.066 0.923 -0.869

210 rows × 10 columns

Hide code cell source

# ! calculate embeddings only if meta data is available? Optional argument to save embeddings?
ana_latent = analyzers.LatentAnalysis(df_latent,
                                      df_meta,
                                      args.model_key,
                                      folder=args.out_figures)
if args.meta_date_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_date'], ax = ana_latent.plot_by_date(
        args.meta_date_col)

Hide code cell source

if args.meta_cat_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_{"_".join(args.meta_cat_col.split())}'], ax = ana_latent.plot_by_category(
        args.meta_cat_col)

Comparisons#

Simulated NAs : Artificially created NAs. Some data was sampled and set explicitly to misssing before it was fed to the model for reconstruction.

Validation data#

  • all measured (identified, observed) peptides in validation data

Hide code cell source

# papermill_description=metrics
d_metrics = models.Metrics()

The simulated NA for the validation step are real test data (not used for training nor early stopping)

Hide code cell source

added_metrics = d_metrics.add_metrics(val_pred_simulated_na, 'valid_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'DAE': {'MSE': 0.45735084631874146,
  'MAE': 0.4309430387022311,
  'N': 12600,
  'prop': 1.0}}

Test Datasplit#

Hide code cell source

added_metrics = d_metrics.add_metrics(test_pred_simulated_na, 'test_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'DAE': {'MSE': 0.4742600729134556,
  'MAE': 0.4345407509047897,
  'N': 12600,
  'prop': 1.0}}

Save all metrics as json

Hide code cell source

pimmslearn.io.dump_json(d_metrics.metrics, args.out_metrics /
                        f'metrics_{args.model_key}.json')
d_metrics
{ 'test_simulated_na': { 'DAE': { 'MAE': 0.4345407509047897,
                                  'MSE': 0.4742600729134556,
                                  'N': 12600,
                                  'prop': 1.0}},
  'valid_simulated_na': { 'DAE': { 'MAE': 0.4309430387022311,
                                   'MSE': 0.45735084631874146,
                                   'N': 12600,
                                   'prop': 1.0}}}

Hide code cell source

metrics_df = models.get_df_from_nested_dict(d_metrics.metrics,
                                            column_levels=['model', 'metric_name']).T
metrics_df
subset valid_simulated_na test_simulated_na
model metric_name
DAE MSE 0.457 0.474
MAE 0.431 0.435
N 12,600.000 12,600.000
prop 1.000 1.000

Save predictions#

Hide code cell source

# save simulated missing values for both splits
val_pred_simulated_na.to_csv(args.out_preds / f"pred_val_{args.model_key}.csv")
test_pred_simulated_na.to_csv(args.out_preds / f"pred_test_{args.model_key}.csv")

Config#

Hide code cell source

figures  # switch to fnames?
{}

Hide code cell source

args.dump(fname=args.out_models / f"model_config_{args.model_key}.yaml")
args
{'M': 1421,
 'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epoch_trained': 179,
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'DAE',
 'model_key': 'DAE',
 'n_params': 184983,
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 25,
 'sample_idx_position': 0,
 'save_pred_real_na': True}