Example: Predicting COVID-19 patients outcome using MRCs

In this example we will use MRCpy.MRC and MRCpy.CMRC to predict the outcome of a COVID-19 positive patient at the moment of hospital triage. This example uses a dataset that comprises different demographic variables and biomarkers of the patients and a binary outcome Status where Status = 0 define the group of survivors and Status = 1 determines a decease.

The data is provided by the Covid Data Saves Lives initiative carried out by HM Hospitales with information of the first wave of the COVID outbreak in Spanish hospitals. The data is available upon request through HM Hospitales here .

See also

For more information about the dataset and the creation of a risk indicator using Logistic regression refer to:

[1] Ruben Armañanzas et al. “Derivation of a Cost-Sensitive COVID-19 Mortality Risk Indicator Using a Multistart Framework” , in 2021 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), 2021, pp. 2179–2186.

First we will see how to deal with class imbalance when training a model using syntethic minority over-sampling (SMOTE) techniques. Furthermore, we will comparetwo MRC with two state of the art machine learning models probability estimation . The selected models are CMRC(phi = 'threshold' , loss = 'log') & MRC(phi = 'fourier' , loss = 'log') for the group of MRCs and Logistic Regression (LR) & C-Support Vector Classifier(SVC) with the implementation from Scikit-Learn.

# Import needed modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from imblearn.over_sampling import SMOTENC
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
)
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

from MRCpy import CMRC, MRC

COVID dataset Loader:

def load_covid(norm=False, array=True):
    data_consensus = pd.read_csv("data/data_consensus.csv", sep=";")
    # rename variables
    variable_dict = {
        "CD0000AGE": "Age",
        "CORE": "PATIENT_ID",
        "CT000000U": "Urea",
        "CT00000BT": "Bilirubin",
        "CT00000NA": "Sodium",
        "CT00000TP": "Proth_time",
        "CT0000COM": "Com",
        "CT0000LDH": "LDH",
        "CT0000NEU": "Neutrophils",
        "CT0000PCR": "Pro_C_Rea",
        "CT0000VCM": "Med_corp_vol",
        "CT000APTT": "Ceph_time",
        "CT000CHCM": "Mean_corp_Hgb",
        "CT000EOSP": "Eosinophils%",
        "CT000LEUC": "Leukocytes",
        "CT000LINP": "Lymphocytes%",
        "CT000NEUP": "Neutrophils%",
        "CT000PLAQ": "Platelet_count",
        "CTHSDXXRATE": "Rate",
        "CTHSDXXSAT": "Sat",
        "ED0DISWHY": "Status",
        "F_INGRESO/ADMISSION_D_ING/INPAT": "Fecha_admision",
        "SEXO/SEX": "Sexo",
    }
    data_consensus = data_consensus.rename(columns=variable_dict)
    if norm:  # if we want the data standardised
        x_consensus = data_consensus[
            data_consensus.columns.difference(["Status", "PATIENT_ID"])
        ][:]
        std_scale = preprocessing.StandardScaler().fit(x_consensus)
        x_consensus_std = std_scale.transform(x_consensus)
        dataframex_consensus = pd.DataFrame(
            x_consensus_std, columns=x_consensus.columns
        )
        data_consensus.reset_index(drop=True, inplace=True)
        data_consensus = pd.concat(
            [dataframex_consensus, data_consensus[["Status"]]], axis=1
        )
    data_consensus = data_consensus[
        data_consensus.columns.difference(["PATIENT_ID"])
    ]
    X = data_consensus[
        data_consensus.columns.difference(["Status", "PATIENT_ID"])
    ]
    y = data_consensus["Status"]
    if array:
        X = X.to_numpy()
        y = y.to_numpy()
    return X, y

Addressing dataset imbalance with SMOTE

The COVID dataset has a significant problem of class imbalance where the positive outcome has a prevalence of 85% (1522) whilst the negative outcome has only 276. In this example oversampling will be used to add syintetic records to get an almost balanced dataset. SMOTE (Synthetic minority over sampling) is a package that implements such oversampling.

X, y = load_covid(array=False)
described = (
    X.describe(percentiles=[0.5])
    .round(2)
    .transpose()[["count", "mean", "std"]]
)
pd.DataFrame(y.value_counts().rename({0.0: "Survive", 1.0: "Decease"}))
Status
Survive 1522
Decease 276



