Fit logistic regression model#

  • based on different imputation methods

  • baseline: reference

  • model: any other selected imputation method

Hide code cell source

import logging
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import njab.sklearn
import pandas as pd
import sklearn
from njab.plotting.metrics import plot_split_auc, plot_split_prc
from njab.sklearn.types import Splits

import pimmslearn
import pimmslearn.analyzers
import pimmslearn.io.datasplits

plt.rcParams['figure.figsize'] = (2.5, 2.5)
plt.rcParams['lines.linewidth'] = 1
plt.rcParams['lines.markersize'] = 2
fontsize = 5
figsize = (2.5, 2.5)
pimmslearn.plotting.make_large_descriptors(fontsize)


logger = pimmslearn.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.ERROR)


def parse_roc(*res: List[njab.sklearn.types.Results]) -> pd.DataFrame:
    ret = list()
    for _r in res:
        _roc = (pd.DataFrame(_r.test.roc,
                             index='fpr tpr cutoffs'.split()
                             )).loc[['fpr', 'tpr']]
        _roc = _roc.T
        _roc.columns = pd.MultiIndex.from_product([[_r.name], _roc.columns])
        ret.append(_roc)
    ret = pd.concat(ret, axis=1)
    return ret


def parse_prc(*res: List[njab.sklearn.types.Results]) -> pd.DataFrame:
    ret = list()
    for _r in res:
        _prc = pd.DataFrame(_r.test.prc,
                            index='precision recall cutoffs'.split()
                            ).loc[['precision', 'recall']]
        _prc = _prc.T.rename(columns={'recall': 'tpr'})
        _prc.columns = pd.MultiIndex.from_product([[_r.name], _prc.columns])
        ret.append(_prc)
    ret = pd.concat(ret, axis=1)
    return ret


# catch passed parameters
args = None
args = dict(globals()).keys()

Parameters#

Default and set parameters for the notebook.

folder_data: str = ''  # specify data directory if needed
fn_clinical_data = "data/ALD_study/processed/ald_metadata_cli.csv"
folder_experiment = "runs/appl_ald_data/plasma/proteinGroups"
model_key = 'VAE'
target = 'kleiner'
sample_id_col = 'Sample ID'
cutoff_target: int = 2  # => for binarization target >= cutoff_target
file_format = "csv"
out_folder = 'diff_analysis'
fn_qc_samples = ''  # 'data/ALD_study/processed/qc_plasma_proteinGroups.pkl'

baseline = 'RSN'  # default is RSN, as this was used in the original ALD Niu. et. al 2022
template_pred = 'pred_real_na_{}.csv'  # fixed, do not change
# Parameters
cutoff_target = 0.5
folder_experiment = "runs/alzheimer_study"
target = "AD"
baseline = "PI"
model_key = "DAE"
out_folder = "diff_analysis"
fn_clinical_data = "runs/alzheimer_study/data/clinical_data.csv"

Hide code cell source

params = pimmslearn.nb.get_params(args, globals=globals())
args = pimmslearn.nb.Config()
args.folder_experiment = Path(params["folder_experiment"])
args = pimmslearn.nb.add_default_paths(args,
                                 out_root=(args.folder_experiment
                                           / params["out_folder"]
                                           / params["target"]
                                           / f"{params['baseline']}_vs_{params['model_key']}"))
args.update_from_dict(params)
files_out = dict()
args
root - INFO     Removed from global namespace: folder_data
root - INFO     Removed from global namespace: fn_clinical_data
root - INFO     Removed from global namespace: folder_experiment
root - INFO     Removed from global namespace: model_key
root - INFO     Removed from global namespace: target
root - INFO     Removed from global namespace: sample_id_col
root - INFO     Removed from global namespace: cutoff_target
root - INFO     Removed from global namespace: file_format
root - INFO     Removed from global namespace: out_folder
root - INFO     Removed from global namespace: fn_qc_samples
root - INFO     Removed from global namespace: baseline
root - INFO     Removed from global namespace: template_pred
root - INFO     Already set attribute: folder_experiment has value runs/alzheimer_study
root - INFO     Already set attribute: out_folder has value diff_analysis
{'baseline': 'PI',
 'cutoff_target': 0.5,
 'data': PosixPath('runs/alzheimer_study/data'),
 'file_format': 'csv',
 'fn_clinical_data': 'runs/alzheimer_study/data/clinical_data.csv',
 'fn_qc_samples': '',
 'folder_data': '',
 'folder_experiment': PosixPath('runs/alzheimer_study'),
 'model_key': 'DAE',
 'out_figures': PosixPath('runs/alzheimer_study/figures'),
 'out_folder': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE'),
 'out_metrics': PosixPath('runs/alzheimer_study'),
 'out_models': PosixPath('runs/alzheimer_study'),
 'out_preds': PosixPath('runs/alzheimer_study/preds'),
 'sample_id_col': 'Sample ID',
 'target': 'AD',
 'template_pred': 'pred_real_na_{}.csv'}

Load data#

Load target#

target = pd.read_csv(args.fn_clinical_data,
                     index_col=0,
                     usecols=[args.sample_id_col, args.target])
target = target.dropna()
target
AD
Sample ID
Sample_000 0
Sample_001 1
Sample_002 1
Sample_003 1
Sample_004 1
... ...
Sample_205 1
Sample_206 0
Sample_207 0
Sample_208 0
Sample_209 0

210 rows × 1 columns

MS proteomics or specified omics data#

Aggregated from data splits of the imputation workflow run before.

Hide code cell source

data = pimmslearn.io.datasplits.DataSplits.from_folder(
    args.data, file_format=args.file_format)
data = pd.concat([data.train_X, data.val_y, data.test_y])
data.sample(5)
pimmslearn.io.datasplits - INFO     Loaded 'train_X' from file: runs/alzheimer_study/data/train_X.csv
pimmslearn.io.datasplits - INFO     Loaded 'val_y' from file: runs/alzheimer_study/data/val_y.csv
pimmslearn.io.datasplits - INFO     Loaded 'test_y' from file: runs/alzheimer_study/data/test_y.csv
Sample ID   protein groups          
Sample_101  Q99538                     17.740
Sample_153  O60512                     12.061
Sample_036  Q9ULF5                     17.128
Sample_139  P32004;P32004-2;P32004-3   15.988
Sample_171  Q96JF0                     18.118
Name: intensity, dtype: float64

Get overlap between independent features and target

Select by ALD criteria#

Use parameters as specified in ALD study.

Hide code cell source

DATA_COMPLETENESS = 0.6
MIN_N_PROTEIN_GROUPS: int = 200
FRAC_PROTEIN_GROUPS: int = 0.622
CV_QC_SAMPLE: float = 0.4

ald_study, cutoffs = pimmslearn.analyzers.diff_analysis.select_raw_data(data.unstack(
), data_completeness=DATA_COMPLETENESS, frac_protein_groups=FRAC_PROTEIN_GROUPS)

if args.fn_qc_samples:
    qc_samples = pd.read_pickle(args.fn_qc_samples)
    qc_samples = qc_samples[ald_study.columns]
    qc_cv_feat = qc_samples.std() / qc_samples.mean()
    qc_cv_feat = qc_cv_feat.rename(qc_samples.columns.name)
    fig, ax = plt.subplots(figsize=(4, 7))
    ax = qc_cv_feat.plot.box(ax=ax)
    ax.set_ylabel('Coefficient of Variation')
    print((qc_cv_feat < CV_QC_SAMPLE).value_counts())
    ald_study = ald_study[pimmslearn.analyzers.diff_analysis.select_feat(qc_samples)]

column_name_first_prot_to_pg = {
    pg.split(';')[0]: pg for pg in data.unstack().columns}

