Compare predictions between model and RSN#

  • see differences in imputation for diverging cases

  • dumps top5

Hide code cell source

import logging
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import njab
import pandas as pd
import seaborn

import pimmslearn
import pimmslearn.analyzers
import pimmslearn.imputation
import pimmslearn.io.datasplits

logger = pimmslearn.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.WARNING)

plt.rcParams['figure.figsize'] = [4, 2.5]  # [16.0, 7.0] , [4, 3]
pimmslearn.plotting.make_large_descriptors(7)

# catch passed parameters
args = None
args = dict(globals()).keys()

Parameters#

folder_experiment = 'runs/appl_ald_data/plasma/proteinGroups'
fn_clinical_data = "data/ALD_study/processed/ald_metadata_cli.csv"
make_plots = True  # create histograms and swarmplots of diverging results
model_key = 'VAE'
sample_id_col = 'Sample ID'
target = 'kleiner'
cutoff_target: int = 2  # => for binarization target >= cutoff_target
out_folder = 'diff_analysis'
file_format = 'csv'
baseline = 'RSN'  # default is RSN, but could be any other trained model
template_pred = 'pred_real_na_{}.csv'  # fixed, do not change
ref_method_score = None  # filepath to reference method score
# Parameters
cutoff_target = 0.5
make_plots = False
ref_method_score = None
folder_experiment = "runs/alzheimer_study"
target = "AD"
baseline = "PI"
out_folder = "diff_analysis"
fn_clinical_data = "runs/alzheimer_study/data/clinical_data.csv"

Hide code cell source

params = pimmslearn.nb.get_params(args, globals=globals())
args = pimmslearn.nb.Config()
args.folder_experiment = Path(params["folder_experiment"])
args = pimmslearn.nb.add_default_paths(args,
                                 out_root=(args.folder_experiment
                                           / params["out_folder"]
                                           / params["target"]))
args.folder_scores = (args.folder_experiment
                      / params["out_folder"]
                      / params["target"]
                      / 'scores'
                      )
args.update_from_dict(params)
args
root - INFO     Removed from global namespace: folder_experiment
root - INFO     Removed from global namespace: fn_clinical_data
root - INFO     Removed from global namespace: make_plots
root - INFO     Removed from global namespace: model_key
root - INFO     Removed from global namespace: sample_id_col
root - INFO     Removed from global namespace: target
root - INFO     Removed from global namespace: cutoff_target
root - INFO     Removed from global namespace: out_folder
root - INFO     Removed from global namespace: file_format
root - INFO     Removed from global namespace: baseline
root - INFO     Removed from global namespace: template_pred
root - INFO     Removed from global namespace: ref_method_score
root - INFO     Already set attribute: folder_experiment has value runs/alzheimer_study
root - INFO     Already set attribute: out_folder has value diff_analysis
{'baseline': 'PI',
 'cutoff_target': 0.5,
 'data': PosixPath('runs/alzheimer_study/data'),
 'file_format': 'csv',
 'fn_clinical_data': 'runs/alzheimer_study/data/clinical_data.csv',
 'folder_experiment': PosixPath('runs/alzheimer_study'),
 'folder_scores': PosixPath('runs/alzheimer_study/diff_analysis/AD/scores'),
 'make_plots': False,
 'model_key': 'VAE',
 'out_figures': PosixPath('runs/alzheimer_study/figures'),
 'out_folder': PosixPath('runs/alzheimer_study/diff_analysis/AD'),
 'out_metrics': PosixPath('runs/alzheimer_study'),
 'out_models': PosixPath('runs/alzheimer_study'),
 'out_preds': PosixPath('runs/alzheimer_study/preds'),
 'ref_method_score': None,
 'sample_id_col': 'Sample ID',
 'target': 'AD',
 'template_pred': 'pred_real_na_{}.csv'}

Write outputs to excel

Hide code cell source

files_out = dict()