So we create a set of cases syntehtically using 5 nearest neighbors until the class imbalance is almost removed. For more information about SMOTE refer to it’s documentation . We will use the method SMOTE-NC for numerical and categorical variables.

See also

For more information about the SMOTE package refer to:

[2] Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P. (2002). SMOTE: synthetic minority over-sampling technique. Journal of artificial intelligence research, 16, 321-357.

# We fit the data to the oversampler
smotefit = SMOTENC(sampling_strategy=0.75, categorical_features=[3])
X_resampled, y_resampled = smotefit.fit_resample(X, y)
described_resample = (
    X_resampled.describe(percentiles=[0.5])
    .round(2)
    .transpose()[["count", "mean", "std"]]
)
described_resample = described_resample.add_suffix("_SMT")
pd.concat([described, described_resample], axis=1)
count mean std count_SMT mean_SMT std_SMT
Age 1798.0 67.79 15.67 2663.0 71.71 14.78
Bilirubin 1798.0 0.57 0.45 2663.0 0.60 0.49
Ceph_time 1798.0 32.94 7.03 2663.0 33.32 7.50
Com 1798.0 0.50 0.78 2663.0 0.49 0.78
Eosinophils% 1798.0 0.70 1.57 2663.0 0.55 1.33
LDH 1798.0 601.10 367.24 2663.0 675.02 471.53
Leukocytes 1798.0 7.62 4.54 2663.0 8.23 4.86
Lymphocytes% 1798.0 18.19 10.44 2663.0 16.24 9.92
Mean_corp_Hgb 1798.0 33.62 1.42 2663.0 33.52 1.35
Med_corp_vol 1798.0 88.23 5.77 2663.0 88.63 5.88
Neutrophils 1798.0 5.75 3.77 2663.0 6.44 4.09
Neutrophils% 1798.0 73.01 12.99 2663.0 75.54 12.56
Platelet_count 1798.0 225.32 96.93 2663.0 219.27 93.65
Pro_C_Rea 1798.0 101.00 100.87 2663.0 121.41 110.35
Proth_time 1798.0 15.39 13.89 2663.0 16.17 15.14
Rate 1798.0 79.29 14.75 2663.0 80.69 14.81
Sat 1798.0 94.67 4.81 2663.0 93.60 5.96
Sodium 1798.0 136.92 4.50 2663.0 137.21 4.93
Urea 1798.0 43.17 30.72 2663.0 49.75 32.74


We see how the distribution of the real data and the resampled data is different. However the distribution between classes is kept similar due to the creation of the synthetic cases through 5 nearest neighbors.

pd.DataFrame(
    y_resampled.value_counts().rename({0.0: "Survive", 1.0: "Decease"})
)
Status
Survive 1522
Decease 1141


Probability estimation