ald_study = ald_study.rename(columns=column_name_first_prot_to_pg)
ald_study
root - INFO     Initally: N samples: 210, M feat: 1421
root - INFO     Dropped features quantified in less than 126 samples.
root - INFO     After feat selection: N samples: 210, M feat: 1213
root - INFO     Min No. of Protein-Groups in single sample: 754
root - INFO     Finally: N samples: 210, M feat: 1213
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 A0A075B6J9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 20.246 16.764 17.584 16.988 20.054 NaN ... 16.012 15.178 NaN 15.050 16.842 19.863 NaN 19.563 12.837 12.805
Sample_001 15.936 16.874 15.519 16.387 19.941 18.786 17.144 NaN 19.067 16.188 ... 15.528 15.576 NaN 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 14.523 15.935 16.416 19.251 16.832 15.671 17.012 18.569 NaN ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 19.628 17.852 18.877 14.182 18.985 13.438 ... 15.495 14.590 14.682 15.140 17.356 19.429 NaN 20.216 12.627 12.445
Sample_004 15.603 15.331 15.375 16.679 20.450 18.682 17.081 14.140 19.686 14.495 ... 14.757 15.094 14.048 15.256 17.075 19.582 15.328 19.867 13.145 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 17.705 17.039 NaN 16.413 19.102 16.064 ... 15.235 15.684 14.236 15.415 17.551 17.922 16.340 19.928 12.929 11.802
Sample_206 15.798 17.554 15.600 15.938 18.154 18.152 16.503 16.860 18.538 15.288 ... 15.422 16.106 NaN 15.345 17.084 18.708 14.249 19.433 NaN NaN
Sample_207 15.739 16.877 15.469 16.898 18.636 17.950 16.321 16.401 18.849 17.580 ... 15.808 16.098 14.403 15.715 16.586 18.725 16.138 19.599 13.637 11.174
Sample_208 15.477 16.779 14.995 16.132 14.908 17.530 NaN 16.119 18.368 15.202 ... 15.157 16.712 NaN 14.640 16.533 19.411 15.807 19.545 13.216 NaN
Sample_209 15.727 17.261 15.175 16.235 17.893 17.744 16.371 15.780 18.806 16.532 ... 15.237 15.652 15.211 14.205 16.749 19.275 15.732 19.577 11.042 11.791

210 rows × 1213 columns

Number of complete cases which can be used:

Hide code cell source

mask_has_target = data.index.levels[0].intersection(target.index)
assert not mask_has_target.empty, f"No data for target: {data.index.levels[0]} and {target.index}"
print(
    f"Samples available both in proteomics data and for target: {len(mask_has_target)}")
target, data, ald_study = target.loc[mask_has_target], data.loc[mask_has_target], ald_study.loc[mask_has_target]
Samples available both in proteomics data and for target: 210

Load imputations from specified model#

Hide code cell source

fname = args.out_preds / args.template_pred.format(args.model_key)
print(f"missing values pred. by {args.model_key}: {fname}")
load_single_csv_pred_file = pimmslearn.analyzers.compare_predictions.load_single_csv_pred_file
pred_real_na = load_single_csv_pred_file(fname).loc[mask_has_target]
pred_real_na.sample(3)
missing values pred. by DAE: runs/alzheimer_study/preds/pred_real_na_DAE.csv
Sample ID   protein groups 
Sample_027  O43854;O43854-2   13.393
Sample_020  P04179            16.894
Sample_101  P09960;P09960-4   13.080
Name: intensity, dtype: float64

Load imputations from baseline model#

Hide code cell source

fname = args.out_preds / args.template_pred.format(args.baseline)
pred_real_na_baseline = load_single_csv_pred_file(fname)  # .loc[mask_has_target]
pred_real_na_baseline
Sample ID   protein groups          
Sample_000  A0A075B6J9                 13.412
            A0A075B6Q5                 13.967
            A0A075B6R2                 12.053
            A0A075B6S5                 13.419
            A0A087WSY4                 14.256
                                        ...  
Sample_209  Q9P1W8;Q9P1W8-2;Q9P1W8-4   12.540
            Q9UI40;Q9UI40-2            11.810
            Q9UIW2                     12.704
            Q9UMX0;Q9UMX0-2;Q9UMX0-4   11.626
            Q9UP79                     12.553
Name: intensity, Length: 46401, dtype: float64

Modeling setup#

General approach:

  • use one train, test split of the data

  • select best 10 features from training data X_train, y_train before binarization of target

  • dichotomize (binarize) data into to groups (zero and 1)

  • evaluate model on the test data X_test, y_test

Repeat general approach for

  1. all original ald data: all features justed in original ALD study

  2. all model data: all features available my using the self supervised deep learning model

  3. newly available feat only: the subset of features available from the self supervised deep learning model which were newly retained using the new approach

All data:

Hide code cell source

X = pd.concat([data, pred_real_na]).unstack()
X
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H7 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 ... Q9Y653;Q9Y653-2;Q9Y653-3 Q9Y696 Q9Y6C2 Q9Y6N6 Q9Y6N7;Q9Y6N7-2;Q9Y6N7-4 Q9Y6R7 Q9Y6X5 Q9Y6Y8;Q9Y6Y8-2 Q9Y6Y9 S4R3U6
Sample ID
Sample_000 15.912 16.852 15.570 16.481 17.301 20.246 16.764 17.584 16.988 20.054 ... 16.012 15.178 14.548 15.050 16.842 19.863 15.734 19.563 12.837 12.805
Sample_001 15.936 16.874 15.519 16.387 13.796 19.941 18.786 17.144 16.604 19.067 ... 15.528 15.576 14.070 14.833 16.597 20.299 15.556 19.386 13.970 12.442
Sample_002 16.111 14.523 15.935 16.416 18.175 19.251 16.832 15.671 17.012 18.569 ... 15.229 14.728 13.757 15.118 17.440 19.598 15.735 20.447 12.636 12.505
Sample_003 16.107 17.032 15.802 16.979 15.963 19.628 17.852 18.877 14.182 18.985 ... 15.495 14.590 14.682 15.140 17.356 19.429 15.783 20.216 12.627 12.445
Sample_004 15.603 15.331 15.375 16.679 15.473 20.450 18.682 17.081 14.140 19.686 ... 14.757 15.094 14.048 15.256 17.075 19.582 15.328 19.867 13.145 12.235
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 14.833 17.705 17.039 15.768 16.413 19.102 ... 15.235 15.684 14.236 15.415 17.551 17.922 16.340 19.928 12.929 11.802
Sample_206 15.798 17.554 15.600 15.938 14.787 18.154 18.152 16.503 16.860 18.538 ... 15.422 16.106 14.642 15.345 17.084 18.708 14.249 19.433 11.131 10.907
Sample_207 15.739 16.877 15.469 16.898 13.539 18.636 17.950 16.321 16.401 18.849 ... 15.808 16.098 14.403 15.715 16.586 18.725 16.138 19.599 13.637 11.174
Sample_208 15.477 16.779 14.995 16.132 14.768 14.908 17.530 16.177 16.119 18.368 ... 15.157 16.712 14.352 14.640 16.533 19.411 15.807 19.545 13.216 10.426
Sample_209 15.727 17.261 15.175 16.235 14.099 17.893 17.744 16.371 15.780 18.806 ... 15.237 15.652 15.211 14.205 16.749 19.275 15.732 19.577 11.042 11.791

210 rows × 1421 columns

Subset of data by ALD criteria#

Hide code cell source

# could be just observed, drop columns with missing values
ald_study = pd.concat(
    [ald_study.stack(),
     pred_real_na_baseline.loc[
        # only select columns in selected in ald_study
        pd.IndexSlice[:, pred_real_na.index.levels[-1].intersection(ald_study.columns)]
    ]
    ]
).unstack()
ald_study
protein groups A0A024QZX5;A0A087X1N8;P35237 A0A024R0T9;K7ER74;P02655 A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 A0A075B6H9 A0A075B6I0 A0A075B6I1 A0A075B6I6 A0A075B6I9 A0A075B6K4 ... O14793 O95479;R4GMU1 P01282;P01282-2 P10619;P10619-2;X6R5C5;X6R8A1 P21810 Q14956;Q14956-2 Q6ZMP0;Q6ZMP0-2 Q9HBW1 Q9NY15 P17050
Sample ID
Sample_000 15.912 16.852 15.570 16.481 20.246 16.764 17.584 16.988 20.054 16.148 ... 12.449 13.051 13.220 12.867 13.424 13.300 13.211 14.187 13.768 11.609
Sample_001 15.936 16.874 15.519 16.387 19.941 18.786 17.144 13.309 19.067 16.127 ... 12.691 12.670 13.043 12.063 12.694 12.414 12.431 11.873 13.834 12.740
Sample_002 16.111 14.523 15.935 16.416 19.251 16.832 15.671 17.012 18.569 15.387 ... 14.611 11.361 14.329 12.444 13.242 13.526 12.610 12.567 12.622 12.551
Sample_003 16.107 17.032 15.802 16.979 19.628 17.852 18.877 14.182 18.985 16.565 ... 12.778 12.245 14.173 12.847 12.345 13.331 13.855 13.287 11.289 13.354
Sample_004 15.603 15.331 15.375 16.679 20.450 18.682 17.081 14.140 19.686 16.418 ... 13.803 14.433 12.759 13.537 13.614 12.396 14.186 14.139 11.627 11.973
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Sample_205 15.682 16.886 14.910 16.482 17.705 17.039 13.688 16.413 19.102 15.350 ... 14.269 14.064 16.826 18.182 15.225 15.044 14.192 16.605 14.995 14.257
Sample_206 15.798 17.554 15.600 15.938 18.154 18.152 16.503 16.860 18.538 16.582 ... 14.273 17.700 16.802 20.202 15.280 15.086 13.978 18.086 15.557 14.171
Sample_207 15.739 16.877 15.469 16.898 18.636 17.950 16.321 16.401 18.849 15.768 ... 14.473 16.882 16.917 20.105 15.690 15.135 13.138 17.066 15.706 15.690
Sample_208 15.477 16.779 14.995 16.132 14.908 17.530 13.150 16.119 18.368 17.560 ... 15.234 17.175 16.521 18.859 15.305 15.161 13.006 17.917 15.396 14.371
Sample_209 15.727 17.261 15.175 16.235 17.893 17.744 16.371 15.780 18.806 16.338 ... 14.556 16.656 16.954 18.493 15.823 14.626 13.385 17.767 15.687 13.573