fname = args.out_folder / 'diff_analysis_compare_DA.xlsx'
writer = pd.ExcelWriter(fname)
files_out[fname.name] = fname.as_posix()
logger.info("Writing to excel file: %s", fname)
root - INFO     Writing to excel file: runs/alzheimer_study/diff_analysis/AD/diff_analysis_compare_DA.xlsx

Load scores#

List dump of scores:

Hide code cell source

score_dumps = [fname for fname in Path(
    args.folder_scores).iterdir() if fname.suffix == '.pkl']
score_dumps
[PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_VAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_RF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_DAE.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_QRILC.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_TRKNN.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_PI.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_Median.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_CF.pkl'),
 PosixPath('runs/alzheimer_study/diff_analysis/AD/scores/diff_analysis_scores_None.pkl')]

Load scores from dumps:

Hide code cell source

scores = pd.concat([pd.read_pickle(fname) for fname in score_dumps], axis=1)
scores
model VAE RF ... CF None
var SS DF F p-unc np2 -Log10 pvalue qvalue rejected SS DF ... qvalue rejected SS DF F p-unc np2 -Log10 pvalue qvalue rejected
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 1.029 1 7.334 0.007 0.037 2.132 0.021 True 0.989 1 ... 0.019 True 0.834 1.000 6.088 0.015 0.033 1.837 0.043 True
age 0.011 1 0.080 0.777 0.000 0.109 0.851 False 0.003 1 ... 0.937 False 0.002 1.000 0.015 0.903 0.000 0.044 0.943 False
Kiel 0.318 1 2.267 0.134 0.012 0.873 0.229 False 0.210 1 ... 0.308 False 0.145 1.000 1.061 0.304 0.006 0.517 0.461 False
Magdeburg 0.534 1 3.806 0.053 0.020 1.280 0.108 False 0.389 1 ... 0.139 False 0.273 1.000 1.996 0.159 0.011 0.797 0.286 False
Sweden 1.814 1 12.934 0.000 0.063 3.386 0.002 True 1.494 1 ... 0.002 True 1.209 1.000 8.827 0.003 0.047 2.472 0.013 True
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
S4R3U6 AD 1.769 1 3.573 0.060 0.018 1.220 0.121 False 1.173 1 ... 0.190 False 0.095 1.000 0.151 0.698 0.001 0.156 0.803 False
age 0.624 1 1.260 0.263 0.007 0.580 0.389 False 0.686 1 ... 0.265 False 1.370 1.000 2.171 0.143 0.018 0.844 0.265 False
Kiel 2.754 1 5.562 0.019 0.028 1.713 0.047 True 2.153 1 ... 0.142 False 1.396 1.000 2.213 0.139 0.018 0.856 0.259 False
Magdeburg 2.388 1 4.821 0.029 0.025 1.533 0.066 False 1.711 1 ... 0.217 False 0.556 1.000 0.882 0.350 0.007 0.456 0.507 False
Sweden 19.067 1 38.503 0.000 0.168 8.479 0.000 True 14.171 1 ... 0.000 True 8.519 1.000 13.502 0.000 0.101 3.447 0.002 True

7105 rows × 72 columns

If reference dump is provided, add it to the scores

Hide code cell source

if args.ref_method_score:
    scores_reference = (pd
                        .read_pickle(args.ref_method_score)
                        .rename({'None': 'None (100%)'},
                                axis=1))
    scores = scores.join(scores_reference)
    logger.info(f'Added reference method scores from {args.ref_method_score}')

Load frequencies of observed features#

Hide code cell source

fname = args.folder_experiment / 'freq_features_observed.csv'
freq_feat = pd.read_csv(fname, index_col=0)
freq_feat.columns = pd.MultiIndex.from_tuples([('data', 'frequency'),])
freq_feat
data
frequency
protein groups
A0A024QZX5;A0A087X1N8;P35237 186
A0A024R0T9;K7ER74;P02655 195
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 174
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 196
A0A075B6H7 91
... ...
Q9Y6R7 197
Q9Y6X5 173
Q9Y6Y8;Q9Y6Y8-2 197
Q9Y6Y9 119
S4R3U6 126

