Compare predictions between model and RSN#

  • see differences in imputation for diverging cases

  • dumps top5

Hide code cell source

import logging
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import njab
import pandas as pd
import seaborn

import pimmslearn
import pimmslearn.analyzers
import pimmslearn.imputation
import pimmslearn.io.datasplits

logger = pimmslearn.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.WARNING)

plt.rcParams['figure.figsize'] = [4, 2.5]  # [16.0, 7.0] , [4, 3]
pimmslearn.plotting.make_large_descriptors(7)

# catch passed parameters
args = None
args = dict(globals()).keys()

Parameters#

folder_experiment = 'runs/appl_ald_data/plasma/proteinGroups'
fn_clinical_data = "data/ALD_study/processed/ald_metadata_cli.csv"
make_plots = True  # create histograms and swarmplots of diverging results
model_key = 'VAE'
sample_id_col = 'Sample ID'
target = 'kleiner'
cutoff_target: int = 2  # => for binarization target >= cutoff_target
out_folder = 'diff_analysis'
file_format = 'csv'
baseline = 'RSN'  # default is RSN, but could be any other trained model
template_pred = 'pred_real_na_{}.csv'  # fixed, do not change
ref_method_score = None  # filepath to reference method score
# Parameters
cutoff_target = 0.5
make_plots = False
ref_method_score = None
folder_experiment = "runs/alzheimer_study"
target = "AD"
baseline = "PI"
out_folder = "diff_analysis"
fn_clinical_data = "runs/alzheimer_study/data/clinical_data.csv"

Hide code cell source

params = pimmslearn.nb.get_params(args, globals=globals())
args = pimmslearn.nb.Config()
args.folder_experiment = Path(params["folder_experiment"])
args = pimmslearn.nb.add_default_paths(args,
                                 out_root=(args.folder_experiment
                                           / params["out_folder"]
                                           / params["target"]))
args.folder_scores = (args.folder_experiment
                      / params["out_folder"]
                      / params["target"]
                      / 'scores'
                      )
args.update_from_dict(params)
args
root - INFO     Removed from global namespace: folder_experiment
root - INFO     Removed from global namespace: fn_clinical_data
root - INFO     Removed from global namespace: make_plots
root - INFO     Removed from global namespace: model_key
root - INFO     Removed from global namespace: sample_id_col
root - INFO     Removed from global namespace: target
root - INFO     Removed from global namespace: cutoff_target
root - INFO     Removed from global namespace: out_folder
root - INFO     Removed from global namespace: file_format
root - INFO     Removed from global namespace: baseline
root - INFO     Removed from global namespace: template_pred
root - INFO     Removed from global namespace: ref_method_score
root - INFO     Already set attribute: folder_experiment has value runs/alzheimer_study
root - INFO     Already set attribute: out_folder has value diff_analysis
{'baseline': 'PI',
 'cutoff_target': 0.5,
 'data': PosixPath('runs/alzheimer_study/data'),
 'file_format': 'csv',
 'fn_clinical_data': 'runs/alzheimer_study/data/clinical_data.csv',
 'folder_experiment': PosixPath('runs/alzheimer_study'),
 'folder_scores': PosixPath('runs/alzheimer_study/diff_analysis/AD/scores'),
 'make_plots': False,
 'model_key': 'VAE',
 'out_figures': PosixPath('runs/alzheimer_study/figures'),
 'out_folder': PosixPath('runs/alzheimer_study/diff_analysis/AD'),
 'out_metrics': PosixPath('runs/alzheimer_study'),
 'out_models': PosixPath('runs/alzheimer_study'),
 'out_preds': PosixPath('runs/alzheimer_study/preds'),
 'ref_method_score': None,
 'sample_id_col': 'Sample ID',
 'target': 'AD',
 'template_pred': 'pred_real_na_{}.csv'}

Write outputs to excel

Hide code cell source

files_out = dict()

fname = args.out_folder / 'diff_analysis_compare_DA.xlsx'
writer = pd.ExcelWriter(fname)
files_out[fname.name] = fname.as_posix()
logger.info("Writing to excel file: %s", fname)
root - INFO     Writing to excel file: runs/alzheimer_study/diff_analysis/AD/diff_analysis_compare_DA.xlsx

Load scores#

List dump of scores:

Hide code cell source

score_dumps = [fname for fname in Path(
    args.folder_scores).iterdir() if fname.suffix == '.pkl']
