Denoising Autoencoder#

Hide code cell source

import logging

import sklearn
from fastai import learner
from fastai.basics import *
from fastai.callback.all import *
from fastai.torch_basics import *
from IPython.display import display
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

import pimmslearn
import pimmslearn.model
import pimmslearn.models as models
from pimmslearn.analyzers import analyzers
from pimmslearn.io import datasplits
# overwriting Recorder callback with custom plot_loss
from pimmslearn.models import ae, plot_loss

learner.Recorder.plot_loss = plot_loss


logger = pimmslearn.logging.setup_logger(logging.getLogger('pimmslearn'))
logger.info(
    "Experiment 03 - Analysis of latent spaces and performance comparisions")

figures = {}  # collection of ax or figures
pimmslearn - INFO     Experiment 03 - Analysis of latent spaces and performance comparisions

Hide code cell source

# catch passed parameters
args = None
args = dict(globals()).keys()

Papermill script parameters:

# files and folders
# Datasplit folder with data for experiment
folder_experiment: str = 'runs/example'
folder_data: str = ''  # specify data directory if needed
file_format: str = 'csv'  # file format of create splits, default pickle (pkl)
# Machine parsed metadata from rawfile workflow
fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv'
# training
epochs_max: int = 50  # Maximum number of epochs
# early_stopping:bool = True # Wheather to use early stopping or not
patience: int = 25  # Patience for early stopping
batch_size: int = 64  # Batch size for training (and evaluation)
cuda: bool = True  # Whether to use a GPU for training
# model
# Dimensionality of encoding dimension (latent space of model)
latent_dim: int = 25
# A underscore separated string of layers, '128_64' for the encoder, reverse will be use for decoder
hidden_layers: str = '512'

sample_idx_position: int = 0  # position of index which is sample ID
model: str = 'DAE'  # model name
model_key: str = 'DAE'  # potentially alternative key for model (grid search)
save_pred_real_na: bool = True  # Save all predictions for missing values
# metadata -> defaults for metadata extracted from machine data
meta_date_col: str = None  # date column in meta data
meta_cat_col: str = None  # category column in meta data
# Parameters
model = "DAE"
latent_dim = 10
batch_size = 64
epochs_max = 300
hidden_layers = "64"
sample_idx_position = 0
cuda = False
save_pred_real_na = True
fn_rawfile_metadata = "https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv"
folder_experiment = "runs/alzheimer_study"
model_key = "DAE"

Some argument transformations

Hide code cell source

args = pimmslearn.nb.get_params(args, globals=globals())
args
{'folder_experiment': 'runs/alzheimer_study',
 'folder_data': '',
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'epochs_max': 300,
 'patience': 25,
 'batch_size': 64,
 'cuda': False,
 'latent_dim': 10,
 'hidden_layers': '64',
 'sample_idx_position': 0,
 'model': 'DAE',
 'model_key': 'DAE',
 'save_pred_real_na': True,
 'meta_date_col': None,
 'meta_cat_col': None}

Hide code cell source

args = pimmslearn.nb.args_from_dict(args)

if isinstance(args.hidden_layers, str):
    args.overwrite_entry("hidden_layers", [int(x)
                         for x in args.hidden_layers.split('_')])
else:
    raise ValueError(
        f"hidden_layers is of unknown type {type(args.hidden_layers)}")
args
{'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'DAE',
 'model_key': 'DAE',
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 25,
 'sample_idx_position': 0,
 'save_pred_real_na': True}

Some naming conventions

Hide code cell source

TEMPLATE_MODEL_PARAMS = 'model_params_{}.json'

Load data in long format#

Hide code cell source

data = datasplits.DataSplits.from_folder(
    args.data, file_format=args.file_format)
pimmslearn.io.datasplits - INFO     Loaded 'train_X' from file: runs/alzheimer_study/data/train_X.csv
pimmslearn.io.datasplits - INFO     Loaded 'val_y' from file: runs/alzheimer_study/data/val_y.csv
pimmslearn.io.datasplits - INFO     Loaded 'test_y' from file: runs/alzheimer_study/data/test_y.csv

data is loaded in long format

Hide code cell source

