Compare predictions between model and RSN#

  • see differences in imputation for diverging cases

  • dumps top5

Hide code cell source

import logging
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import njab
import pandas as pd
import seaborn

import pimmslearn
import pimmslearn.analyzers
import pimmslearn.imputation
import pimmslearn.io.datasplits

logger = pimmslearn.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.WARNING)

plt.rcParams['figure.figsize'] = [4, 2.5]  # [16.0, 7.0] , [4, 3]
pimmslearn.plotting.make_large_descriptors(7)

# catch passed parameters
args = None
args = dict(globals()).keys()

Parameters#

folder_experiment = 'runs/appl_ald_data/plasma/proteinGroups'
fn_clinical_data = "data/ALD_study/processed/ald_metadata_cli.csv"
make_plots = True  # create histograms and swarmplots of diverging results
model_key = 'VAE'
sample_id_col = 'Sample ID'
target = 'kleiner'
cutoff_target: int = 2  # => for binarization target >= cutoff_target
out_folder = 'diff_analysis'
file_format = 'csv'
baseline = 'RSN'  # default is RSN, but could be any other trained model
template_pred = 'pred_real_na_{}.csv'  # fixed, do not change
ref_method_score = None  # filepath to reference method score
# Parameters
cutoff_target = 0.5
make_plots = False
ref_method_score = None
folder_experiment = "runs/alzheimer_study"
target = "AD"
baseline = "PI"
out_folder = "diff_analysis"
fn_clinical_data = "runs/alzheimer_study/data/clinical_data.csv"

Hide code cell source

params = pimmslearn.nb.get_params(args, globals=globals())
args = pimmslearn.nb.Config()
args.folder_experiment = Path(params["folder_experiment"])
args = pimmslearn.nb.add_default_paths(args,
                                 out_root=(args.folder_experiment
                                           / params["out_folder"]
                                           / params["target"]))
args.folder_scores = (args.folder_experiment
                      / params["out_folder"]
                      / params["target"]
                      / 'scores'
                      )
args.update_from_dict(params)
args
root - INFO     Removed from global namespace: folder_experiment
root - INFO     Removed from global namespace: fn_clinical_data
root - INFO     Removed from global namespace: make_plots
root - INFO     Removed from global namespace: model_key
root - INFO     Removed from global namespace: sample_id_col
root - INFO     Removed from global namespace: target
root - INFO     Removed from global namespace: cutoff_target
root - INFO     Removed from global namespace: out_folder
root - INFO     Removed from global namespace: file_format
root - INFO     Removed from global namespace: baseline
root - INFO     Removed from global namespace: template_pred
root - INFO     Removed from global namespace: ref_method_score
root - INFO     Already set attribute: folder_experiment has value runs/alzheimer_study
root - INFO     Already set attribute: out_folder has value diff_analysis
{'baseline': 'PI',
 'cutoff_target': 0.5,
 'data': PosixPath('runs/alzheimer_study/data'),
 'file_format': 'csv',
 'fn_clinical_data': 'runs/alzheimer_study/data/clinical_data.csv',
 'folder_experiment': PosixPath('runs/alzheimer_study'),
 'folder_scores': PosixPath('runs/alzheimer_study/diff_analysis/AD/scores'),
 'make_plots': False,
 'model_key': 'VAE',
 'out_figures': PosixPath('runs/alzheimer_study/figures'),
 'out_folder': PosixPath('runs/alzheimer_study/diff_analysis/AD'),
 'out_metrics': PosixPath('runs/alzheimer_study'),
 'out_models': PosixPath('runs/alzheimer_study'),
 'out_preds': PosixPath('runs/alzheimer_study/preds'),
 'ref_method_score': None,
 'sample_id_col': 'Sample ID',
 'target': 'AD',
 'template_pred': 'pred_real_na_{}.csv'}

Write outputs to excel

Hide code cell source

files_out = dict()

fname = args.out_folder / 'diff_analysis_compare_DA.xlsx'
writer = pd.ExcelWriter(fname)
files_out[fname.name] = fname.as_posix()
logger.info("Writing to excel file: %s", fname)
root - INFO     Writing to excel file: runs/alzheimer_study/diff_analysis/AD/diff_analysis_compare_DA.xlsx

Load scores#

List dump of scores:

Hide code cell source

score_dumps = [fname for fname in Path(
    args.folder_scores).iterdir() if fname.suffix == '.pkl']