1421 rows × 1 columns

Assemble qvalues#

Hide code cell source

qvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'qvalue']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
qvalues.index.names = qvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'qvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
qvalues.to_pickle(fname)
qvalues.to_excel(writer, sheet_name='qvalues_all')
qvalues
VAE RF DAE QRILC TRKNN PI Median CF None
qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue qvalue
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.021 0.021 0.014 0.094 0.023 0.502 0.039 0.019 0.043
A0A024R0T9;K7ER74;P02655 AD 195 0.072 0.069 0.069 0.071 0.071 0.109 0.087 0.076 0.092
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.403 0.502 0.468 0.460 0.394 0.161 0.832 0.467 0.586
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.383 0.394 0.375 0.446 0.396 0.724 0.418 0.385 0.404
A0A075B6H7 AD 91 0.012 0.009 0.027 0.277 0.048 0.227 0.124 0.004 0.027
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.283 0.292 0.282 0.301 0.289 0.318 0.315 0.286 0.307
Q9Y6X5 AD 173 0.361 0.289 0.341 0.107 0.205 0.086 0.455 0.261 0.501
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.157 0.162 0.156 0.171 0.160 0.182 0.178 0.159 0.174
Q9Y6Y9 AD 119 0.936 0.550 0.874 0.797 0.472 0.572 0.667 0.600 0.651
S4R3U6 AD 126 0.121 0.205 0.049 0.654 0.080 0.538 0.829 0.190 0.803

1421 rows × 9 columns

Assemble pvalues#

Hide code cell source

pvalues = scores.loc[pd.IndexSlice[:, args.target],
                     pd.IndexSlice[:, 'p-unc']
                     ].join(freq_feat
                            ).set_index(
    ('data', 'frequency'), append=True)
pvalues.index.names = pvalues.index.names[:-1] + ['frequency']
fname = args.out_folder / 'pvalues_target.pkl'
files_out[fname.name] = fname.as_posix()
pvalues.to_pickle(fname)
pvalues.to_excel(writer, sheet_name='pvalues_all')
pvalues
VAE RF DAE QRILC TRKNN PI Median CF None
p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc p-unc
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 0.007 0.007 0.005 0.039 0.008 0.336 0.012 0.006 0.015
A0A024R0T9;K7ER74;P02655 AD 195 0.032 0.029 0.030 0.028 0.031 0.043 0.033 0.034 0.037
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 0.277 0.363 0.339 0.311 0.264 0.072 0.736 0.331 0.432
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 0.258 0.259 0.252 0.297 0.266 0.586 0.259 0.257 0.254
A0A075B6H7 AD 91 0.004 0.002 0.010 0.158 0.020 0.111 0.053 0.001 0.008
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175 0.175
Q9Y6X5 AD 173 0.239 0.173 0.223 0.046 0.113 0.032 0.291 0.156 0.344
Q9Y6Y8;Q9Y6Y8-2 AD 197 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083 0.083
Q9Y6Y9 AD 119 0.896 0.412 0.812 0.697 0.334 0.409 0.520 0.473 0.505
S4R3U6 AD 126 0.060 0.112 0.021 0.519 0.036 0.373 0.730 0.104 0.698

1421 rows × 9 columns

Assemble rejected features#

Hide code cell source

da_target = scores.loc[pd.IndexSlice[:, args.target],
                       pd.IndexSlice[:, 'rejected']
                       ].join(freq_feat
                              ).set_index(
    ('data', 'frequency'), append=True)
da_target.index.names = da_target.index.names[:-1] + ['frequency']
fname = args.out_folder / 'equality_rejected_target.pkl'
files_out[fname.name] = fname.as_posix()
da_target.to_pickle(fname)
count_rejected = njab.pandas.combine_value_counts(da_target.droplevel(-1, axis=1))
count_rejected.to_excel(writer, sheet_name='count_rejected')
count_rejected
VAE RF DAE QRILC TRKNN PI Median CF None
False 935 965 933 992 936 1,025 1,069 958 1,054
True 486 456 488 429 485 396 352 463 367