data.train_X.sample(5)
Sample ID   protein groups                   
Sample_172  P53634                              16.782
Sample_068  A0A087WYK9;Q02985;Q02985-2;Q6NSD3   16.619
Sample_172  Q9UPU3                              18.652
Sample_144  P10451-5                            19.524
Sample_161  P01817                              15.385
Name: intensity, dtype: float64

Infer index names from long format

Hide code cell source

index_columns = list(data.train_X.index.names)
sample_id = index_columns.pop(args.sample_idx_position)
if len(index_columns) == 1:
    index_column = index_columns.pop()
    index_columns = None
    logger.info(f"{sample_id = }, single feature: {index_column = }")
else:
    logger.info(f"{sample_id = }, multiple features: {index_columns = }")

if not index_columns:
    index_columns = [sample_id, index_column]
else:
    raise NotImplementedError(
        "More than one feature: Needs to be implemented. see above logging output.")
pimmslearn - INFO     sample_id = 'Sample ID', single feature: index_column = 'protein groups'

load meta data for splits

Hide code cell source

if args.fn_rawfile_metadata:
    df_meta = pd.read_csv(args.fn_rawfile_metadata, index_col=0)
    display(df_meta.loc[data.train_X.index.levels[0]])
else:
    df_meta = None
_collection site _age at CSF collection _gender _t-tau [ng/L] _p-tau [ng/L] _Abeta-42 [ng/L] _Abeta-40 [ng/L] _Abeta-42/Abeta-40 ratio _primary biochemical AD classification _clinical AD diagnosis _MMSE score
Sample ID
Sample_000 Sweden 71.000 f 703.000 85.000 562.000 NaN NaN biochemical control NaN NaN
Sample_001 Sweden 77.000 m 518.000 91.000 334.000 NaN NaN biochemical AD NaN NaN
Sample_002 Sweden 75.000 m 974.000 87.000 515.000 NaN NaN biochemical AD NaN NaN
Sample_003 Sweden 72.000 f 950.000 109.000 394.000 NaN NaN biochemical AD NaN NaN
Sample_004 Sweden 63.000 f 873.000 88.000 234.000 NaN NaN biochemical AD NaN NaN
... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 Berlin 69.000 f 1,945.000 NaN 699.000 12,140.000 0.058 biochemical AD AD 17.000
Sample_206 Berlin 73.000 m 299.000 NaN 1,420.000 16,571.000 0.086 biochemical control non-AD 28.000
Sample_207 Berlin 71.000 f 262.000 NaN 639.000 9,663.000 0.066 biochemical control non-AD 28.000
Sample_208 Berlin 83.000 m 289.000 NaN 1,436.000 11,285.000 0.127 biochemical control non-AD 24.000
Sample_209 Berlin 63.000 f 591.000 NaN 1,299.000 11,232.000 0.116 biochemical control non-AD 29.000

210 rows × 11 columns

Produce some addional simulated samples#

The validation simulated NA is used to by all models to evaluate training performance.

Hide code cell source

val_pred_simulated_na = data.val_y.to_frame(name='observed')
val_pred_simulated_na
observed
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630
Sample_050 Q9Y287 15.755
Sample_107 Q8N475;Q8N475-2 15.029
Sample_199 P06307 19.376
Sample_067 Q5VUB5 15.309
... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822
Sample_002 A0A0A0MT36 18.165
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525
Sample_182 Q8NFT8 14.379
Sample_123 Q16853;Q16853-2 14.504

12600 rows × 1 columns

Hide code cell source

test_pred_simulated_na = data.test_y.to_frame(name='observed')
test_pred_simulated_na.describe()
observed
count 12,600.000
mean 16.339
std 2.741
min 7.209
25% 14.412
50% 15.935
75% 17.910
max 30.140

Data in wide format#

  • Autoencoder need data in wide format

Hide code cell source

data.to_wide_format()
args.M = data.train_X.shape[-1]
data.train_X.head()
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN

5 rows × 1421 columns

Fill Validation data with potentially missing features#

Hide code cell source