210 rows × 1213 columns

Features which would not have been included using ALD criteria:

Hide code cell source

new_features = X.columns.difference(ald_study.columns)
new_features
Index(['A0A075B6H7', 'A0A075B6Q5', 'A0A075B7B8', 'A0A087WSY4',
       'A0A087WTT8;A0A0A0MQX5;O94779;O94779-2', 'A0A087WXB8;Q9Y274',
       'A0A087WXE9;E9PQ70;Q6UXH9;Q6UXH9-2;Q6UXH9-3',
       'A0A087X1Z2;C9JTV4;H0Y4Y4;Q8WYH2;Q96C19;Q9BUP0;Q9BUP0-2',
       'A0A0A0MQS9;A0A0A0MTC7;Q16363;Q16363-2', 'A0A0A0MSN4;P12821;P12821-2',
       ...
       'Q9NZ94;Q9NZ94-2;Q9NZ94-3', 'Q9NZU1', 'Q9P1W8;Q9P1W8-2;Q9P1W8-4',
       'Q9UHI8', 'Q9UI40;Q9UI40-2',
       'Q9UIB8;Q9UIB8-2;Q9UIB8-3;Q9UIB8-4;Q9UIB8-5;Q9UIB8-6',
       'Q9UKZ4;Q9UKZ4-2', 'Q9UMX0;Q9UMX0-2;Q9UMX0-4', 'Q9Y281;Q9Y281-3',
       'Q9Y490'],
      dtype='object', name='protein groups', length=208)

Binarize targets, but also keep groups for stratification

Hide code cell source

target_to_group = target.copy()
target = target >= args.cutoff_target
pd.crosstab(target.squeeze(), target_to_group.squeeze())
AD 0 1
AD
False 122 0
True 0 88

Determine best number of parameters by cross validation procedure#

using subset of data by ALD criteria:

Hide code cell source

cv_feat_ald = njab.sklearn.find_n_best_features(X=ald_study, y=target, name=args.target,
                                                groups=target_to_group)
cv_feat_ald = (cv_feat_ald
               .drop('test_case', axis=1)
               .groupby('n_features')
               .agg(['mean', 'std']))
cv_feat_ald
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 124.83it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00,  7.90it/s]
100%|██████████| 2/2 [00:00<00:00,  7.83it/s]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.94it/s]
100%|██████████| 3/3 [00:00<00:00,  5.23it/s]
100%|██████████| 3/3 [00:00<00:00,  5.48it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00,  9.82it/s]
 75%|███████▌  | 3/4 [00:00<00:00,  6.81it/s]