Tabulate rejected decisions by method:#

Hide code cell source

# ! This uses implicitly that RSN is not available for some protein groups
# ! Make an explicit list of the 313 protein groups available in original data
mask_common = da_target.notna().all(axis=1)
count_rejected_common = njab.pandas.combine_value_counts(da_target.loc[mask_common].droplevel(-1, axis=1))
count_rejected_common.to_excel(writer, sheet_name='count_rejected_common')
count_rejected_common
VAE RF DAE QRILC TRKNN PI Median CF None
False 935 965 933 992 936 1,025 1,069 958 1,054
True 486 456 488 429 485 396 352 463 367

Tabulate rejected decisions by method for newly included features (if available)#

Hide code cell source

count_rejected_new = njab.pandas.combine_value_counts(da_target.loc[~mask_common].droplevel(-1, axis=1))
count_rejected_new.to_excel(writer, sheet_name='count_rejected_new')
count_rejected_new
VAE RF DAE QRILC TRKNN PI Median CF None

Tabulate rejected decisions by method for all features#

Hide code cell source

da_target.to_excel(writer, sheet_name='equality_rejected_all')
logger.info("Written to sheet 'equality_rejected_all' in excel file.")
da_target
root - INFO     Written to sheet 'equality_rejected_all' in excel file.
VAE RF DAE QRILC TRKNN PI Median CF None
rejected rejected rejected rejected rejected rejected rejected rejected rejected
protein groups Source frequency
A0A024QZX5;A0A087X1N8;P35237 AD 186 True True True False True False True True True
A0A024R0T9;K7ER74;P02655 AD 195 False False False False False False False False False
A0A024R3W6;A0A024R412;O60462;O60462-2;O60462-3;O60462-4;O60462-5;Q7LBX6;X5D2Q8 AD 174 False False False False False False False False False
A0A024R644;A0A0A0MRU5;A0A1B0GWI2;O75503 AD 196 False False False False False False False False False
A0A075B6H7 AD 91 True True True False True False False True True
... ... ... ... ... ... ... ... ... ... ... ...
Q9Y6R7 AD 197 False False False False False False False False False
Q9Y6X5 AD 173 False False False False False False False False False
Q9Y6Y8;Q9Y6Y8-2 AD 197 False False False False False False False False False
Q9Y6Y9 AD 119 False False False False False False False False False
S4R3U6 AD 126 False False True False False False False False False

1421 rows × 9 columns

Tabulate number of equal decison by method (True) to the ones with varying decision depending on the method (False)

Hide code cell source

da_target_same = (da_target.sum(axis=1) == 0) | da_target.all(axis=1)
da_target_same.value_counts()
True    1,093
False     328
Name: count, dtype: int64

List frequency of features with varying decisions

Hide code cell source

feat_idx_w_diff = da_target_same[~da_target_same].index
feat_idx_w_diff.to_frame()[['frequency']].reset_index(-1, drop=True)
frequency
protein groups Source
A0A024QZX5;A0A087X1N8;P35237 AD 186
A0A075B6H7 AD 91
A0A075B6H9 AD 189
A0A075B6I0 AD 194
A0A075B6J9 AD 156
... ... ...
Q9UPU3 AD 163
Q9UQ52 AD 188
Q9Y281;Q9Y281-3 AD 51
Q9Y6C2 AD 119
S4R3U6 AD 126

328 rows × 1 columns

take only those with different decisions

Hide code cell source

(qvalues
 .loc[feat_idx_w_diff]
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff')
 )

(qvalues
 .loc[feat_idx_w_diff]
 .loc[mask_common]  # mask automatically aligned
 .sort_values(('None', 'qvalue'))
 .to_excel(writer, sheet_name='qvalues_diff_common')
 )