score_dumps
[PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_Median.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_RF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_TRKNN.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_None.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_VAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_PI.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_CF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_DAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_QRILC.pkl')]

Load scores from dumps:

Hide code cell source

scores = pd.concat([pd.read_pickle(fname) for fname in score_dumps], axis=1)
scores
model Median RF ... DAE QRILC
var SS DF F p-unc np2 -Log10 pvalue qvalue rejected SS DF ... qvalue rejected SS DF F p-unc np2 -Log10 pvalue qvalue rejected
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 0.830 1 6.377 0.012 0.032 1.907 0.039 True 1.010 1 ... 0.016 True 0.741 1 4.819 0.029 0.025 1.532 0.075 False
age 0.001 1 0.006 0.939 0.000 0.027 0.966 False 0.002 1 ... 0.903 False 0.010 1 0.062 0.803 0.000 0.095 0.871 False
Kiel 0.106 1 0.815 0.368 0.004 0.435 0.532 False 0.206 1 ... 0.279 False 0.394 1 2.560 0.111 0.013 0.954 0.214 False
Magdeburg 0.219 1 1.680 0.197 0.009 0.707 0.343 False 0.388 1 ... 0.142 False 0.877 1 5.702 0.018 0.029 1.747 0.050 True
Sweden 1.101 1 8.461 0.004 0.042 2.392 0.016 True 1.521 1 ... 0.003 True 2.352 1 15.296 0.000 0.074 3.894 0.001 True
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
S4R3U6 AD 0.051 1 0.119 0.730 0.001 0.136 0.829 False 1.041 1 ... 0.087 False 7.824 1 3.739 0.055 0.019 1.262 0.122 False
age 1.214 1 2.845 0.093 0.015 1.030 0.194 False 0.807 1 ... 0.367 False 0.011 1 0.005 0.941 0.000 0.026 0.965 False
Kiel 0.861 1 2.018 0.157 0.010 0.804 0.289 False 1.830 1 ... 0.068 False 6.913 1 3.303 0.071 0.017 1.150 0.150 False
Magdeburg 0.216 1 0.506 0.478 0.003 0.321 0.631 False 1.832 1 ... 0.036 True 21.543 1 10.294 0.002 0.051 2.805 0.006 True
Sweden 3.965 1 9.288 0.003 0.046 2.580 0.011 True 12.678 1 ... 0.000 True 0.048 1 0.023 0.879 0.000 0.056 0.926 False

7105 rows × 72 columns

If reference dump is provided, add it to the scores

Hide code cell source

if args.ref_method_score:
    scores_reference = (pd
                        .read_pickle(args.ref_method_score)
                        .rename({'None': 'None (100%)'},
                                axis=1))
    scores = scores.join(scores_reference)
    logger.info(f'Added reference method scores from {args.ref_method_score}')

Load frequencies of observed features#

Hide code cell source

fname = args.folder_experiment / 'freq_features_observed.csv'
freq_feat = pd.read_csv(fname, index_col=0)
freq_feat.columns = pd.MultiIndex.from_tuples([('data', 'frequency'),])
freq_feat
data
frequency
protein groups
A0A024QZX5;A0A087X1N8;P35237 186
A0A024R0T9;K7ER74;P02655 195
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 174
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 196
A0A075B6H7 91
... ...
Q9Y6R7 197
Q9Y6X5 173
Q9Y6Y8;Q9Y6Y8-2 197
Q9Y6Y9 119
S4R3U6 126

1421 rows × 1 columns

Assemble qvalues#

Hide code cell source

qvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'qvalue']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
qvalues.index.names = qvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'qvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
qvalues.to_pickle(fname)
qvalues.to_excel(writer, sheet_name='qvalues_all')
qvalues
Median RF TRKNN None VAE PI CF DAE QRILC
qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.039 0.019 0.023 0.043 0.019 0.397 0.020 0.016 0.075
A0A024R0T9;K7ER74;P02655 AD 195 0.087 0.072 0.071 0.092 0.071 0.139 0.077 0.074 0.085
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.832 0.540 0.394 0.586 0.392 0.103 0.510 0.380 0.454
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.418 0.390 0.396 0.404 0.372 0.580 0.376 0.376 0.453
A0A075B6H7 AD 91 0.124 0.026 0.048 0.027 0.014 0.075 0.013 0.036 0.288
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.315 0.292 0.289 0.307 0.283 0.316 0.284 0.283 0.302
Q9Y6X5 AD 173 0.455 0.327 0.205 0.501 0.390 0.159 0.301 0.240 0.199
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.178 0.162 0.160 0.174 0.156 0.182 0.158 0.156 0.171
Q9Y6Y9 AD 119 0.667 0.468 0.472 0.651 0.531 0.512 0.733 0.808 0.891
S4R3U6 AD 126 0.829 0.238 0.080 0.802 0.120 0.623 0.117 0.087 0.122