data.train_X
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 NaN 15.050 16.842 NaN NaN 19.563 NaN 12.805
Sample_001 NaN 16.874 15.519 16.387 NaN 19.941 18.786 17.144 NaN 19.067 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 NaN 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 NaN 12.445
Sample_004 15.603 15.331 15.375 16.679 NaN 20.450 18.682 17.081 14.140 19.686 ... 14.757 NaN NaN 15.256 17.075 19.582 15.328 NaN 13.145 NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 NaN 17.705 17.039 NaN 16.413 19.102 ... NaN 15.684 14.236 15.415 17.551 17.922 16.340 19.928 12.929 NaN
Sample_206 15.798 17.554 15.600 15.938 NaN 18.154 18.152 16.503 16.860 18.538 ... 15.422 16.106 NaN 15.345 17.084 18.708 NaN 19.433 NaN NaN
Sample_207 15.739 NaN 15.469 16.898 NaN 18.636 17.950 16.321 16.401 18.849 ... 15.808 16.098 14.403 15.715 NaN 18.725 16.138 19.599 13.637 11.174
Sample_208 15.477 16.779 14.995 16.132 NaN 14.908 NaN NaN 16.119 18.368 ... 15.157 16.712 NaN 14.640 16.533 19.411 15.807 19.545 NaN NaN
Sample_209 NaN 17.261 15.175 16.235 NaN 17.893 17.744 16.371 15.780 18.806 ... 15.237 15.652 15.211 14.205 16.749 19.275 15.732 19.577 11.042 11.791

210 rows × 1421 columns

Hide code cell source

data.val_y  # potentially has less features
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1419 columns

Hide code cell source

data.val_y = pd.DataFrame(pd.NA, index=data.train_X.index,
                          columns=data.train_X.columns).fillna(data.val_y)
data.val_y
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN 19.863 NaN NaN NaN NaN
Sample_001 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_002 NaN 14.523 NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_003 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_004 NaN NaN NaN NaN 15.473 NaN NaN NaN NaN NaN ... NaN NaN 14.048 NaN NaN NaN NaN 19.867 NaN 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 11.802
Sample_206 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_207 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_208 NaN NaN NaN NaN NaN NaN 17.530 NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
Sample_209 15.727 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

210 rows × 1421 columns

Denoising Autoencoder#

Analysis: DataLoaders, Model, transform#

Hide code cell source

default_pipeline = sklearn.pipeline.Pipeline(
    [
        ('normalize', StandardScaler()),
        ('impute', SimpleImputer(add_indicator=False))
    ])

analysis = ae.AutoEncoderAnalysis(
    train_df=data.train_X,
    val_df=data.val_y,
    model=ae.Autoencoder,
    transform=default_pipeline,
    decode=['normalize'],
    model_kwargs=dict(n_features=data.train_X.shape[-1],
                      n_neurons=args.hidden_layers,
                      last_decoder_activation=None,
                      dim_latent=args.latent_dim),
    bs=args.batch_size)
args.n_params = analysis.n_params_ae

if args.cuda:
    analysis.model = analysis.model.cuda()
analysis.model
Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=1421, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=10, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=64, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
    (4): Linear(in_features=64, out_features=1421, bias=True)
  )
)

Training#

Hide code cell source

analysis.learn = Learner(dls=analysis.dls,
                         model=analysis.model,
                         loss_func=MSELossFlat(reduction='sum'),
                         cbs=[EarlyStoppingCallback(patience=args.patience),
                              ae.ModelAdapter(p=0.2)]
                         )

analysis.learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback, EarlyStoppingCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [ModelAdapter, CastToTensor]
         - after_pred     : [ModelAdapter]
         - after_loss     : [ModelAdapter]
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder, EarlyStoppingCallback]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback, EarlyStoppingCallback]

Adding a EarlyStoppingCallback results in an error. Potential fix in PR3509 is not yet in current version. Try again later

Hide code cell source

# learn.summary()

Hide code cell source

suggested_lr = analysis.learn.lr_find()
analysis.params['suggested_inital_lr'] = suggested_lr.valley
suggested_lr
SuggestedLRs(valley=0.015848932787775993)
_images/95d7e05a6ad4a4739e171ad414a3a99281accfdcd1c2c895938c344dfe5c5cd1.png

dump model config

Hide code cell source

pimmslearn.io.dump_json(analysis.params, args.out_models /
                        TEMPLATE_MODEL_PARAMS.format(args.model_key))