try:
    (qvalues
     .loc[feat_idx_w_diff]
     .loc[~mask_common]
     .sort_values(('None', 'qvalue'))
     .to_excel(writer, sheet_name='qvalues_diff_new')
     )
except IndexError:
    print("No new features or no new ones (with diverging decisions.)")
writer.close()
No new features or no new ones (with diverging decisions.)

Plots for inspecting imputations (for diverging decisions)#

Hide code cell source

if not args.make_plots:
    logger.warning("Not plots requested.")
    import sys
    sys.exit(0)
root - WARNING  Not plots requested.
/home/runner/work/pimms/pimms/project/.snakemake/conda/924ec7e362d761ecf0807b9074d79999_/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3707: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
An exception has occurred, use %tb to see the full traceback.

SystemExit: 0

Load target#

Hide code cell source

target = pd.read_csv(args.fn_clinical_data,
                     index_col=0,
                     usecols=[args.sample_id_col, args.target])
target = target.dropna()
target

Hide code cell source

target_to_group = target.copy()
target = target >= args.cutoff_target
target = target.replace({False: f'{args.target} < {args.cutoff_target}',
                        True: f'{args.target} >= {args.cutoff_target}'}
                        ).astype('category')
pd.crosstab(target.squeeze(), target_to_group.squeeze())

Measurments#

Hide code cell source

data = pimmslearn.io.datasplits.DataSplits.from_folder(
    args.data,
    file_format=args.file_format)
data = pd.concat([data.train_X, data.val_y, data.test_y]).unstack()
data

plot all of the new pgs which are at least once significant which are not already dumped.

Hide code cell source

feat_new_abundant = da_target.loc[~mask_common].any(axis=1)
feat_new_abundant = feat_new_abundant.loc[feat_new_abundant].index.get_level_values(0)
feat_new_abundant

Hide code cell source

feat_sel = feat_idx_w_diff.get_level_values(0)
feat_sel = feat_sel.union(feat_new_abundant)
len(feat_sel)

Hide code cell source

data = data.loc[:, feat_sel]
data
  • RSN prediction are based on all samples mean and std (N=455) as in original study

  • VAE also trained on all samples (self supervised) One could also reduce the selected data to only the samples with a valid target marker, but this was not done in the original study which considered several different target markers.

RSN : shifted per sample, not per feature!

Load all prediction files and reshape

Hide code cell source

# exclude 'None' as this is without imputation (-> data)
model_keys = [k for k in qvalues.columns.get_level_values(0) if k != 'None']
pred_paths = [
    args.out_preds / args.template_pred.format(method)
    for method in model_keys]
pred_paths

Hide code cell source

load_single_csv_pred_file = pimmslearn.analyzers.compare_predictions.load_single_csv_pred_file
pred_real_na = dict()
for method in model_keys:
    fname = args.out_preds / args.template_pred.format(method)
    print(f"missing values pred. by {method}: {fname}")
    pred_real_na[method] = load_single_csv_pred_file(fname)
pred_real_na = pd.DataFrame(pred_real_na)
pred_real_na

Once imputation, reduce to target samples only (samples with target score)

Hide code cell source

# select samples with target information
data = data.loc[target.index]
pred_real_na = pred_real_na.loc[target.index]

# assert len(data) == len(pred_real_na)

Hide code cell source

idx = feat_sel[0]

Hide code cell source

feat_observed = data[idx].dropna()
feat_observed

Hide code cell source

# axes = axes.ravel()
# args.out_folder.parent / 'intensity_plots'
# each feature -> one plot?
# plot all which are at least for one method significant?
folder = args.out_folder / 'intensities_for_diff_in_DA_decision'
folder.mkdir(parents=True, exist_ok=True)

Hide code cell source

min_y_int, max_y_int = pimmslearn.plotting.data.get_min_max_iterable(
    [data.stack(), pred_real_na.stack()])
min_max = min_y_int, max_y_int

target_name = target.columns[0]

min_max, target_name

Compare with target annotation#

Hide code cell source

# labels somehow?
# target.replace({True: f' >={args.cutoff_target}', False: f'<{args.cutoff_target}'})