score_dumps
[PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_PI.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_VAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_None.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_DAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_RF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_Median.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_QRILC.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_CF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_TRKNN.pkl')]

Load scores from dumps:

Hide code cell source

scores = pd.concat([pd.read_pickle(fname) for fname in score_dumps], axis=1)
scores
model PI VAE ... CF TRKNN
var SS DF F p-unc np2 -Log10 pvalue qvalue rejected SS DF ... qvalue rejected SS DF F p-unc np2 -Log10 pvalue qvalue rejected
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 0.261 1 0.456 0.501 0.002 0.301 0.653 False 1.005 1 ... 0.020 True 0.994 1 7.134 0.008 0.036 2.085 0.023 True
age 0.036 1 0.063 0.802 0.000 0.096 0.878 False 0.011 1 ... 0.907 False 0.004 1 0.029 0.864 0.000 0.063 0.913 False
Kiel 1.631 1 2.848 0.093 0.015 1.031 0.198 False 0.262 1 ... 0.281 False 0.269 1 1.933 0.166 0.010 0.780 0.277 False
Magdeburg 4.653 1 8.127 0.005 0.041 2.315 0.018 True 0.416 1 ... 0.125 False 0.519 1 3.727 0.055 0.019 1.259 0.114 False
Sweden 7.129 1 12.451 0.001 0.061 3.281 0.003 True 1.584 1 ... 0.002 True 1.796 1 12.893 0.000 0.063 3.378 0.002 True
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
S4R3U6 AD 0.832 1 0.895 0.345 0.005 0.462 0.508 False 1.745 1 ... 0.068 False 2.295 1 4.480 0.036 0.023 1.449 0.080 False
age 0.017 1 0.019 0.891 0.000 0.050 0.934 False 0.489 1 ... 0.396 False 0.398 1 0.777 0.379 0.004 0.421 0.516 False
Kiel 0.472 1 0.508 0.477 0.003 0.322 0.630 False 2.470 1 ... 0.055 False 2.981 1 5.819 0.017 0.030 1.775 0.043 True
Magdeburg 1.695 1 1.824 0.178 0.009 0.749 0.320 False 2.088 1 ... 0.035 True 3.440 1 6.716 0.010 0.034 1.987 0.028 True
Sweden 13.108 1 14.110 0.000 0.069 3.640 0.001 True 15.708 1 ... 0.000 True 27.114 1 52.939 0.000 0.217 11.062 0.000 True

7105 rows × 72 columns

If reference dump is provided, add it to the scores

Hide code cell source

if args.ref_method_score:
    scores_reference = (pd
                        .read_pickle(args.ref_method_score)
                        .rename({'None': 'None (100%)'},
                                axis=1))
    scores = scores.join(scores_reference)
    logger.info(f'Added reference method scores from {args.ref_method_score}')

Load frequencies of observed features#

Hide code cell source

fname = args.folder_experiment / 'freq_features_observed.csv'
freq_feat = pd.read_csv(fname, index_col=0)
freq_feat.columns = pd.MultiIndex.from_tuples([('data', 'frequency'),])
freq_feat
data
frequency
protein groups
A0A024QZX5;A0A087X1N8;P35237 186
A0A024R0T9;K7ER74;P02655 195
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 174
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 196
A0A075B6H7 91
... ...
Q9Y6R7 197
Q9Y6X5 173
Q9Y6Y8;Q9Y6Y8-2 197
Q9Y6Y9 119
S4R3U6 126

1421 rows × 1 columns

Assemble qvalues#

Hide code cell source

qvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'qvalue']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
qvalues.index.names = qvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'qvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
qvalues.to_pickle(fname)
qvalues.to_excel(writer, sheet_name='qvalues_all')
qvalues
PI VAE None DAE RF Median QRILC CF TRKNN
qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.653 0.020 0.043 0.017 0.020 0.039 0.103 0.020 0.023
A0A024R0T9;K7ER74;P02655 AD 195 0.092 0.067 0.092 0.077 0.072 0.087 0.080 0.084 0.071
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.214 0.428 0.586 0.474 0.523 0.832 0.517 0.707 0.394
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.747 0.377 0.404 0.375 0.389 0.418 0.455 0.385 0.396
A0A075B6H7 AD 91 0.146 0.019 0.027 0.024 0.007 0.124 0.283 0.023 0.048
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.317 0.282 0.307 0.283 0.291 0.315 0.304 0.285 0.289
Q9Y6X5 AD 173 0.136 0.332 0.501 0.473 0.323 0.455 0.078 0.171 0.205
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.183 0.155 0.174 0.157 0.162 0.178 0.171 0.158 0.160
Q9Y6Y9 AD 119 0.697 0.822 0.651 0.939 0.518 0.667 0.891 0.449 0.472
S4R3U6 AD 126 0.508 0.122 0.803 0.080 0.159 0.829 0.603 0.068 0.080