Hide code cell source

# papermill_description=train
analysis.learn.fit_one_cycle(args.epochs_max, lr_max=suggested_lr.valley)
epoch train_loss valid_loss time
0 65189.332031 4065.009033 00:00
1 63970.929688 4073.339355 00:00
2 62738.941406 4046.097900 00:00
3 61545.886719 3982.603516 00:00
4 60375.144531 3881.793701 00:00
5 59170.667969 3760.335449 00:00
6 57996.183594 3628.615967 00:00
7 56775.031250 3494.257080 00:00
8 55579.671875 3364.098145 00:00
9 54338.664062 3249.452637 00:00
10 53160.007812 3153.119385 00:00
11 51986.457031 3067.656250 00:00
12 50857.687500 3002.722900 00:00
13 49810.578125 2947.004150 00:00
14 48791.207031 2878.490479 00:00
15 47812.417969 2814.804688 00:00
16 46881.757812 2750.369873 00:00
17 45990.117188 2702.312500 00:00
18 45146.527344 2654.875000 00:00
19 44375.152344 2621.254150 00:00
20 43638.242188 2584.091064 00:00
21 42901.175781 2559.164551 00:00
22 42242.066406 2536.586670 00:00
23 41581.832031 2511.079102 00:00
24 40969.992188 2497.289307 00:00
25 40374.402344 2475.123779 00:00
26 39790.597656 2456.556885 00:00
27 39230.460938 2424.583984 00:00
28 38685.246094 2403.708252 00:00
29 38147.367188 2381.049072 00:00
30 37676.078125 2371.700928 00:00
31 37175.718750 2357.399902 00:00
32 36696.914062 2349.020020 00:00
33 36259.953125 2348.249023 00:00
34 35844.031250 2341.917480 00:00
35 35464.050781 2325.163330 00:00
36 35095.621094 2340.296631 00:00
37 34718.742188 2347.518066 00:00
38 34372.738281 2312.507812 00:00
39 34058.503906 2307.323486 00:00
40 33731.542969 2296.672119 00:00
41 33409.183594 2302.800781 00:00
42 33101.085938 2309.259277 00:00
43 32836.953125 2310.265869 00:00
44 32562.789062 2280.129883 00:00
45 32289.310547 2277.523193 00:00
46 32056.732422 2297.179443 00:00
47 31818.916016 2281.912598 00:00
48 31587.218750 2276.459717 00:00
49 31336.210938 2250.835449 00:00
50 31123.798828 2261.320801 00:00
51 30911.251953 2265.617188 00:00
52 30709.378906 2287.955566 00:00
53 30528.097656 2267.725342 00:00
54 30352.703125 2295.426025 00:00
55 30143.472656 2269.196045 00:00
56 29963.767578 2265.019775 00:00
57 29800.945312 2283.927734 00:00
58 29627.294922 2265.821045 00:00
59 29552.359375 2222.951416 00:00
60 29456.718750 2280.027100 00:00
61 29349.886719 2254.399170 00:00
62 29216.562500 2275.485352 00:00
63 29062.291016 2261.553467 00:00
64 28973.205078 2247.205322 00:00
65 28861.738281 2226.112305 00:00
66 28738.486328 2239.219971 00:00
67 28612.404297 2232.740234 00:00
68 28481.646484 2222.366455 00:00
69 28355.685547 2213.064697 00:00
70 28273.449219 2234.936523 00:00
71 28183.482422 2269.694580 00:00
72 28049.103516 2268.020752 00:00
73 27988.853516 2264.024902 00:00
74 27924.660156 2239.363281 00:00
75 27831.146484 2296.860352 00:00
76 27807.638672 2222.490479 00:00
77 27719.265625 2246.213623 00:00
78 27658.500000 2220.362061 00:00
79 27597.330078 2228.690430 00:00
80 27561.689453 2227.278564 00:00
81 27462.822266 2207.621826 00:00
82 27348.320312 2229.434814 00:00
83 27266.439453 2229.196533 00:00
84 27219.341797 2223.225586 00:00
85 27141.863281 2216.499512 00:00
86 27085.958984 2204.550049 00:00
87 27020.142578 2188.401123 00:00
88 27000.318359 2185.473633 00:00
89 26931.087891 2186.194824 00:00
90 26872.074219 2249.486572 00:00
91 26838.744141 2198.011230 00:00
92 26764.652344 2179.640869 00:00
93 26699.521484 2195.164551 00:00
94 26651.605469 2216.068848 00:00
95 26577.832031 2203.488770 00:00
96 26520.025391 2220.129395 00:00
97 26479.972656 2199.240967 00:00
98 26446.634766 2183.968506 00:00
99 26427.878906 2240.149658 00:00
100 26370.634766 2223.751221 00:00
101 26349.404297 2184.207275 00:00
102 26351.498047 2207.575195 00:00
103 26300.734375 2229.634766 00:00
104 26263.832031 2192.597412 00:00
105 26262.826172 2223.088623 00:00
106 26219.982422 2202.072266 00:00
107 26189.517578 2172.981934 00:00
108 26200.792969 2200.256836 00:00
109 26207.058594 2219.857422 00:00
110 26163.671875 2211.435303 00:00
111 26140.054688 2222.508301 00:00
112 26103.554688 2186.026123 00:00
113 26048.662109 2194.691162 00:00
114 25961.123047 2208.794678 00:00
115 25908.806641 2221.994629 00:00
116 25880.289062 2197.366455 00:00
117 25867.888672 2186.518066 00:00
118 25855.978516 2184.555908 00:00
119 25855.585938 2209.547119 00:00
120 25805.609375 2206.538086 00:00
121 25784.732422 2193.061768 00:00
122 25713.316406 2200.547852 00:00
123 25694.386719 2192.185791 00:00
124 25645.117188 2183.558594 00:00
125 25662.468750 2194.983643 00:00
126 25654.291016 2194.555176 00:00
127 25657.900391 2224.259033 00:00
128 25626.945312 2190.095215 00:00
129 25611.156250 2177.753418 00:00
130 25593.054688 2198.444580 00:00
131 25563.583984 2206.655273 00:00
132 25513.742188 2188.066895 00:00
No improvement since epoch 107: early stopping