1421 rows × 9 columns

Assemble pvalues#

Hide code cell source

pvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'p-unc']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
pvalues.index.names = pvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'pvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
pvalues.to_pickle(fname)
pvalues.to_excel(writer, sheet_name='pvalues_all')
pvalues
Median RF TRKNN None VAE PI CF DAE QRILC
p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.012 0.006 0.008 0.015 0.007 0.240 0.007 0.005 0.029
A0A024R0T9;K7ER74;P02655 AD 195 0.033 0.031 0.031 0.037 0.032 0.059 0.035 0.033 0.035
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.736 0.403 0.264 0.432 0.266 0.040 0.380 0.257 0.305
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.259 0.257 0.266 0.254 0.249 0.420 0.251 0.253 0.303
A0A075B6H7 AD 91 0.053 0.009 0.020 0.008 0.005 0.027 0.004 0.014 0.165
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175
Q9Y6X5 AD 173 0.291 0.203 0.113 0.344 0.265 0.070 0.189 0.143 0.101
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083
Q9Y6Y9 AD 119 0.520 0.328 0.334 0.505 0.403 0.348 0.628 0.723 0.832
S4R3U6 AD 126 0.730 0.135 0.036 0.698 0.060 0.469 0.058 0.041 0.055

1421 rows × 9 columns

Assemble rejected features#

Hide code cell source

da_target = scores.loc[pd.IndexSlice[:, args.target],
                       pd.IndexSlice[:, 'rejected']
                       ].join(freq_feat
                              ).set_index(
    ('data', 'frequency'), append=True)
da_target.index.names = da_target.index.names[:-1] + ['frequency']
fname = args.out_folder / 'equality_rejected_target.pkl'
files_out[fname.name] = fname.as_posix()
da_target.to_pickle(fname)
count_rejected = njab.pandas.combine_value_counts(da_target.droplevel(-1, axis=1))
count_rejected.to_excel(writer, sheet_name='count_rejected')
count_rejected
Median RF TRKNN None VAE PI CF DAE QRILC
False 1,069 961 936 1,054 939 1,031 952 938 999
True 352 460 485 367 482 390 469 483 422

Tabulate rejected decisions by method:#

Hide code cell source

# ! This uses implicitly that RSN is not available for some protein groups
# ! Make an explicit list of the 313 protein groups available in original data
mask_common = da_target.notna().all(axis=1)
count_rejected_common = njab.pandas.combine_value_counts(da_target.loc[mask_common].droplevel(-1, axis=1))
count_rejected_common.to_excel(writer, sheet_name='count_rejected_common')
count_rejected_common
Median RF TRKNN None VAE PI CF DAE QRILC
False 1,069 961 936 1,054 939 1,031 952 938 999
True 352 460 485 367 482 390 469 483 422

Tabulate rejected decisions by method for newly included features (if available)#

Hide code cell source

count_rejected_new = njab.pandas.combine_value_counts(da_target.loc[~mask_common].droplevel(-1, axis=1))
count_rejected_new.to_excel(writer, sheet_name='count_rejected_new')
count_rejected_new
Median RF TRKNN None VAE PI CF DAE QRILC

Tabulate rejected decisions by method for all features#

Hide code cell source

da_target.to_excel(writer, sheet_name='equality_rejected_all')
logger.info("Written to sheet 'equality_rejected_all' in excel file.")
da_target
root - INFO     Written to sheet 'equality_rejected_all' in excel file.
Median RF TRKNN None VAE PI CF DAE QRILC
rejected rejected rejected rejected rejected rejected rejected rejected rejected
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 True True True True True False True True False
A0A024R0T9;K7ER74;P02655 AD 195 False False False False False False False False False
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 False False False False False False False False False
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 False False False False False False False False False
A0A075B6H7 AD 91 False True True True True False True True False
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 False False False False False False False False False
Q9Y6X5 AD 173 False False False False False False False False False
Q9Y6Y8;Q9Y6Y8-2 AD 197 False False False False False False False False False
Q9Y6Y9 AD 119 False False False False False False False False False
S4R3U6 AD 126 False False False False False False False False False

1421 rows × 9 columns