1421 rows × 9 columns

Assemble pvalues#

Hide code cell source

pvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'p-unc']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
pvalues.index.names = pvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'pvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
pvalues.to_pickle(fname)
pvalues.to_excel(writer, sheet_name='pvalues_all')
pvalues
PI VAE None DAE RF Median QRILC CF TRKNN
p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.501 0.007 0.015 0.006 0.007 0.012 0.044 0.007 0.008
A0A024R0T9;K7ER74;P02655 AD 195 0.035 0.030 0.037 0.035 0.031 0.033 0.032 0.038 0.031
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.102 0.298 0.432 0.343 0.385 0.736 0.366 0.598 0.264
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.620 0.253 0.254 0.250 0.256 0.259 0.304 0.259 0.266
A0A075B6H7 AD 91 0.063 0.007 0.008 0.009 0.002 0.053 0.160 0.008 0.020
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175
Q9Y6X5 AD 173 0.057 0.215 0.344 0.342 0.200 0.291 0.031 0.092 0.113
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083
Q9Y6Y9 AD 119 0.561 0.742 0.505 0.904 0.380 0.520 0.829 0.317 0.334
S4R3U6 AD 126 0.345 0.062 0.698 0.037 0.082 0.730 0.462 0.030 0.036

1421 rows × 9 columns

Assemble rejected features#

Hide code cell source

da_target = scores.loc[pd.IndexSlice[:, args.target],
                       pd.IndexSlice[:, 'rejected']
                       ].join(freq_feat
                              ).set_index(
    ('data', 'frequency'), append=True)
da_target.index.names = da_target.index.names[:-1] + ['frequency']
fname = args.out_folder / 'equality_rejected_target.pkl'
files_out[fname.name] = fname.as_posix()
da_target.to_pickle(fname)
count_rejected = njab.pandas.combine_value_counts(da_target.droplevel(-1, axis=1))
count_rejected.to_excel(writer, sheet_name='count_rejected')
count_rejected
PI VAE None DAE RF Median QRILC CF TRKNN
False 1,025 947 1,054 951 973 1,069 995 932 936
True 396 474 367 470 448 352 426 489 485

Tabulate rejected decisions by method:#

Hide code cell source

# ! This uses implicitly that RSN is not available for some protein groups
# ! Make an explicit list of the 313 protein groups available in original data
mask_common = da_target.notna().all(axis=1)
count_rejected_common = njab.pandas.combine_value_counts(da_target.loc[mask_common].droplevel(-1, axis=1))
count_rejected_common.to_excel(writer, sheet_name='count_rejected_common')
count_rejected_common
PI VAE None DAE RF Median QRILC CF TRKNN
False 1,025 947 1,054 951 973 1,069 995 932 936
True 396 474 367 470 448 352 426 489 485

Tabulate rejected decisions by method for newly included features (if available)#

Hide code cell source

count_rejected_new = njab.pandas.combine_value_counts(da_target.loc[~mask_common].droplevel(-1, axis=1))
count_rejected_new.to_excel(writer, sheet_name='count_rejected_new')
count_rejected_new
PI VAE None DAE RF Median QRILC CF TRKNN

Tabulate rejected decisions by method for all features#

Hide code cell source

da_target.to_excel(writer, sheet_name='equality_rejected_all')
logger.info("Written to sheet 'equality_rejected_all' in excel file.")
da_target
root - INFO     Written to sheet 'equality_rejected_all' in excel file.
PI VAE None DAE RF Median QRILC CF TRKNN
rejected rejected rejected rejected rejected rejected rejected rejected rejected
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 False True True True True True False True True
A0A024R0T9;K7ER74;P02655 AD 195 False False False False False False False False False
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 False False False False False False False False False
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 False False False False False False False False False
A0A075B6H7 AD 91 False True True True True False False True True
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 False False False False False False False False False
Q9Y6X5 AD 173 False False False False False False False False False
Q9Y6Y8;Q9Y6Y8-2 AD 197 False False False False False False False False False
Q9Y6Y9 AD 119 False False False False False False False False False
S4R3U6 AD 126 False False False False False False False False False

1421 rows × 9 columns