Save number of actually trained epochs

Hide code cell source

args.epoch_trained = analysis.learn.epoch + 1
args.epoch_trained
133

Loss normalized by total number of measurements#

Hide code cell source

N_train_notna = data.train_X.notna().sum().sum()
N_val_notna = data.val_y.notna().sum().sum()
fig = models.plot_training_losses(analysis.learn, args.model_key,
                                  folder=args.out_figures,
                                  norm_factors=[N_train_notna, N_val_notna])
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/figures/dae_training
_images/5aa2ff57a48aae78834039b3c96b929b1477200a0fee4be69e7cc6193838f2a0.png

Why is the validation loss better then the training loss?

  • during training input data is masked and needs to be reconstructed

  • when evaluating the model, all input data is provided and only the artifically masked data is used for evaluation.

Predictions#

  • data of training data set and validation dataset to create predictions is the same as training data.

  • predictions include missing values (which are not further compared)

  • [ ] double check ModelAdapter

create predictiona and select for validation data

Hide code cell source

analysis.model.eval()
pred, target = analysis.get_preds_from_df(df_wide=data.train_X)  # train_X
pred = pred.stack()
pred
Sample ID   protein groups                                                                
Sample_000  A0A024QZX5;A0A087X1N8;P35237                                                     16.002
            A0A024R0T9;K7ER74;P02655                                                         16.601
            A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8   15.901
            A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503                                          16.829
            A0A075B6H7                                                                       16.929
                                                                                              ...  
Sample_209  Q9Y6R7                                                                           19.203
            Q9Y6X5                                                                           15.723
            Q9Y6Y8;Q9Y6Y8-2                                                                  19.493
            Q9Y6Y9                                                                           11.193
            S4R3U6                                                                           11.386
Length: 298410, dtype: float32

Hide code cell source