100%|██████████| 4/4 [00:00<00:00,  5.94it/s]
100%|██████████| 4/4 [00:00<00:00,  6.45it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
 40%|████      | 2/5 [00:00<00:00,  7.71it/s]
 60%|██████    | 3/5 [00:00<00:00,  6.23it/s]
 80%|████████  | 4/5 [00:00<00:00,  5.55it/s]
100%|██████████| 5/5 [00:00<00:00,  5.00it/s]
100%|██████████| 5/5 [00:00<00:00,  5.47it/s]
  0%|          | 0/6 [00:00<?, ?it/s]
 33%|███▎      | 2/6 [00:00<00:00,  5.00it/s]
 50%|█████     | 3/6 [00:00<00:00,  3.68it/s]
 67%|██████▋   | 4/6 [00:01<00:00,  3.46it/s]
 83%|████████▎ | 5/6 [00:01<00:00,  3.46it/s]
100%|██████████| 6/6 [00:01<00:00,  3.70it/s]
100%|██████████| 6/6 [00:01<00:00,  3.72it/s]
  0%|          | 0/7 [00:00<?, ?it/s]
 29%|██▊       | 2/7 [00:00<00:00,  7.74it/s]
 43%|████▎     | 3/7 [00:00<00:00,  5.37it/s]
 57%|█████▋    | 4/7 [00:00<00:00,  4.93it/s]
 71%|███████▏  | 5/7 [00:01<00:00,  4.39it/s]
 86%|████████▌ | 6/7 [00:01<00:00,  3.47it/s]
100%|██████████| 7/7 [00:01<00:00,  3.14it/s]
100%|██████████| 7/7 [00:01<00:00,  3.84it/s]
  0%|          | 0/8 [00:00<?, ?it/s]
 25%|██▌       | 2/8 [00:00<00:00,  7.05it/s]
 38%|███▊      | 3/8 [00:00<00:01,  4.80it/s]
 50%|█████     | 4/8 [00:00<00:00,  4.49it/s]
 62%|██████▎   | 5/8 [00:01<00:00,  4.47it/s]
 75%|███████▌  | 6/8 [00:01<00:00,  4.49it/s]
 88%|████████▊ | 7/8 [00:01<00:00,  3.73it/s]
100%|██████████| 8/8 [00:02<00:00,  3.28it/s]
100%|██████████| 8/8 [00:02<00:00,  3.94it/s]
  0%|          | 0/9 [00:00<?, ?it/s]
 22%|██▏       | 2/9 [00:00<00:01,  6.77it/s]
 33%|███▎      | 3/9 [00:00<00:01,  4.17it/s]
 44%|████▍     | 4/9 [00:01<00:01,  3.43it/s]
 56%|█████▌    | 5/9 [00:01<00:01,  2.96it/s]
 67%|██████▋   | 6/9 [00:01<00:01,  2.64it/s]
 78%|███████▊  | 7/9 [00:02<00:00,  2.72it/s]
 89%|████████▉ | 8/9 [00:02<00:00,  2.68it/s]
100%|██████████| 9/9 [00:02<00:00,  2.88it/s]
100%|██████████| 9/9 [00:02<00:00,  3.05it/s]
  0%|          | 0/10 [00:00<?, ?it/s]
 20%|██        | 2/10 [00:00<00:01,  6.52it/s]
 30%|███       | 3/10 [00:00<00:01,  4.05it/s]
 40%|████      | 4/10 [00:01<00:01,  3.47it/s]
 50%|█████     | 5/10 [00:01<00:01,  3.54it/s]
 60%|██████    | 6/10 [00:01<00:01,  3.65it/s]
 70%|███████   | 7/10 [00:01<00:00,  3.43it/s]
 80%|████████  | 8/10 [00:02<00:00,  3.64it/s]
 90%|█████████ | 9/10 [00:02<00:00,  3.85it/s]
100%|██████████| 10/10 [00:02<00:00,  3.90it/s]
100%|██████████| 10/10 [00:02<00:00,  3.82it/s]
  0%|          | 0/11 [00:00<?, ?it/s]
 18%|█▊        | 2/11 [00:00<00:01,  6.92it/s]
 27%|██▋       | 3/11 [00:00<00:01,  5.06it/s]
 36%|███▋      | 4/11 [00:00<00:01,  4.36it/s]
 45%|████▌     | 5/11 [00:01<00:01,  4.34it/s]
 55%|█████▍    | 6/11 [00:01<00:01,  4.38it/s]
 64%|██████▎   | 7/11 [00:01<00:00,  4.40it/s]
 73%|███████▎  | 8/11 [00:01<00:00,  4.38it/s]
 82%|████████▏ | 9/11 [00:01<00:00,  4.34it/s]
 91%|█████████ | 10/11 [00:02<00:00,  4.27it/s]
100%|██████████| 11/11 [00:02<00:00,  4.34it/s]
100%|██████████| 11/11 [00:02<00:00,  4.48it/s]
  0%|          | 0/12 [00:00<?, ?it/s]
 17%|█▋        | 2/12 [00:00<00:01,  9.01it/s]
 25%|██▌       | 3/12 [00:00<00:01,  6.05it/s]
 33%|███▎      | 4/12 [00:00<00:01,  5.30it/s]
 42%|████▏     | 5/12 [00:00<00:01,  4.96it/s]
 50%|█████     | 6/12 [00:01<00:01,  4.85it/s]
 58%|█████▊    | 7/12 [00:01<00:01,  4.49it/s]
 67%|██████▋   | 8/12 [00:01<00:00,  4.38it/s]
 75%|███████▌  | 9/12 [00:01<00:00,  4.44it/s]
 83%|████████▎ | 10/12 [00:02<00:00,  3.38it/s]
 92%|█████████▏| 11/12 [00:02<00:00,  3.43it/s]
100%|██████████| 12/12 [00:02<00:00,  3.42it/s]
100%|██████████| 12/12 [00:02<00:00,  4.16it/s]
  0%|          | 0/13 [00:00<?, ?it/s]
 15%|█▌        | 2/13 [00:00<00:01,  8.48it/s]
 23%|██▎       | 3/13 [00:00<00:01,  5.52it/s]
 31%|███       | 4/13 [00:00<00:01,  5.01it/s]
 38%|███▊      | 5/13 [00:00<00:01,  4.72it/s]
 46%|████▌     | 6/13 [00:01<00:01,  4.62it/s]
 54%|█████▍    | 7/13 [00:01<00:01,  4.53it/s]
 62%|██████▏   | 8/13 [00:01<00:01,  4.55it/s]
 69%|██████▉   | 9/13 [00:01<00:00,  4.52it/s]
 77%|███████▋  | 10/13 [00:02<00:00,  3.90it/s]
 85%|████████▍ | 11/13 [00:02<00:00,  3.13it/s]
 92%|█████████▏| 12/13 [00:03<00:00,  2.94it/s]
100%|██████████| 13/13 [00:03<00:00,  2.75it/s]
100%|██████████| 13/13 [00:03<00:00,  3.74it/s]
  0%|          | 0/14 [00:00<?, ?it/s]
 14%|█▍        | 2/14 [00:00<00:01,  8.91it/s]
 21%|██▏       | 3/14 [00:00<00:01,  6.25it/s]
 29%|██▊       | 4/14 [00:00<00:01,  5.27it/s]
 36%|███▌      | 5/14 [00:00<00:01,  4.74it/s]
 43%|████▎     | 6/14 [00:01<00:01,  4.38it/s]
 50%|█████     | 7/14 [00:01<00:01,  3.61it/s]
 57%|█████▋    | 8/14 [00:01<00:01,  3.23it/s]
 64%|██████▍   | 9/14 [00:02<00:01,  3.02it/s]
 71%|███████▏  | 10/14 [00:02<00:01,  2.92it/s]
 79%|███████▊  | 11/14 [00:03<00:01,  2.86it/s]
 86%|████████▌ | 12/14 [00:03<00:00,  3.10it/s]
 93%|█████████▎| 13/14 [00:03<00:00,  3.23it/s]
100%|██████████| 14/14 [00:03<00:00,  3.27it/s]
100%|██████████| 14/14 [00:03<00:00,  3.57it/s]
  0%|          | 0/15 [00:00<?, ?it/s]
 13%|█▎        | 2/15 [00:00<00:01,  8.33it/s]
 20%|██        | 3/15 [00:00<00:01,  6.30it/s]
 27%|██▋       | 4/15 [00:00<00:02,  4.37it/s]
 33%|███▎      | 5/15 [00:01<00:02,  3.67it/s]
 40%|████      | 6/15 [00:01<00:02,  3.33it/s]
 47%|████▋     | 7/15 [00:01<00:02,  3.16it/s]
 53%|█████▎    | 8/15 [00:02<00:02,  3.03it/s]
 60%|██████    | 9/15 [00:02<00:01,  3.22it/s]
 67%|██████▋   | 10/15 [00:02<00:01,  3.31it/s]
 73%|███████▎  | 11/15 [00:03<00:01,  3.53it/s]
 80%|████████  | 12/15 [00:03<00:00,  3.53it/s]
 87%|████████▋ | 13/15 [00:03<00:00,  3.62it/s]
 93%|█████████▎| 14/15 [00:03<00:00,  3.40it/s]
100%|██████████| 15/15 [00:04<00:00,  3.41it/s]
100%|██████████| 15/15 [00:04<00:00,  3.58it/s]
fit_time score_time test_precision test_recall test_f1 test_balanced_accuracy test_roc_auc test_average_precision n_observations
mean std mean std mean std mean std mean std mean std mean std mean std mean std
n_features
1 0.004 0.002 0.047 0.018 0.899 0.158 0.169 0.089 0.274 0.124 0.576 0.043 0.856 0.060 0.823 0.086 210.000 0.000
2 0.004 0.001 0.041 0.012 0.629 0.134 0.431 0.141 0.497 0.115 0.618 0.071 0.693 0.083 0.633 0.095 210.000 0.000
3 0.004 0.001 0.042 0.013 0.663 0.098 0.611 0.122 0.631 0.094 0.691 0.072 0.789 0.070 0.736 0.099 210.000 0.000
4 0.005 0.002 0.058 0.019 0.662 0.097 0.620 0.126 0.635 0.096 0.694 0.073 0.780 0.073 0.725 0.101 210.000 0.000
5 0.005 0.003 0.057 0.023 0.754 0.102 0.716 0.099 0.728 0.077 0.768 0.065 0.857 0.057 0.836 0.063 210.000 0.000
6 0.003 0.000 0.040 0.007 0.810 0.078 0.829 0.095 0.815 0.062 0.842 0.054 0.908 0.046 0.889 0.053 210.000 0.000
7 0.006 0.002 0.066 0.022 0.810 0.079 0.827 0.100 0.814 0.066 0.841 0.057 0.906 0.048 0.887 0.055 210.000 0.000
8 0.005 0.002 0.054 0.018 0.825 0.087 0.825 0.098 0.820 0.065 0.846 0.055 0.909 0.047 0.882 0.063 210.000 0.000
9 0.004 0.001 0.040 0.007 0.817 0.083 0.807 0.104 0.806 0.066 0.835 0.055 0.908 0.048 0.885 0.059 210.000 0.000
10 0.004 0.002 0.045 0.015 0.843 0.088 0.824 0.105 0.828 0.072 0.853 0.061 0.919 0.049 0.906 0.056 210.000 0.000
11 0.005 0.002 0.051 0.019 0.835 0.087 0.817 0.108 0.821 0.075 0.848 0.064 0.920 0.049 0.907 0.056 210.000 0.000
12 0.006 0.003 0.054 0.020 0.829 0.085 0.830 0.098 0.825 0.073 0.851 0.063 0.920 0.050 0.909 0.055 210.000 0.000
13 0.005 0.002 0.043 0.016 0.831 0.089 0.828 0.099 0.825 0.073 0.850 0.063 0.919 0.050 0.908 0.055 210.000 0.000
14 0.004 0.001 0.039 0.011 0.821 0.086 0.825 0.092 0.819 0.066 0.845 0.057 0.918 0.049 0.908 0.053 210.000 0.000
15 0.005 0.002 0.042 0.010 0.828 0.089 0.825 0.092 0.822 0.069 0.848 0.059 0.919 0.049 0.911 0.051 210.000 0.000

Using all data:

Hide code cell source

cv_feat_all = njab.sklearn.find_n_best_features(X=X, y=target, name=args.target,
                                                groups=target_to_group)
cv_feat_all = cv_feat_all.drop('test_case', axis=1).groupby('n_features').agg(['mean', 'std'])
cv_feat_all
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 170.39it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00,  8.36it/s]
100%|██████████| 2/2 [00:00<00:00,  8.27it/s]
  0%|          | 0/3 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  9.23it/s]
100%|██████████| 3/3 [00:00<00:00,  5.52it/s]
100%|██████████| 3/3 [00:00<00:00,  5.97it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00,  4.62it/s]
 75%|███████▌  | 3/4 [00:00<00:00,  3.09it/s]