In this section we will estimate the conditional probabilities and analyse the distribution of the probabilities depending on the real outcome . The probability estimation is better when using loss = log. We use CMRC(phi = 'threshold', loss = 'log') and MRC(phi = 'fourier' , loss = 'log'. We will then compare these MRCs with SVC and LR with default parameters.

Load classification function:

These function classify each of the cases in their correspondent confusion matrix’s category. It also allows to set the desired cut-off for the predictions.

def defDataFrame(model, x_test, y_test, threshold=0.5):
    """
    Takes x,y test and train and a fitted model and
    computes the probabilities to then classify in TP,TN , FP , FN.
    """
    if "predict_proba" in dir(model):
        probabilities = model.predict_proba(x_test)[:, 1]
        predictions = [1 if i > threshold else 0 for i in probabilities]
        df = pd.DataFrame(
            {
                "Real": y_test.tolist(),
                "Prediction": predictions,
                "Probabilities": probabilities.tolist(),
            }
        )
    else:
        df = pd.DataFrame(
            {"Real": y_test.tolist(), "Prediction": model.predict(x_test)}
        )
    conditions = [
        (df["Real"] == 1) & (df["Prediction"] == 1),
        (df["Real"] == 1) & (df["Prediction"] == 0),
        (df["Real"] == 0) & (df["Prediction"] == 0),
        (df["Real"] == 0) & (df["Prediction"] == 1),
    ]
    choices = [
        "True Positive",
        "False Negative",
        "True Negative",
        "False Positive",
    ]
    df["Category"] = np.select(conditions, choices, default="No")
    df.sort_index(inplace=True)
    df.sort_values(by="Category", ascending=False, inplace=True)
    return df

Train models:

We will train the models with 80% of the data and then test with the other 20% selected randomly.

X_train, X_test, y_train, y_test = train_test_split(
    X_resampled, y_resampled, test_size=0.2, random_state=1
)

clf_MRC = MRC(phi="fourier", use_cvx=True, loss="log").fit(X_train, y_train)
df_MRC = defDataFrame(model=clf_MRC, x_test=X_test, y_test=y_test)
MRC_values = pd.DataFrame(df_MRC.Category.value_counts()).rename(
    columns={"Category": type(clf_MRC).__name__}
)
MRC_values["Freq_MRC"] = MRC_values["MRC"] / sum(MRC_values["MRC"]) * 100

clf_CMRC = CMRC(phi="threshold", use_cvx=True, loss="log").fit(
    X_train, y_train
)
df_CMRC = defDataFrame(model=clf_CMRC, x_test=X_test, y_test=y_test)
CMRC_values = pd.DataFrame(df_CMRC.Category.value_counts()).rename(
    columns={"Category": type(clf_CMRC).__name__}
)
CMRC_values["Freq_CMRC"] = CMRC_values["CMRC"] / sum(CMRC_values["CMRC"]) * 100

clf_SVC = SVC(probability=True).fit(X_train, y_train)
df_SVC = defDataFrame(model=clf_SVC, x_test=X_test, y_test=y_test)
SVC_values = pd.DataFrame(df_SVC.Category.value_counts()).rename(
    columns={"Category": type(clf_SVC).__name__}
)
SVC_values["Freq_SVC"] = SVC_values["SVC"] / sum(SVC_values["SVC"]) * 100

clf_LR = LogisticRegression().fit(X_train, y_train)
df_LR = defDataFrame(model=clf_LR, x_test=X_test, y_test=y_test)
LR_values = pd.DataFrame(df_LR.Category.value_counts()).rename(
    columns={"Category": type(clf_LR).__name__}
)
LR_values["Freq_LR"] = (
    LR_values["LogisticRegression"]
    / sum(LR_values["LogisticRegression"])
    * 100
)


pd.concat(
    [MRC_values, CMRC_values, SVC_values, LR_values], axis=1
).style.set_caption("Classification results by model").format(precision=2)
Classification results by model
  MRC Freq_MRC CMRC Freq_CMRC SVC Freq_SVC LogisticRegression Freq_LR
True Negative 278 52.16 279 52.35 249 46.72 267 50.09
False Negative 140 26.27 35 6.57 76 14.26 46 8.63
True Positive 83 15.57 188 35.27 147 27.58 177 33.21
False Positive 32 6.00 31 5.82 61 11.44 43 8.07


We will compare now the histograms of the conditional probability for the two posible outcomes. Overlapping in the histograms means that the classification is erroneous. Condisering a cutoff of 0.5 pink cases below this point are false negatives (FN) and blue cases above the threhsold false positives (FP). It is important to consider that in this classification problem the missclassification of a patient with fatal outcome (FN) is considered a much more serious error.

def scatterPlot(df, ax):
    """
    Takes DF created with defDataFrame and creates a boxplot of
    different classification by mortal probability.
    """
    sns.swarmplot(
        ax=ax,
        y="Category",
        x="Probabilities",
        data=df,
        size=4,
        palette=sns.color_palette("tab10"),
        linewidth=0,
        dodge=False,
        alpha=0.6,
        order=[
            "True Negative",
            "False Negative",
            "True Positive",
            "False Positive",
        ],
    )
    sns.boxplot(
        ax=ax,
        x="Probabilities",
        y="Category",
        color="White",
        data=df,
        order=[
            "True Negative",
            "False Negative",
            "True Positive",
            "False Positive",
        ],
        saturation=15,
    )
    ax.set_xlabel("Probability of mortality")
    ax.set_ylabel("")


def plotHisto(df, ax, threshold=0.5, normalize=True):
    """
    Takes DF created with defDataFrame and plots histograms based on the
    probability of mortality by real Status at a selected @threshold.
    """
    if normalize:
        norm_params = {"stat": "density", "common_norm": False}
    else:
        norm_params = {}
    sns.histplot(
        ax=ax,
        data=df[df["Real"] == 1],
        x="Probabilities",
        color="deeppink",
        label="Deceased",
        bins=15,
        binrange=[0, 1],
        alpha=0.6,
        element="step",
        **norm_params
    )
    sns.histplot(
        ax=ax,
        data=df[df["Real"] == 0],
        x="Probabilities",
        color="dodgerblue",
        label="Survived",
        bins=15,
        binrange=[0, 1],
        alpha=0.4,
        element="step",
        **norm_params
    )
    ax.axvline(
        threshold, 0, 1, linestyle=(0, (1, 10)), linewidth=0.7, color="black"
    )


# visualize results
fig, ax = plt.subplots(
    nrows=2,
    ncols=2,
    sharex="all",
    sharey="all",
    gridspec_kw={"wspace": 0.1, "hspace": 0.35},
)
plotHisto(df_CMRC, ax=ax[0, 0], normalize=False)
ax[0, 0].set_title("CMRC")
plotHisto(df_MRC, ax=ax[1, 0], normalize=False)
ax[1, 0].set_title("MRC")
plotHisto(df_LR, ax=ax[0, 1], normalize=False)
ax[0, 1].set_title("LR")
ax[0, 1].legend()
plotHisto(df_SVC, ax=ax[1, 1], normalize=False)
ax[1, 1].set_title("SVC")
fig.tight_layout()
Imagen de prueba

0.25 to 0.75. This estimation is very sensible to cut-off changes. The CMRC model shows a distribution where most of the cases are grouped around 0 and 1 for survive and decease respectively. This results are similar to the Logistic Regression’s but with less overlapping. SVC is the model with the worst performance of all having a lot of patients that survived with high decease probabilities.

cm_cmrc = confusion_matrix(y_test, clf_CMRC.predict(X_test))  # CMRC
cm_mrc = confusion_matrix(y_test, clf_MRC.predict(X_test))  # MRC
cm_lr = confusion_matrix(y_test, clf_LR.predict(X_test))  # Logistic Regression
cm_svc = confusion_matrix(
    y_test, clf_SVC.predict(X_test)
)  # C-Support Vector Machine

fig, ax = plt.subplots(
    nrows=2,
    ncols=2,
    sharex="all",
    sharey="all",
    gridspec_kw={"wspace": 0, "hspace": 0.35},
)
ConfusionMatrixDisplay(cm_cmrc, display_labels=["Survive", "Decease"]).plot(
    colorbar=False, ax=ax[0, 0]
)
ax[0, 0].set_title("CMRC")
ConfusionMatrixDisplay(cm_mrc, display_labels=["Survive", "Decease"]).plot(
    colorbar=False, ax=ax[1, 0]
)
ax[1, 0].set_title("MRC")
ConfusionMatrixDisplay(cm_lr, display_labels=["Survive", "Decease"]).plot(
    colorbar=False, ax=ax[0, 1]
)
ax[0, 1].set_title("LR")
ConfusionMatrixDisplay(cm_svc, display_labels=["Survive", "Decease"]).plot(
    colorbar=False, ax=ax[1, 1]
)
ax[1, 1].set_title("SVC")
fig.tight_layout()
Confusion Matrices
pd.DataFrame(
    classification_report(
        y_test,
        clf_CMRC.predict(X_test),
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report CMRC").format(precision=3)
Classification report CMRC
  Survive Decease accuracy macro avg weighted avg
precision 0.889 0.858 0.876 0.873 0.876
recall 0.900 0.843 0.876 0.872 0.876
f1-score 0.894 0.851 0.876 0.872 0.876
support 310.000 223.000 0.876 533.000 533.000



pd.DataFrame(
    classification_report(
        y_test,
        clf_MRC.predict(X_test),
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report MRC").format(precision=3)
Classification report MRC
  Survive Decease accuracy macro avg weighted avg
precision 0.665 0.722 0.677 0.693 0.689
recall 0.897 0.372 0.677 0.634 0.677
f1-score 0.764 0.491 0.677 0.627 0.650
support 310.000 223.000 0.677 533.000 533.000



pd.DataFrame(
    classification_report(
        y_test,
        clf_LR.predict(X_test),
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report LR").format(precision=3)
Classification report LR
  Survive Decease accuracy macro avg weighted avg
precision 0.853 0.805 0.833 0.829 0.833
recall 0.861 0.794 0.833 0.828 0.833
f1-score 0.857 0.799 0.833 0.828 0.833
support 310.000 223.000 0.833 533.000 533.000



pd.DataFrame(
    classification_report(
        y_test,
        clf_SVC.predict(X_test),
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report SVC").format(precision=3)
Classification report SVC
  Survive Decease accuracy macro avg weighted avg
precision 0.743 0.717 0.734 0.730 0.732
recall 0.829 0.601 0.734 0.715 0.734
f1-score 0.784 0.654 0.734 0.719 0.729
support 310.000 223.000 0.734 533.000 533.000


We can see in the classification reports and the confusion matrices the outperformance of CMRC.

Setting the cut-off point for binary classification:

In this section we will use beeswarm-boxplot to select the cut-off point to optimise the tradeoff between false positives and false negatives. The beeswarm-boxplot is a great tool to determine the performance of the model in each of the cases of the confusion matrix. On an ideal scenario the errors are located near the cut-off point and the true guesses are located near the 0 and 1 values.

fig, ax = plt.subplots(
    nrows=2,
    ncols=2,
    figsize=(10, 12),
    sharex="all",
    sharey="all",
    gridspec_kw={"wspace": 0.1, "hspace": 0.20},
)
scatterPlot(df_CMRC, ax[0, 0])
ax[0, 0].set_title("CMRC")
scatterPlot(df_MRC, ax[1, 0])
ax[1, 0].set_title("MRC")
scatterPlot(df_LR, ax[0, 1])
ax[0, 1].set_title("LR")
scatterPlot(df_SVC, ax[1, 1])
ax[1, 1].set_title("SVC")
plt.tight_layout()
Imagen de prueba

We see in the CMRC that the correct cases have a very good conditional probability estimation with around 75% of the cases very close to the extreme values. The most problematic cases are those with a low mortality probability estimation that had a fatal outcome (FN). In the CMRC model adjusting the threshold to 0.35 reduces the false negatives by 25% adding just some cases to the FP. In the MRC model adjusting the cutoff to 0.4 reduces half of the false negatives by trading of 25% of the TP.

threshold = 0.35
df_CMRC = defDataFrame(
    model=clf_CMRC, x_test=X_test, y_test=y_test, threshold=threshold
)
threshold = 0.4
df_MRC = defDataFrame(
    model=clf_MRC, x_test=X_test, y_test=y_test, threshold=threshold
)
pd.DataFrame(
    classification_report(
        df_CMRC.Real,
        df_CMRC.Prediction,
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report CMRC \n adjusted threshold").format(
    precision=3
)
Classification report CMRC adjusted threshold
  Survive Decease accuracy macro avg weighted avg
precision 0.919 0.800 0.863 0.859 0.869
recall 0.839 0.897 0.863 0.868 0.863
f1-score 0.877 0.846 0.863 0.861 0.864
support 310.000 223.000 0.863 533.000 533.000



pd.DataFrame(
    classification_report(
        df_MRC.Real,
        df_MRC.Prediction,
        target_names=["Survive", "Decease"],
        output_dict=True,
    )
).style.set_caption("Classification report MRC \n adjusted threshold").format(
    precision=3
)
Classification report MRC adjusted threshold
  Survive Decease accuracy macro avg weighted avg
precision 0.811 0.627 0.715 0.719 0.734
recall 0.665 0.785 0.715 0.725 0.715
f1-score 0.730 0.697 0.715 0.714 0.717
support 310.000 223.000 0.715 533.000 533.000


Comparing the outputs of this example we can determine that MRCs work significantly well for estimating the outcome of COVID-19 patients at hospital triage.

Furthermore, the CMRC model with threhsold feature mapping has shown a great performance both for classifying and for estimating conditional probabilities Finally we have seen how to select the cut-off values based on data visualization with beeswarm-boxplots to increase the recall in the desired class.

Gallery generated by Sphinx-Gallery