val_pred_simulated_na['DAE'] = pred  # model_key ?
val_pred_simulated_na
observed DAE
Sample ID protein groups
Sample_158 Q9UN70;Q9UN70-2 14.630 15.686
Sample_050 Q9Y287 15.755 16.760
Sample_107 Q8N475;Q8N475-2 15.029 14.591
Sample_199 P06307 19.376 18.837
Sample_067 Q5VUB5 15.309 15.094
... ... ... ...
Sample_111 F6SYF8;Q9UBP4 22.822 22.872
Sample_002 A0A0A0MT36 18.165 16.226
Sample_049 Q8WY21;Q8WY21-2;Q8WY21-3;Q8WY21-4 15.525 15.621
Sample_182 Q8NFT8 14.379 13.961
Sample_123 Q16853;Q16853-2 14.504 14.513

12600 rows × 2 columns

Hide code cell source

test_pred_simulated_na['DAE'] = pred  # model_key?
test_pred_simulated_na
observed DAE
Sample ID protein groups
Sample_000 A0A075B6P5;P01615 17.016 17.051
A0A087X089;Q16627;Q16627-2 18.280 17.982
A0A0B4J2B5;S4R460 21.735 22.380
A0A140T971;O95865;Q5SRR8;Q5SSV3 14.603 14.978
A0A140TA33;A0A140TA41;A0A140TA52;P22105;P22105-3;P22105-4 16.143 16.577
... ... ... ...
Sample_209 Q96ID5 16.074 15.935
Q9H492;Q9H492-2 13.173 13.489
Q9HC57 14.207 13.757
Q9NPH3;Q9NPH3-2;Q9NPH3-5 14.962 15.254
Q9UGM5;Q9UGM5-2 16.871 16.364

12600 rows × 2 columns

save missing values predictions

Hide code cell source

if args.save_pred_real_na:
    pred_real_na = ae.get_missing_values(df_train_wide=data.train_X,
                                         val_idx=val_pred_simulated_na.index,
                                         test_idx=test_pred_simulated_na.index,
                                         pred=pred)
    display(pred_real_na)
    pred_real_na.to_csv(args.out_preds / f"pred_real_na_{args.model_key}.csv")
Sample ID   protein groups          
Sample_000  A0A075B6J9                 15.551
            A0A075B6Q5                 16.444
            A0A075B6R2                 16.475
            A0A075B6S5                 16.235
            A0A087WSY4                 16.707
                                        ...  
Sample_209  Q9P1W8;Q9P1W8-2;Q9P1W8-4   15.985
            Q9UI40;Q9UI40-2            16.272
            Q9UIW2                     16.648
            Q9UMX0;Q9UMX0-2;Q9UMX0-4   13.432
            Q9UP79                     15.823
Name: intensity, Length: 46401, dtype: float32

Plots#

  • validation data

Hide code cell source

analysis.model.cpu()
df_latent = pimmslearn.model.get_latent_space(analysis.model.encoder,
                                              dl=analysis.dls.valid,
                                              dl_index=analysis.dls.valid.data.index)
df_latent
latent dimension 1 latent dimension 2 latent dimension 3 latent dimension 4 latent dimension 5 latent dimension 6 latent dimension 7 latent dimension 8 latent dimension 9 latent dimension 10
Sample ID
Sample_000 2.955 -4.229 -2.044 1.863 -0.252 -0.927 -1.193 3.173 0.656 -0.073
Sample_001 1.797 -4.163 -1.219 -0.205 -0.441 0.337 -0.335 1.236 0.423 1.435
Sample_002 0.274 -3.224 -3.118 -2.083 3.095 -0.212 0.665 3.830 -1.946 -2.671
Sample_003 1.641 -3.175 -3.978 2.153 -0.342 0.490 -0.798 3.122 0.186 -0.534
Sample_004 0.360 -3.812 -2.570 1.657 0.868 -0.597 -1.617 1.482 1.071 -1.230
... ... ... ... ... ... ... ... ... ... ...
Sample_205 -1.552 0.477 0.987 0.653 2.789 -0.325 -1.672 3.013 -1.320 -2.326
Sample_206 5.264 0.335 3.767 -2.156 1.896 2.804 -2.063 0.780 1.554 -2.718
Sample_207 1.384 -4.817 4.881 1.635 0.847 -0.102 -1.043 1.881 0.209 -6.002
Sample_208 -1.577 -2.423 3.580 0.267 2.188 1.808 0.848 -0.450 1.776 -2.244
Sample_209 0.685 -1.645 1.296 1.548 2.147 5.799 3.832 1.854 0.759 -2.740