100%|██████████| 4/4 [00:01<00:00,  2.55it/s]
100%|██████████| 4/4 [00:01<00:00,  2.82it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
 40%|████      | 2/5 [00:00<00:00,  6.95it/s]
 60%|██████    | 3/5 [00:00<00:00,  4.83it/s]
 80%|████████  | 4/5 [00:00<00:00,  4.31it/s]
100%|██████████| 5/5 [00:01<00:00,  4.45it/s]
100%|██████████| 5/5 [00:01<00:00,  4.68it/s]
  0%|          | 0/6 [00:00<?, ?it/s]
 33%|███▎      | 2/6 [00:00<00:00,  7.28it/s]
 50%|█████     | 3/6 [00:00<00:00,  5.58it/s]
 67%|██████▋   | 4/6 [00:00<00:00,  4.91it/s]
 83%|████████▎ | 5/6 [00:00<00:00,  4.66it/s]
100%|██████████| 6/6 [00:01<00:00,  4.35it/s]
100%|██████████| 6/6 [00:01<00:00,  4.78it/s]
  0%|          | 0/7 [00:00<?, ?it/s]
 29%|██▊       | 2/7 [00:00<00:00,  7.70it/s]
 43%|████▎     | 3/7 [00:00<00:00,  5.58it/s]
 57%|█████▋    | 4/7 [00:00<00:00,  3.85it/s]
 71%|███████▏  | 5/7 [00:01<00:00,  3.19it/s]
 86%|████████▌ | 6/7 [00:01<00:00,  2.89it/s]
100%|██████████| 7/7 [00:02<00:00,  2.70it/s]
100%|██████████| 7/7 [00:02<00:00,  3.25it/s]
  0%|          | 0/8 [00:00<?, ?it/s]
 25%|██▌       | 2/8 [00:00<00:01,  4.89it/s]
 38%|███▊      | 3/8 [00:00<00:01,  3.82it/s]
 50%|█████     | 4/8 [00:01<00:01,  3.45it/s]
 62%|██████▎   | 5/8 [00:01<00:00,  3.61it/s]
 75%|███████▌  | 6/8 [00:01<00:00,  3.67it/s]
 88%|████████▊ | 7/8 [00:01<00:00,  3.83it/s]
100%|██████████| 8/8 [00:02<00:00,  4.11it/s]
100%|██████████| 8/8 [00:02<00:00,  3.91it/s]
  0%|          | 0/9 [00:00<?, ?it/s]
 22%|██▏       | 2/9 [00:00<00:01,  6.66it/s]
 33%|███▎      | 3/9 [00:00<00:01,  4.38it/s]
 44%|████▍     | 4/9 [00:00<00:01,  3.90it/s]
 56%|█████▌    | 5/9 [00:01<00:01,  3.74it/s]
 67%|██████▋   | 6/9 [00:01<00:00,  4.01it/s]
 78%|███████▊  | 7/9 [00:01<00:00,  4.13it/s]
 89%|████████▉ | 8/9 [00:01<00:00,  4.00it/s]
100%|██████████| 9/9 [00:02<00:00,  3.84it/s]
100%|██████████| 9/9 [00:02<00:00,  4.05it/s]
  0%|          | 0/10 [00:00<?, ?it/s]
 20%|██        | 2/10 [00:00<00:00,  8.10it/s]
 30%|███       | 3/10 [00:00<00:01,  5.74it/s]
 40%|████      | 4/10 [00:00<00:01,  5.03it/s]
 50%|█████     | 5/10 [00:00<00:01,  4.57it/s]
 60%|██████    | 6/10 [00:01<00:00,  4.18it/s]
 70%|███████   | 7/10 [00:01<00:00,  4.14it/s]
 80%|████████  | 8/10 [00:01<00:00,  4.10it/s]
 90%|█████████ | 9/10 [00:02<00:00,  3.97it/s]
100%|██████████| 10/10 [00:02<00:00,  3.87it/s]
100%|██████████| 10/10 [00:02<00:00,  4.33it/s]
  0%|          | 0/11 [00:00<?, ?it/s]
 18%|█▊        | 2/11 [00:00<00:01,  8.49it/s]
 27%|██▋       | 3/11 [00:00<00:01,  6.33it/s]
 36%|███▋      | 4/11 [00:00<00:01,  5.58it/s]
 45%|████▌     | 5/11 [00:00<00:01,  5.22it/s]
 55%|█████▍    | 6/11 [00:01<00:01,  4.95it/s]
 64%|██████▎   | 7/11 [00:01<00:00,  4.83it/s]
 73%|███████▎  | 8/11 [00:01<00:00,  3.79it/s]
 82%|████████▏ | 9/11 [00:02<00:00,  3.35it/s]
 91%|█████████ | 10/11 [00:02<00:00,  2.92it/s]
100%|██████████| 11/11 [00:02<00:00,  2.74it/s]
100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
  0%|          | 0/12 [00:00<?, ?it/s]
 17%|█▋        | 2/12 [00:00<00:01,  8.22it/s]
 25%|██▌       | 3/12 [00:00<00:01,  6.38it/s]
 33%|███▎      | 4/12 [00:00<00:01,  5.24it/s]
 42%|████▏     | 5/12 [00:00<00:01,  4.48it/s]
 50%|█████     | 6/12 [00:01<00:01,  3.82it/s]
 58%|█████▊    | 7/12 [00:01<00:01,  3.28it/s]
 67%|██████▋   | 8/12 [00:02<00:01,  3.02it/s]
 75%|███████▌  | 9/12 [00:02<00:01,  2.80it/s]
 83%|████████▎ | 10/12 [00:02<00:00,  2.73it/s]
 92%|█████████▏| 11/12 [00:03<00:00,  2.88it/s]
100%|██████████| 12/12 [00:03<00:00,  2.98it/s]
100%|██████████| 12/12 [00:03<00:00,  3.41it/s]
  0%|          | 0/13 [00:00<?, ?it/s]
 15%|█▌        | 2/13 [00:00<00:02,  5.27it/s]
 23%|██▎       | 3/13 [00:00<00:02,  3.68it/s]
 31%|███       | 4/13 [00:01<00:02,  3.02it/s]
 38%|███▊      | 5/13 [00:01<00:02,  2.98it/s]
 46%|████▌     | 6/13 [00:01<00:02,  2.89it/s]
 54%|█████▍    | 7/13 [00:02<00:02,  2.91it/s]
 62%|██████▏   | 8/13 [00:02<00:01,  3.01it/s]
 69%|██████▉   | 9/13 [00:02<00:01,  2.86it/s]
 77%|███████▋  | 10/13 [00:03<00:01,  2.95it/s]
 85%|████████▍ | 11/13 [00:03<00:00,  3.07it/s]
 92%|█████████▏| 12/13 [00:03<00:00,  3.11it/s]
100%|██████████| 13/13 [00:04<00:00,  3.17it/s]
100%|██████████| 13/13 [00:04<00:00,  3.12it/s]
  0%|          | 0/14 [00:00<?, ?it/s]
 14%|█▍        | 2/14 [00:00<00:02,  4.75it/s]
 21%|██▏       | 3/14 [00:00<00:03,  3.39it/s]
 29%|██▊       | 4/14 [00:01<00:03,  3.28it/s]
 36%|███▌      | 5/14 [00:01<00:02,  3.56it/s]
 43%|████▎     | 6/14 [00:01<00:02,  3.48it/s]
 50%|█████     | 7/14 [00:01<00:01,  3.55it/s]
 57%|█████▋    | 8/14 [00:02<00:01,  3.48it/s]
 64%|██████▍   | 9/14 [00:02<00:01,  3.43it/s]
 71%|███████▏  | 10/14 [00:02<00:01,  3.30it/s]
 79%|███████▊  | 11/14 [00:03<00:00,  3.32it/s]
 86%|████████▌ | 12/14 [00:03<00:00,  3.37it/s]
 93%|█████████▎| 13/14 [00:03<00:00,  3.34it/s]
100%|██████████| 14/14 [00:04<00:00,  3.27it/s]
100%|██████████| 14/14 [00:04<00:00,  3.41it/s]
  0%|          | 0/15 [00:00<?, ?it/s]
 13%|█▎        | 2/15 [00:00<00:02,  5.92it/s]
 20%|██        | 3/15 [00:00<00:03,  3.97it/s]
 27%|██▋       | 4/15 [00:01<00:03,  3.47it/s]
 33%|███▎      | 5/15 [00:01<00:03,  3.12it/s]
 40%|████      | 6/15 [00:01<00:02,  3.20it/s]
 47%|████▋     | 7/15 [00:02<00:02,  3.21it/s]
 53%|█████▎    | 8/15 [00:02<00:02,  3.08it/s]
 60%|██████    | 9/15 [00:02<00:02,  2.88it/s]
 67%|██████▋   | 10/15 [00:03<00:01,  2.58it/s]
 73%|███████▎  | 11/15 [00:03<00:01,  2.46it/s]
 80%|████████  | 12/15 [00:04<00:01,  2.51it/s]
 87%|████████▋ | 13/15 [00:04<00:00,  2.56it/s]
 93%|█████████▎| 14/15 [00:04<00:00,  2.70it/s]
100%|██████████| 15/15 [00:05<00:00,  2.96it/s]
100%|██████████| 15/15 [00:05<00:00,  2.96it/s]
fit_time score_time test_precision test_recall test_f1 test_balanced_accuracy test_roc_auc test_average_precision n_observations
mean std mean std mean std mean std mean std mean std mean std mean std mean std
n_features
1 0.005 0.002 0.049 0.017 0.112 0.270 0.014 0.037 0.024 0.062 0.502 0.016 0.874 0.057 0.838 0.086 210.000 0.000
2 0.004 0.002 0.051 0.019 0.634 0.100 0.521 0.113 0.565 0.095 0.649 0.066 0.728 0.081 0.649 0.102 210.000 0.000
3 0.005 0.002 0.055 0.020 0.726 0.104 0.634 0.107 0.669 0.076 0.725 0.056 0.795 0.065 0.775 0.083 210.000 0.000
4 0.007 0.003 0.072 0.028 0.804 0.092 0.622 0.116 0.694 0.087 0.753 0.059 0.818 0.064 0.791 0.079 210.000 0.000
5 0.004 0.002 0.046 0.020 0.798 0.071 0.801 0.108 0.794 0.066 0.825 0.055 0.905 0.050 0.879 0.062 210.000 0.000
6 0.004 0.002 0.041 0.019 0.807 0.082 0.798 0.112 0.797 0.075 0.827 0.062 0.907 0.049 0.879 0.063 210.000 0.000
7 0.004 0.001 0.040 0.010 0.810 0.084 0.814 0.107 0.807 0.074 0.835 0.062 0.906 0.049 0.879 0.062 210.000 0.000
8 0.002 0.000 0.024 0.005 0.799 0.079 0.817 0.109 0.803 0.073 0.832 0.061 0.903 0.053 0.877 0.063 210.000 0.000
9 0.005 0.002 0.049 0.022 0.789 0.081 0.815 0.108 0.797 0.074 0.827 0.062 0.901 0.053 0.875 0.063 210.000 0.000
10 0.005 0.002 0.055 0.021 0.794 0.086 0.822 0.111 0.802 0.075 0.831 0.064 0.902 0.053 0.875 0.063 210.000 0.000
11 0.004 0.001 0.036 0.012 0.792 0.089 0.798 0.106 0.789 0.071 0.820 0.061 0.907 0.048 0.880 0.058 210.000 0.000
12 0.005 0.002 0.045 0.017 0.821 0.092 0.798 0.101 0.804 0.072 0.833 0.062 0.921 0.042 0.900 0.052 210.000 0.000
13 0.004 0.001 0.039 0.010 0.824 0.082 0.787 0.099 0.799 0.063 0.830 0.052 0.922 0.040 0.904 0.050 210.000 0.000
14 0.007 0.003 0.065 0.021 0.821 0.083 0.794 0.099 0.802 0.064 0.831 0.054 0.921 0.041 0.903 0.050 210.000 0.000
15 0.005 0.002 0.044 0.014 0.820 0.081 0.795 0.099 0.802 0.063 0.832 0.053 0.920 0.040 0.901 0.050 210.000 0.000

Using only new features:

Hide code cell source

cv_feat_new = njab.sklearn.find_n_best_features(X=X.loc[:, new_features],
                                                y=target, name=args.target,
                                                groups=target_to_group)
cv_feat_new = cv_feat_new.drop('test_case', axis=1).groupby('n_features').agg(['mean', 'std'])
cv_feat_new
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1031.81it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 33.80it/s]
  0%|          | 0/3 [00:00<?, ?it/s]