Tabulate number of equal decison by method (True) to the ones with varying decision depending on the method (False)

Hide code cell source

da_target_same = (da_target.sum(axis=1) == 0) | da_target.all(axis=1)
da_target_same.value_counts()
True    1,098
False     323
Name: count, dtype: int64

List frequency of features with varying decisions

Hide code cell source

feat_idx_w_diff = da_target_same[~da_target_same].index
feat_idx_w_diff.to_frame()[['frequency']].reset_index(-1, drop=True)
frequency
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 186
A0A075B6H7 AD 91
A0A075B6H9 AD 189
A0A075B6I0 AD 194
A0A075B6J9 AD 156
... ... ...
Q9UP79 AD 135
Q9UPU3 AD 163
Q9UQ52 AD 188
Q9Y281;Q9Y281-3 AD 51
Q9Y6C2 AD 119

323 rows × 1 columns

take only those with different decisions

Hide code cell source

(qvalues
 .loc[feat_idx_w_diff]
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff')
 )

(qvalues
 .loc[feat_idx_w_diff]
 .loc[mask_common]  # mask automatically aligned
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff_common')
 )

try:
    (qvalues
     .loc[feat_idx_w_diff]
     .loc[~mask_common]
     .sort_values(('None', 'qvalue'))
     .to_excel(writer, sheet_name='qvalues_diff_new')
     )
except IndexError:
    print("No new features or no new ones (with diverging decisions.)")
writer.close()
No new features or no new ones (with diverging decisions.)

Plots for inspecting imputations (for diverging decisions)#

Hide code cell source

if not args.make_plots:
    logger.warning("Not plots requested.")
    import sys
    sys.exit(0)
root - WARNING  Not plots requested.
/home/runner/work/pimms/pimms/project/.snakemake/conda/43fbe714d68d8fe6f9b0c93f5652adb3_/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3755: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
An exception has occurred, use %tb to see the full traceback.

SystemExit: 0

Load target#

Hide code cell source

target = pd.read_csv(args.fn_clinical_data,
                     index_col=0,
                     usecols=[args.sample_id_col, args.target])
target = target.dropna()
target

Hide code cell source

target_to_group = target.copy()
target = target >= args.cutoff_target
target = target.replace({False: f'{args.target} < {args.cutoff_target}',
                        True: f'{args.target} >= {args.cutoff_target}'}
                        ).astype('category')
pd.crosstab(target.squeeze(), target_to_group.squeeze())

Measurments#

Hide code cell source

data = pimmslearn.io.datasplits.DataSplits.from_folder(
    args.data,
    file_format=args.file_format)
data = pd.concat([data.train_X, data.val_y, data.test_y]).unstack()
data

plot all of the new pgs which are at least once significant which are not already dumped.

Hide code cell source

feat_new_abundant = da_target.loc[~mask_common].any(axis=1)
feat_new_abundant = feat_new_abundant.loc[feat_new_abundant].index.get_level_values(0)
feat_new_abundant

Hide code cell source

feat_sel = feat_idx_w_diff.get_level_values(0)
feat_sel = feat_sel.union(feat_new_abundant)
len(feat_sel)

Hide code cell source

data = data.loc[:, feat_sel]
data
  • RSN prediction are based on all samples mean and std (N=455) as in original study

  • VAE also trained on all samples (self supervised) One could also reduce the selected data to only the samples with a valid target marker, but this was not done in the original study which considered several different target markers.

RSN : shifted per sample, not per feature!

Load all prediction files and reshape

Hide code cell source

# exclude 'None' as this is without imputation (-> data)
model_keys = [k for k in qvalues.columns.get_level_values(0) if k != 'None']
pred_paths = [
    args.out_preds / args.template_pred.format(method)
    for method in model_keys]
pred_paths

Hide code cell source

load_single_csv_pred_file = pimmslearn.analyzers.compare_predictions.load_single_csv_pred_file
pred_real_na = dict()
for method in model_keys:
    fname = args.out_preds / args.template_pred.format(method)
    print(f"missing values pred. by {method}: {fname}")
    pred_real_na[method] = load_single_csv_pred_file(fname)
pred_real_na = pd.DataFrame(pred_real_na)
pred_real_na

Once imputation, reduce to target samples only (samples with target score)

Hide code cell source

# select samples with target information
data = data.loc[target.index]
pred_real_na = pred_real_na.loc[target.index]

# assert len(data) == len(pred_real_na)

Hide code cell source

idx = feat_sel[0]

Hide code cell source