210 rows × 10 columns

Hide code cell source

# ! calculate embeddings only if meta data is available? Optional argument to save embeddings?
ana_latent = analyzers.LatentAnalysis(df_latent,
                                      df_meta,
                                      args.model_key,
                                      folder=args.out_figures)
if args.meta_date_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_date'], ax = ana_latent.plot_by_date(
        args.meta_date_col)

Hide code cell source

if args.meta_cat_col and df_meta is not None:
    figures[f'latent_{args.model_key}_by_{"_".join(args.meta_cat_col.split())}'], ax = ana_latent.plot_by_category(
        args.meta_cat_col)

Comparisons#

Simulated NAs : Artificially created NAs. Some data was sampled and set explicitly to misssing before it was fed to the model for reconstruction.

Validation data#

  • all measured (identified, observed) peptides in validation data

Hide code cell source

# papermill_description=metrics
d_metrics = models.Metrics()

The simulated NA for the validation step are real test data (not used for training nor early stopping)

Hide code cell source

added_metrics = d_metrics.add_metrics(val_pred_simulated_na, 'valid_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'DAE': {'MSE': 0.4578401596479822,
  'MAE': 0.433157106839031,
  'N': 12600,
  'prop': 1.0}}

Test Datasplit#

Hide code cell source

added_metrics = d_metrics.add_metrics(test_pred_simulated_na, 'test_simulated_na')
added_metrics
Selected as truth to compare to: observed
{'DAE': {'MSE': 0.4764879431798487,
  'MAE': 0.4383871249810212,
  'N': 12600,
  'prop': 1.0}}

Save all metrics as json

Hide code cell source

pimmslearn.io.dump_json(d_metrics.metrics, args.out_metrics /
                        f'metrics_{args.model_key}.json')
d_metrics
{ 'test_simulated_na': { 'DAE': { 'MAE': 0.4383871249810212,
                                  'MSE': 0.4764879431798487,
                                  'N': 12600,
                                  'prop': 1.0}},
  'valid_simulated_na': { 'DAE': { 'MAE': 0.433157106839031,
                                   'MSE': 0.4578401596479822,
                                   'N': 12600,
                                   'prop': 1.0}}}

Hide code cell source

metrics_df = models.get_df_from_nested_dict(d_metrics.metrics,
                                            column_levels=['model', 'metric_name']).T
metrics_df
subset valid_simulated_na test_simulated_na
model metric_name
DAE MSE 0.458 0.476
MAE 0.433 0.438
N 12,600.000 12,600.000
prop 1.000 1.000

Save predictions#

Hide code cell source

# save simulated missing values for both splits
val_pred_simulated_na.to_csv(args.out_preds / f"pred_val_{args.model_key}.csv")
test_pred_simulated_na.to_csv(args.out_preds / f"pred_test_{args.model_key}.csv")

Config#

Hide code cell source

figures  # switch to fnames?
{}

Hide code cell source

args.dump(fname=args.out_models / f"model_config_{args.model_key}.yaml")
args
{'M': 1421,
 'batch_size': 64,
 'cuda': False,
 'data': Path('runs/alzheimer_study/data'),
 'epoch_trained': 133,
 'epochs_max': 300,
 'file_format': 'csv',
 'fn_rawfile_metadata': 'https://raw.githubusercontent.com/RasmussenLab/njab/HEAD/docs/tutorial/data/alzheimer/meta.csv',
 'folder_data': '',
 'folder_experiment': Path('runs/alzheimer_study'),
 'hidden_layers': [64],
 'latent_dim': 10,
 'meta_cat_col': None,
 'meta_date_col': None,
 'model': 'DAE',
 'model_key': 'DAE',
 'n_params': 184983,
 'out_figures': Path('runs/alzheimer_study/figures'),
 'out_folder': Path('runs/alzheimer_study'),
 'out_metrics': Path('runs/alzheimer_study'),
 'out_models': Path('runs/alzheimer_study'),
 'out_preds': Path('runs/alzheimer_study/preds'),
 'patience': 25,
 'sample_idx_position': 0,
 'save_pred_real_na': True}