100%|██████████| 3/3 [00:00<00:00, 28.90it/s]
100%|██████████| 3/3 [00:00<00:00, 28.29it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00, 16.43it/s]
100%|██████████| 4/4 [00:00<00:00, 12.99it/s]
100%|██████████| 4/4 [00:00<00:00, 13.23it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
 60%|██████    | 3/5 [00:00<00:00, 22.12it/s]
100%|██████████| 5/5 [00:00<00:00, 17.25it/s]
  0%|          | 0/6 [00:00<?, ?it/s]
 33%|███▎      | 2/6 [00:00<00:00, 18.28it/s]
 67%|██████▋   | 4/6 [00:00<00:00, 11.24it/s]
100%|██████████| 6/6 [00:00<00:00, 10.61it/s]
100%|██████████| 6/6 [00:00<00:00, 11.13it/s]
  0%|          | 0/7 [00:00<?, ?it/s]
 29%|██▊       | 2/7 [00:00<00:00, 18.14it/s]
 57%|█████▋    | 4/7 [00:00<00:00, 10.82it/s]
 86%|████████▌ | 6/7 [00:00<00:00, 10.02it/s]
100%|██████████| 7/7 [00:00<00:00, 10.07it/s]
  0%|          | 0/8 [00:00<?, ?it/s]
 38%|███▊      | 3/8 [00:00<00:00, 25.53it/s]
 75%|███████▌  | 6/8 [00:00<00:00, 18.85it/s]
100%|██████████| 8/8 [00:00<00:00, 17.47it/s]
100%|██████████| 8/8 [00:00<00:00, 18.34it/s]
  0%|          | 0/9 [00:00<?, ?it/s]
 33%|███▎      | 3/9 [00:00<00:00, 18.32it/s]
 56%|█████▌    | 5/9 [00:00<00:00, 17.20it/s]
 78%|███████▊  | 7/9 [00:00<00:00, 14.77it/s]
100%|██████████| 9/9 [00:00<00:00, 14.43it/s]
100%|██████████| 9/9 [00:00<00:00, 15.12it/s]
  0%|          | 0/10 [00:00<?, ?it/s]
 30%|███       | 3/10 [00:00<00:00, 22.67it/s]
 60%|██████    | 6/10 [00:00<00:00, 16.73it/s]
 80%|████████  | 8/10 [00:00<00:00, 13.78it/s]
100%|██████████| 10/10 [00:00<00:00, 12.30it/s]
100%|██████████| 10/10 [00:00<00:00, 13.65it/s]
  0%|          | 0/11 [00:00<?, ?it/s]
 27%|██▋       | 3/11 [00:00<00:00, 20.23it/s]
 55%|█████▍    | 6/11 [00:00<00:00, 15.95it/s]
 73%|███████▎  | 8/11 [00:00<00:00, 15.74it/s]
 91%|█████████ | 10/11 [00:00<00:00, 16.51it/s]
100%|██████████| 11/11 [00:00<00:00, 16.50it/s]
  0%|          | 0/12 [00:00<?, ?it/s]
 25%|██▌       | 3/12 [00:00<00:00, 25.79it/s]
 50%|█████     | 6/12 [00:00<00:00, 16.53it/s]
 67%|██████▋   | 8/12 [00:00<00:00, 14.86it/s]
 83%|████████▎ | 10/12 [00:00<00:00, 13.86it/s]
100%|██████████| 12/12 [00:00<00:00, 14.48it/s]
100%|██████████| 12/12 [00:00<00:00, 15.15it/s]
  0%|          | 0/13 [00:00<?, ?it/s]
 23%|██▎       | 3/13 [00:00<00:00, 24.54it/s]
 46%|████▌     | 6/13 [00:00<00:00, 18.19it/s]
 62%|██████▏   | 8/13 [00:00<00:00, 18.23it/s]
 77%|███████▋  | 10/13 [00:00<00:00, 17.22it/s]
 92%|█████████▏| 12/13 [00:00<00:00, 16.10it/s]
100%|██████████| 13/13 [00:00<00:00, 16.68it/s]
  0%|          | 0/14 [00:00<?, ?it/s]
 21%|██▏       | 3/14 [00:00<00:00, 17.62it/s]
 36%|███▌      | 5/14 [00:00<00:00, 15.05it/s]
 50%|█████     | 7/14 [00:00<00:00, 13.57it/s]
 64%|██████▍   | 9/14 [00:00<00:00, 12.83it/s]
 79%|███████▊  | 11/14 [00:00<00:00, 12.17it/s]
 93%|█████████▎| 13/14 [00:00<00:00, 12.87it/s]
100%|██████████| 14/14 [00:01<00:00, 13.35it/s]
  0%|          | 0/15 [00:00<?, ?it/s]
 20%|██        | 3/15 [00:00<00:00, 13.39it/s]
 33%|███▎      | 5/15 [00:00<00:00, 12.88it/s]
 47%|████▋     | 7/15 [00:00<00:00, 13.14it/s]
 60%|██████    | 9/15 [00:00<00:00, 12.38it/s]
 73%|███████▎  | 11/15 [00:00<00:00, 11.69it/s]
 87%|████████▋ | 13/15 [00:01<00:00, 12.51it/s]
100%|██████████| 15/15 [00:01<00:00, 12.57it/s]
100%|██████████| 15/15 [00:01<00:00, 12.54it/s]
fit_time score_time test_precision test_recall test_f1 test_balanced_accuracy test_roc_auc test_average_precision n_observations
mean std mean std mean std mean std mean std mean std mean std mean std mean std
n_features
1 0.005 0.002 0.060 0.025 0.000 0.000 0.000 0.000 0.000 0.000 0.500 0.000 0.749 0.066 0.691 0.087 210.000 0.000
2 0.005 0.002 0.053 0.019 0.606 0.100 0.477 0.108 0.525 0.087 0.622 0.060 0.699 0.062 0.658 0.071 210.000 0.000
3 0.005 0.002 0.055 0.024 0.632 0.078 0.573 0.086 0.596 0.065 0.663 0.051 0.763 0.057 0.721 0.069 210.000 0.000
4 0.005 0.004 0.053 0.023 0.700 0.086 0.657 0.103 0.671 0.070 0.723 0.056 0.799 0.052 0.740 0.070 210.000 0.000
5 0.005 0.003 0.053 0.023 0.695 0.091 0.655 0.103 0.668 0.072 0.719 0.058 0.795 0.053 0.736 0.070 210.000 0.000
6 0.005 0.002 0.051 0.020 0.685 0.081 0.639 0.110 0.655 0.077 0.711 0.058 0.793 0.050 0.737 0.063 210.000 0.000
7 0.004 0.002 0.042 0.012 0.680 0.082 0.628 0.108 0.646 0.074 0.704 0.056 0.788 0.053 0.731 0.066 210.000 0.000
8 0.005 0.002 0.054 0.020 0.682 0.089 0.629 0.124 0.648 0.094 0.707 0.069 0.794 0.057 0.734 0.078 210.000 0.000
9 0.004 0.002 0.040 0.009 0.674 0.084 0.623 0.121 0.642 0.090 0.701 0.066 0.792 0.056 0.726 0.076 210.000 0.000
10 0.004 0.002 0.045 0.016 0.685 0.081 0.662 0.127 0.667 0.089 0.720 0.067 0.798 0.058 0.727 0.076 210.000 0.000
11 0.005 0.002 0.049 0.020 0.679 0.079 0.665 0.123 0.666 0.085 0.717 0.066 0.796 0.057 0.727 0.074 210.000 0.000
12 0.004 0.002 0.041 0.016 0.671 0.083 0.660 0.132 0.659 0.094 0.712 0.070 0.792 0.057 0.721 0.073 210.000 0.000
13 0.006 0.004 0.061 0.025 0.665 0.083 0.651 0.133 0.652 0.096 0.706 0.071 0.790 0.058 0.719 0.077 210.000 0.000
14 0.004 0.002 0.044 0.014 0.657 0.076 0.655 0.130 0.650 0.092 0.703 0.066 0.787 0.059 0.716 0.077 210.000 0.000
15 0.005 0.002 0.046 0.014 0.657 0.076 0.650 0.127 0.647 0.089 0.701 0.063 0.784 0.060 0.709 0.082 210.000 0.000

Best number of features by subset of the data:#

Hide code cell source

n_feat_best = pd.DataFrame(
    {'ald': cv_feat_ald.loc[:, pd.IndexSlice[:, 'mean']].idxmax(),
     'all': cv_feat_all.loc[:, pd.IndexSlice[:, 'mean']].idxmax(),
     'new': cv_feat_new.loc[:, pd.IndexSlice[:, 'mean']].idxmax()
     }
).droplevel(-1)
n_feat_best
ald all new
fit_time 7 14 13
score_time 7 4 13
test_precision 1 13 4
test_recall 12 10 11
test_f1 10 7 4
test_balanced_accuracy 10 7 4
test_roc_auc 11 13 4
test_average_precision 15 13 4
n_observations 1 1 1

Train, test split#

Show number of cases in train and test data

Hide code cell source

X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X,
    target,
    test_size=.2,
    stratify=target_to_group,
    random_state=42)