Tabulate number of equal decison by method (True) to the ones with varying decision depending on the method (False)

Hide code cell source

da_target_same = (da_target.sum(axis=1) == 0) | da_target.all(axis=1)
da_target_same.value_counts()
True    1,106
False     315
Name: count, dtype: int64

List frequency of features with varying decisions

Hide code cell source

feat_idx_w_diff = da_target_same[~da_target_same].index
feat_idx_w_diff.to_frame()[['frequency']].reset_index(-1, drop=True)
frequency
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 186
A0A075B6H7 AD 91
A0A075B6H9 AD 189
A0A075B6I0 AD 194
A0A075B6J9 AD 156
... ... ...
Q9UNW1 AD 171
Q9UP79 AD 135
Q9UPU3 AD 163
Q9UQ52 AD 188
Q9Y6C2 AD 119

315 rows × 1 columns

take only those with different decisions

Hide code cell source

(qvalues
 .loc[feat_idx_w_diff]
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff')
 )

(qvalues
 .loc[feat_idx_w_diff]
 .loc[mask_common]  # mask automatically aligned
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff_common')
 )

try:
    (qvalues
     .loc[feat_idx_w_diff]
     .loc[~mask_common]
     .sort_values(('None', 'qvalue'))
     .to_excel(writer, sheet_name='qvalues_diff_new')
     )
except IndexError:
    print("No new features or no new ones (with diverging decisions.)")
writer.close()
No new features or no new ones (with diverging decisions.)

Plots for inspecting imputations (for diverging decisions)#

Hide code cell source

if not args.make_plots:
    logger.warning("Not plots requested.")
    import sys
    sys.exit(0)
root - WARNING  Not plots requested.
/home/runner/work/pimms/pimms/project/.snakemake/conda/43fbe714d68d8fe6f9b0c93f5652adb3_/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3756: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
An exception has occurred, use %tb to see the full traceback.

SystemExit: 0

Load target#

Hide code cell source

target = pd.read_csv(args.fn_clinical_data,
                     index_col=0,
                     usecols=[args.sample_id_col, args.target])
target = target.dropna()
target

Hide code cell source

target_to_group = target.copy()
target = target >= args.cutoff_target
target = target.replace({False: f'{args.target} < {args.cutoff_target}',
                        True: f'{args.target} >= {args.cutoff_target}'}
                        ).astype('category')
pd.crosstab(target.squeeze(), target_to_group.squeeze())

Measurments#

Hide code cell source

data = pimmslearn.io.datasplits.DataSplits.from_folder(
    args.data,
    file_format=args.file_format)
data = pd.concat([data.train_X, data.val_y, data.test_y]).unstack()
data

plot all of the new pgs which are at least once significant which are not already dumped.

Hide code cell source

feat_new_abundant = da_target.loc[~mask_common].any(axis=1)
feat_new_abundant = feat_new_abundant.loc[feat_new_abundant].index.get_level_values(0)
feat_new_abundant

Hide code cell source

feat_sel = feat_idx_w_diff.get_level_values(0)
feat_sel = feat_sel.union(feat_new_abundant)
len(feat_sel)

Hide code cell source

data = data.loc[:, feat_sel]
data
  • RSN prediction are based on all samples mean and std (N=455) as in original study

  • VAE also trained on all samples (self supervised) One could also reduce the selected data to only the samples with a valid target marker, but this was not done in the original study which considered several different target markers.

RSN : shifted per sample, not per feature!

Load all prediction files and reshape

Hide code cell source

# exclude 'None' as this is without imputation (-> data)
model_keys = [k for k in qvalues.columns.get_level_values(0) if k != 'None']
pred_paths = [
    args.out_preds / args.template_pred.format(method)
    for method in model_keys]
pred_paths

Hide code cell source

load_single_csv_pred_file = pimmslearn.analyzers.compare_predictions.load_single_csv_pred_file
pred_real_na = dict()
for method in model_keys:
    fname = args.out_preds / args.template_pred.format(method)
    print(f"missing values pred. by {method}: {fname}")
    pred_real_na[method] = load_single_csv_pred_file(fname)
pred_real_na = pd.DataFrame(pred_real_na)
pred_real_na

Once imputation, reduce to target samples only (samples with target score)

Hide code cell source

# select samples with target information
data = data.loc[target.index]
pred_real_na = pred_real_na.loc[target.index]

# assert len(data) == len(pred_real_na)

Hide code cell source

idx = feat_sel[0]

Hide code cell source