feat_observed = data[idx].dropna()
feat_observed

Hide code cell source

# axes = axes.ravel()
# args.out_folder.parent / 'intensity_plots'
# each feature -> one plot?
# plot all which are at least for one method significant?
folder = args.out_folder / 'intensities_for_diff_in_DA_decision'
folder.mkdir(parents=True, exist_ok=True)

Hide code cell source

min_y_int, max_y_int = pimmslearn.plotting.data.get_min_max_iterable(
    [data.stack(), pred_real_na.stack()])
min_max = min_y_int, max_y_int

target_name = target.columns[0]

min_max, target_name

Compare with target annotation#

Hide code cell source

# labels somehow?
# target.replace({True: f' >={args.cutoff_target}', False: f'<{args.cutoff_target}'})

for i, idx in enumerate(feat_sel):
    print(f"Swarmplot {i:3<}: {idx}:")
    fig, ax = plt.subplots()

    # dummy plots, just to get the Path objects
    tmp_dot = ax.scatter([1, 2], [3, 4], marker='X')
    new_mk, = tmp_dot.get_paths()
    tmp_dot.remove()

    feat_observed = data[idx].dropna()

    def get_centered_label(method, n, q):
        model_str = f'{method}'
        stats_str = f'(N={n:,d}, q={q:.3f})'
        if len(model_str) > len(stats_str):
            stats_str = f"{stats_str:<{len(model_str)}}"
        else:
            model_str = f"{model_str:<{len(stats_str)}}"
        return f'{model_str}\n{stats_str}'

    key = get_centered_label(method='observed',
                             n=len(feat_observed),
                             q=float(qvalues.loc[idx, ('None', 'qvalue')])
                             )
    to_plot = {key: feat_observed}
    for method in model_keys:
        try:
            pred = pred_real_na.loc[pd.IndexSlice[:,
                                                  idx], method].dropna().droplevel(-1)
            if len(pred) == 0:
                # in case no values was imputed -> qvalue is as based on measured
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, ('None', 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].notna().all():
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, (method, 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].isna().all():
                logger.info(f"NA qvalues for {idx}: {method}")
                continue
            else:
                raise ValueError("Unknown case.")
            to_plot[key] = pred
        except KeyError:
            print(f"No missing values for {idx}: {method}")
            continue

    to_plot = pd.DataFrame.from_dict(to_plot)
    to_plot.columns.name = 'group'
    groups_order = to_plot.columns.to_list()
    to_plot = to_plot.stack().to_frame('intensity').reset_index(-1)
    to_plot = to_plot.join(target.astype('category'), how='inner')
    to_plot = to_plot.astype({'group': 'category'})

    ax = seaborn.swarmplot(data=to_plot,
                           x='group',
                           y='intensity',
                           order=groups_order,
                           dodge=True,
                           hue=args.target,
                           size=2,
                           ax=ax)
    first_pg = idx.split(";")[0]
    ax.set_title(
        f'Imputation for protein group {first_pg} with target {target_name} (N= {len(data):,d} samples)')

    _ = ax.set_ylim(min_y_int, max_y_int)
    _ = ax.locator_params(axis='y', integer=True)
    _ = ax.set_xlabel('')
    _xticks = ax.get_xticks()
    ax.xaxis.set_major_locator(
        matplotlib.ticker.FixedLocator(_xticks)
    )
    _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45,
                           horizontalalignment='right')

    N_hues = len(pd.unique(to_plot[args.target]))

    _ = ax.collections[0].set_paths([new_mk])
    _ = ax.collections[1].set_paths([new_mk])

    label_target_0, label_target_1 = ax.collections[-2].get_label(), ax.collections[-1].get_label()
    _ = ax.collections[-2].set_label(f'imputed, {label_target_0}')
    _ = ax.collections[-1].set_label(f'imputed, {label_target_1}')
    _obs_label0 = ax.scatter([], [], color='C0', marker='X', label=f'observed, {label_target_0}')
    _obs_label1 = ax.scatter([], [], color='C1', marker='X', label=f'observed, {label_target_1}')
    _ = ax.legend(
        handles=[_obs_label0, _obs_label1, *ax.collections[-4:-2]],
        fontsize=5, title_fontsize=5, markerscale=0.4,)
    fname = (folder /
             f'{first_pg}_swarmplot.pdf')
    files_out[fname.name] = fname.as_posix()
    pimmslearn.savefig(
        fig,
        name=fname)
    plt.close()

Saved files:

Hide code cell source

files_out