idx_train = X_train.index
idx_test = X_test.index

njab.pandas.combine_value_counts(
    pd.concat([y_train, y_test],
              axis=1,
              ignore_index=True,
              ).rename(columns={0: 'train', 1: 'test'})
)
train test
False 98 24
True 70 18

Results#

  • run_model returns dataclasses with the further needed results

  • add mrmr selection of data (select best number of features to use instead of fixing it)

Save results for final model on entire data, new features and ALD study criteria selected data.

Hide code cell source

splits = Splits(X_train=X.loc[idx_train],
                X_test=X.loc[idx_test],
                y_train=y_train,
                y_test=y_test)
results_model_full = njab.sklearn.run_model(
    splits,
    n_feat_to_select=n_feat_best.loc['test_roc_auc', 'all'])
results_model_full.name = f'{args.model_key} all'
fname = args.out_folder / f'results_{results_model_full.name}.pkl'
files_out[fname.name] = fname
pimmslearn.io.to_pickle(results_model_full, fname)

splits = Splits(X_train=X.loc[idx_train, new_features],
                X_test=X.loc[idx_test, new_features],
                y_train=y_train,
                y_test=y_test)
results_model_new = njab.sklearn.run_model(
    splits,
    n_feat_to_select=n_feat_best.loc['test_roc_auc', 'new'])
results_model_new.name = f'{args.model_key} new'
fname = args.out_folder / f'results_{results_model_new.name}.pkl'
files_out[fname.name] = fname
pimmslearn.io.to_pickle(results_model_new, fname)

splits_ald = Splits(
    X_train=ald_study.loc[idx_train],
    X_test=ald_study.loc[idx_test],
    y_train=y_train,
    y_test=y_test)
results_ald_full = njab.sklearn.run_model(
    splits_ald,
    n_feat_to_select=n_feat_best.loc['test_roc_auc', 'ald'])
results_ald_full.name = 'ALD study all'
fname = args.out_folder / f'results_{results_ald_full.name}.pkl'
files_out[fname.name] = fname
pimmslearn.io.to_pickle(results_ald_full, fname)
  0%|          | 0/13 [00:00<?, ?it/s]
 15%|█▌        | 2/13 [00:00<00:02,  4.46it/s]
 23%|██▎       | 3/13 [00:00<00:03,  3.25it/s]
 31%|███       | 4/13 [00:01<00:02,  3.07it/s]
 38%|███▊      | 5/13 [00:01<00:02,  3.09it/s]
 46%|████▌     | 6/13 [00:01<00:02,  2.95it/s]
 54%|█████▍    | 7/13 [00:02<00:01,  3.04it/s]
 62%|██████▏   | 8/13 [00:02<00:01,  2.94it/s]
 69%|██████▉   | 9/13 [00:02<00:01,  2.77it/s]
 77%|███████▋  | 10/13 [00:03<00:01,  2.88it/s]
 85%|████████▍ | 11/13 [00:03<00:00,  2.97it/s]
 92%|█████████▏| 12/13 [00:03<00:00,  3.03it/s]
100%|██████████| 13/13 [00:04<00:00,  3.13it/s]
100%|██████████| 13/13 [00:04<00:00,  3.07it/s]
  0%|          | 0/4 [00:00<?, ?it/s]
 75%|███████▌  | 3/4 [00:00<00:00, 24.20it/s]
100%|██████████| 4/4 [00:00<00:00, 20.54it/s]
  0%|          | 0/11 [00:00<?, ?it/s]
 18%|█▊        | 2/11 [00:00<00:01,  8.41it/s]
 27%|██▋       | 3/11 [00:00<00:01,  5.81it/s]
 36%|███▋      | 4/11 [00:00<00:01,  4.90it/s]
 45%|████▌     | 5/11 [00:01<00:01,  4.46it/s]
 55%|█████▍    | 6/11 [00:01<00:01,  3.88it/s]
 64%|██████▎   | 7/11 [00:01<00:01,  3.33it/s]
 73%|███████▎  | 8/11 [00:02<00:00,  3.18it/s]
 82%|████████▏ | 9/11 [00:02<00:00,  2.90it/s]
 91%|█████████ | 10/11 [00:02<00:00,  2.95it/s]
100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
100%|██████████| 11/11 [00:03<00:00,  3.59it/s]

ROC-AUC on test split#

Hide code cell source

fig, ax = plt.subplots(1, 1, figsize=figsize)
plot_split_auc(results_ald_full.test, results_ald_full.name, ax)
plot_split_auc(results_model_full.test, results_model_full.name, ax)
plot_split_auc(results_model_new.test, results_model_new.name, ax)
fname = args.out_folder / 'auc_roc_curve.pdf'
files_out[fname.name] = fname
pimmslearn.savefig(fig, name=fname)
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/auc_roc_curve.pdf
../../../_images/119a9067d144f01650fa4d0441ab3f280032509f6c1ff3bb34e5b27c1c82da8e.png

Data used to plot ROC:

Hide code cell source

res = [results_ald_full, results_model_full, results_model_new]

auc_roc_curve = parse_roc(*res)
auc_roc_curve.to_excel(fname.with_suffix('.xlsx'))
auc_roc_curve
ALD study all DAE all DAE new
fpr tpr fpr tpr fpr tpr
0 0.000 0.000 0.000 0.000 0.000 0.000
1 0.000 0.056 0.000 0.056 0.042 0.000
2 0.000 0.611 0.000 0.167 0.083 0.000
3 0.042 0.611 0.042 0.167 0.083 0.056
4 0.042 0.833 0.042 0.500 0.208 0.056
5 0.167 0.833 0.125 0.500 0.208 0.333
6 0.167 0.889 0.125 0.667 0.250 0.333
7 0.542 0.889 0.167 0.667 0.250 0.556
8 0.542 0.944 0.167 0.778 0.292 0.556
9 0.583 0.944 0.250 0.778 0.292 0.611
10 0.583 1.000 0.250 0.833 0.333 0.611
11 1.000 1.000 0.292 0.833 0.333 0.667
12 NaN NaN 0.292 0.889 0.375 0.667
13 NaN NaN 0.417 0.889 0.375 0.778
14 NaN NaN 0.417 1.000 0.458 0.778
15 NaN NaN 1.000 1.000 0.458 0.889
16 NaN NaN NaN NaN 0.583 0.889
17 NaN NaN NaN NaN 0.583 0.944
18 NaN NaN NaN NaN 0.750 0.944
19 NaN NaN NaN NaN 0.750 1.000
20 NaN NaN NaN NaN 1.000 1.000

Features selected for final models#

Hide code cell source

selected_features = pd.DataFrame(
    [results_ald_full.selected_features,
     results_model_full.selected_features,
     results_model_new.selected_features],
    index=[
        results_ald_full.name,
        results_model_full.name,
        results_model_new.name]
).T
selected_features.index.name = 'rank'
fname = args.out_folder / 'mrmr_feat_by_model.xlsx'
files_out[fname.name] = fname
selected_features.to_excel(fname)
selected_features
ALD study all DAE all DAE new
rank
0 P10636-2;P10636-6 P10636-2;P10636-6 P31321
1 K7ER15;Q9H0R4;Q9H0R4-2 A6NLU5 Q15847
2 P02741 A6NNI4;G8JLH6;P21926 Q14894
3 P61981 P35052 P51688
4 P04075 P61981 None
5 P14174 Q9Y2T3;Q9Y2T3-3 None
6 Q9Y2T3;Q9Y2T3-3 P04075 None
7 P08294 P14174 None
8 P00338;P00338-3 A0A0C4DGY8;D6RA00;Q9UHY7 None
9 P14618 P63104 None
10 Q6EMK4 Q14894 None
11 None P25189;P25189-2 None
12 None P00492 None

Precision-Recall plot on test data#

Hide code cell source

fig, ax = plt.subplots(1, 1, figsize=figsize)

ax = plot_split_prc(results_ald_full.test, results_ald_full.name, ax)
ax = plot_split_prc(results_model_full.test, results_model_full.name, ax)
ax = plot_split_prc(results_model_new.test, results_model_new.name, ax)
fname = folder = args.out_folder / 'prec_recall_curve.pdf'
files_out[fname.name] = fname
pimmslearn.savefig(fig, name=fname)
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/prec_recall_curve.pdf
../../../_images/18123e7d5fdcc86cb849a4a36654773f7aff3afa1e46ef26a150fa993d0fda4e.png

Data used to plot PRC:

Hide code cell source

prec_recall_curve = parse_prc(*res)
prec_recall_curve.to_excel(fname.with_suffix('.xlsx'))
prec_recall_curve
ALD study all DAE all DAE new
precision tpr precision tpr precision tpr
0 0.429 1.000 0.429 1.000 0.429 1.000
1 0.439 1.000 0.439 1.000 0.439 1.000
2 0.450 1.000 0.450 1.000 0.450 1.000
3 0.462 1.000 0.462 1.000 0.462 1.000
4 0.474 1.000 0.474 1.000 0.474 1.000
5 0.486 1.000 0.486 1.000 0.486 1.000
6 0.500 1.000 0.500 1.000 0.500 1.000
7 0.514 1.000 0.514 1.000 0.486 0.944
8 0.529 1.000 0.529 1.000 0.500 0.944
9 0.545 1.000 0.545 1.000 0.515 0.944
10 0.562 1.000 0.562 1.000 0.531 0.944
11 0.548 0.944 0.581 1.000 0.548 0.944
12 0.567 0.944 0.600 1.000 0.533 0.889
13 0.552 0.889 0.621 1.000 0.552 0.889
14 0.571 0.889 0.643 1.000 0.571 0.889
15 0.593 0.889 0.630 0.944 0.593 0.889
16 0.615 0.889 0.615 0.889 0.577 0.833
17 0.640 0.889 0.640 0.889 0.560 0.778
18 0.667 0.889 0.667 0.889 0.583 0.778
19 0.696 0.889 0.696 0.889 0.609 0.778
20 0.727 0.889 0.682 0.833 0.591 0.722
21 0.762 0.889 0.714 0.833 0.571 0.667
22 0.800 0.889 0.700 0.778 0.600 0.667
23 0.789 0.833 0.737 0.778 0.579 0.611
24 0.833 0.833 0.778 0.778 0.611 0.611
25 0.882 0.833 0.765 0.722 0.588 0.556
26 0.938 0.833 0.750 0.667 0.625 0.556
27 0.933 0.778 0.800 0.667 0.600 0.500
28 0.929 0.722 0.786 0.611 0.571 0.444
29 0.923 0.667 0.769 0.556 0.538 0.389
30 0.917 0.611 0.750 0.500 0.500 0.333
31 1.000 0.611 0.818 0.500 0.545 0.333
32 1.000 0.556 0.900 0.500 0.500 0.278
33 1.000 0.500 0.889 0.444 0.444 0.222
34 1.000 0.444 0.875 0.389 0.375 0.167
35 1.000 0.389 0.857 0.333 0.286 0.111
36 1.000 0.333 0.833 0.278 0.167 0.056
37 1.000 0.278 0.800 0.222 0.200 0.056
38 1.000 0.222 0.750 0.167 0.250 0.056
39 1.000 0.167 1.000 0.167 0.333 0.056
40 1.000 0.111 1.000 0.111 0.000 0.000
41 1.000 0.056 1.000 0.056 0.000 0.000
42 1.000 0.000 1.000 0.000 1.000 0.000

Train data plots#

Hide code cell source

fig, ax = plt.subplots(1, 1, figsize=figsize)

ax = plot_split_prc(results_ald_full.train, results_ald_full.name, ax)
ax = plot_split_prc(results_model_full.train, results_model_full.name, ax)
ax = plot_split_prc(results_model_new.train, results_model_new.name, ax)
fname = folder = args.out_folder / 'prec_recall_curve_train.pdf'
files_out[fname.name] = fname
pimmslearn.savefig(fig, name=fname)
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/prec_recall_curve_train.pdf
../../../_images/8279daa41c770548acba99a8f6899f64834817cdd63f190a9f3ad64bf76b5119.png

Hide code cell source

fig, ax = plt.subplots(1, 1, figsize=figsize)
plot_split_auc(results_ald_full.train, results_ald_full.name, ax)
plot_split_auc(results_model_full.train, results_model_full.name, ax)
plot_split_auc(results_model_new.train, results_model_new.name, ax)
fname = folder = args.out_folder / 'auc_roc_curve_train.pdf'
files_out[fname.name] = fname
pimmslearn.savefig(fig, name=fname)
pimmslearn.plotting - INFO     Saved Figures to runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/auc_roc_curve_train.pdf
../../../_images/6c92a142979e2b22c28257f807d78a3a4267166f40ccbc5e27dd892bab91459b.png

Output files:

Hide code cell source

files_out
{'results_DAE all.pkl': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/results_DAE all.pkl'),
 'results_DAE new.pkl': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/results_DAE new.pkl'),
 'results_ALD study all.pkl': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/results_ALD study all.pkl'),
 'auc_roc_curve.pdf': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/auc_roc_curve.pdf'),
 'mrmr_feat_by_model.xlsx': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/mrmr_feat_by_model.xlsx'),
 'prec_recall_curve.pdf': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/prec_recall_curve.pdf'),
 'prec_recall_curve_train.pdf': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/prec_recall_curve_train.pdf'),
 'auc_roc_curve_train.pdf': PosixPath('runs/alzheimer_study/diff_analysis/AD/PI_vs_DAE/auc_roc_curve_train.pdf')}