feat_observed = data[idx].dropna()
feat_observed

Hide code cell source

# axes = axes.ravel()
# args.out_folder.parent / 'intensity_plots'
# each feature -> one plot?
# plot all which are at least for one method significant?
folder = args.out_folder / 'intensities_for_diff_in_DA_decision'
folder.mkdir(parents=True, exist_ok=True)

Hide code cell source

min_y_int, max_y_int = pimmslearn.plotting.data.get_min_max_iterable(
    [data.stack(), pred_real_na.stack()])
min_max = min_y_int, max_y_int

target_name = target.columns[0]

min_max, target_name

Compare with target annotation#

Hide code cell source

# labels somehow?
# target.replace({True: f' >={args.cutoff_target}', False: f'<{args.cutoff_target}'})

for i, idx in enumerate(feat_sel):
    print(f"Swarmplot {i:3<}: {idx}:")
    fig, ax = plt.subplots()

    # dummy plots, just to get the Path objects
    tmp_dot = ax.scatter([1, 2], [3, 4], marker='X')
    new_mk, = tmp_dot.get_paths()
    tmp_dot.remove()

    feat_observed = data[idx].dropna()

    def get_centered_label(method, n, q):
        model_str = f'{method}'
        stats_str = f'(N={n:,d}, q={q:.3f})'
        if len(model_str) > len(stats_str):
            stats_str = f"{stats_str:<{len(model_str)}}"
        else:
            model_str = f"{model_str:<{len(stats_str)}}"
        return f'{model_str}\n{stats_str}'

    key = get_centered_label(method='observed',
                             n=len(feat_observed),
                             q=float(qvalues.loc[idx, ('None', 'qvalue')])
                             )
    to_plot = {key: feat_observed}
    for method in model_keys:
        try:
            pred = pred_real_na.loc[pd.IndexSlice[:,
                                                  idx], method].dropna().droplevel(-1)
            if len(pred) == 0:
                # in case no values was imputed -> qvalue is as based on measured
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, ('None', 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].notna().all():
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, (method, 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].isna().all():
                logger.info(f"NA qvalues for {idx}: {method}")
                continue
            else:
                raise ValueError("Unknown case.")
            to_plot[key] = pred
        except KeyError:
            print(f"No missing values for {idx}: {method}")
            continue

    to_plot = pd.DataFrame.from_dict(to_plot)
    to_plot.columns.name = 'group'
    groups_order = to_plot.columns.to_list()
    to_plot = to_plot.stack().to_frame('intensity').reset_index(-1)
    to_plot = to_plot.join(target.astype('category'), how='inner')
    to_plot = to_plot.astype({'group': 'category'})

    ax = seaborn.swarmplot(data=to_plot,
                           x='group',
                           y='intensity',
                           order=groups_order,
                           dodge=True,
                           hue=args.target,
                           size=2,
                           ax=ax)
    first_pg = idx.split(";")[0]
    ax.set_title(
        f'Imputation for protein group {first_pg} with target {target_name} (N= {len(data):,d} samples)')

    _ = ax.set_ylim(min_y_int, max_y_int)
    _ = ax.locator_params(axis='y', integer=True)
    _ = ax.set_xlabel('')
    _xticks = ax.get_xticks()
    ax.xaxis.set_major_locator(
        matplotlib.ticker.FixedLocator(_xticks)
    )
    _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45,
                           horizontalalignment='right')

    N_hues = len(pd.unique(to_plot[args.target]))

    _ = ax.collections[0].set_paths([new_mk])
    _ = ax.collections[1].set_paths([new_mk])

    label_target_0, label_target_1 = ax.collections[-2].get_label(), ax.collections[-1].get_label()
    _ = ax.collections[-2].set_label(f'imputed, {label_target_0}')
    _ = ax.collections[-1].set_label(f'imputed, {label_target_1}')
    _obs_label0 = ax.scatter([], [], color='C0', marker='X', label=f'observed, {label_target_0}')
    _obs_label1 = ax.scatter([], [], color='C1', marker='X', label=f'observed, {label_target_1}')
    _ = ax.legend(
        handles=[_obs_label0, _obs_label1, *ax.collections[-4:-2]],
        fontsize=5, title_fontsize=5, markerscale=0.4,)
    fname = (folder /
             f'{first_pg}_swarmplot.pdf')
    files_out[fname.name] = fname.as_posix()
    pimmslearn.savefig(
        fig,
        name=fname)
    plt.close()

Saved files:

Hide code cell source

files_out