for i, idx in enumerate(feat_sel):
    print(f"Swarmplot {i:3<}: {idx}:")
    fig, ax = plt.subplots()

    # dummy plots, just to get the Path objects
    tmp_dot = ax.scatter([1, 2], [3, 4], marker='X')
    new_mk, = tmp_dot.get_paths()
    tmp_dot.remove()

    feat_observed = data[idx].dropna()

    def get_centered_label(method, n, q):
        model_str = f'{method}'
        stats_str = f'(N={n:,d}, q={q:.3f})'
        if len(model_str) > len(stats_str):
            stats_str = f"{stats_str:<{len(model_str)}}"
        else:
            model_str = f"{model_str:<{len(stats_str)}}"
        return f'{model_str}\n{stats_str}'

    key = get_centered_label(method='observed',
                             n=len(feat_observed),
                             q=float(qvalues.loc[idx, ('None', 'qvalue')])
                             )
    to_plot = {key: feat_observed}
    for method in model_keys:
        try:
            pred = pred_real_na.loc[pd.IndexSlice[:,
                                                  idx], method].dropna().droplevel(-1)
            if len(pred) == 0:
                # in case no values was imputed -> qvalue is as based on measured
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, ('None', 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].notna().all():
                key = get_centered_label(method=method,
                                         n=len(pred),
                                         q=float(qvalues.loc[idx, (method, 'qvalue')]
                                                 ))
            elif qvalues.loc[idx, (method, 'qvalue')].isna().all():
                logger.info(f"NA qvalues for {idx}: {method}")
                continue
            else:
                raise ValueError("Unknown case.")
            to_plot[key] = pred
        except KeyError:
            print(f"No missing values for {idx}: {method}")
            continue

    to_plot = pd.DataFrame.from_dict(to_plot)
    to_plot.columns.name = 'group'
    groups_order = to_plot.columns.to_list()
    to_plot = to_plot.stack().to_frame('intensity').reset_index(-1)
    to_plot = to_plot.join(target.astype('category'), how='inner')
    to_plot = to_plot.astype({'group': 'category'})

    ax = seaborn.swarmplot(data=to_plot,
                           x='group',
                           y='intensity',
                           order=groups_order,
                           dodge=True,
                           hue=args.target,
                           size=2,
                           ax=ax)
    first_pg = idx.split(";")[0]
    ax.set_title(
        f'Imputation for protein group {first_pg} with target {target_name} (N= {len(data):,d} samples)')

    _ = ax.set_ylim(min_y_int, max_y_int)
    _ = ax.locator_params(axis='y', integer=True)
    _ = ax.set_xlabel('')
    _xticks = ax.get_xticks()
    ax.xaxis.set_major_locator(
        matplotlib.ticker.FixedLocator(_xticks)
    )
    _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45,
                           horizontalalignment='right')

    N_hues = len(pd.unique(to_plot[args.target]))

    _ = ax.collections[0].set_paths([new_mk])
    _ = ax.collections[1].set_paths([new_mk])

    label_target_0, label_target_1 = ax.collections[-2].get_label(), ax.collections[-1].get_label()
    _ = ax.collections[-2].set_label(f'imputed, {label_target_0}')
    _ = ax.collections[-1].set_label(f'imputed, {label_target_1}')
    _obs_label0 = ax.scatter([], [], color='C0', marker='X', label=f'observed, {label_target_0}')
    _obs_label1 = ax.scatter([], [], color='C1', marker='X', label=f'observed, {label_target_1}')
    _ = ax.legend(
        handles=[_obs_label0, _obs_label1, *ax.collections[-4:-2]],
        fontsize=5, title_fontsize=5, markerscale=0.4,)
    fname = (folder /
             f'{first_pg}_swarmplot.pdf')
    files_out[fname.name] = fname.as_posix()
    pimmslearn.savefig(
        fig,
        name=fname)
    plt.close()

Saved files:

Hide code cell source

files_out