Stroke is a leading cause of disability and death globally, necessitating accurate prediction models to identify individuals at high risk for preventive interventions. This study aims to develop and deploy a machine learning model for predicting an individual's risk of stroke based on various health, lifestyle, and demographic factors. The stroke prediction model was built using supervised learning techniques on a comprehensive dataset of 5109 respondents containing relevant features such as age, gender, hypertension status, heart disease history, average glucose level, Body Mass index, marital status, residential type, and smoking status.
The model development process involves exploratory Data preprocessing, feature engineering, and evaluating logistic regression, support vector machines, random forests, k-nearest neighbours, artificial neural networks, and extreme gradient boosting machine learning algorithms to select the best-performing model. The extreme gradient boosting was discovered to be the best-performing model, it was then integrated into a web application using the Flask framework, allowing users to input their personal information and receive a stroke risk assessment.
To facilitate accessibility and scalability, the Flask application was deployed on the Heroku cloud platform, enabling users to access the stroke risk prediction service from anywhere via a web interface. The deployed application was thoroughly tested for functionality, performance, and user experience.
The deployed stroke risk prediction model offers a user-friendly and accessible solution for individuals to assess their stroke risk conveniently. By promoting awareness and enabling early detection, this project contributes to the prevention and management of stroke-related complications, ultimately improving public health outcome.
Since the 1920s, stroke has emerged as one of the foremost contributors to mortality in the United States, indiscriminately affecting individuals across all demographics (Sherwood, 2014; White, 2017). This debilitating disease transcends age, gender, socioeconomic status, and political standing, as evidenced by the deaths of prominent world leaders like former US President Franklin Delano Roosevelt, former British Prime Minister Margaret Thatcher, and former Israeli Prime Minister Ariel Sharon, all due to stroke complications (Thatcher, 2011; Weisman, 2014; White, 2017). In the United States alone, an estimated 800,000 people experience a stroke annually, placing a significant burden on the healthcare system and society. Moreover, the global impact of stroke extends beyond individual health, as its economic repercussions are substantial.
According to the Centres for Disease Control and Prevention (2023) a stroke is a grave health crisis that arises from an obstruction in blood circulation to the brain. This can be caused by a clot blocking an artery (ischemic stroke) or a blood vessel bursting (haemorrhagic stroke). The lack of blood flow damages brain tissue, which can lead to problems like weakness, numbness, or speech difficulties, and can progress to devastating problems (Katan and Luft, 2018). According to Bustamante et al., (2021), stroke can be defined as an acute cerebrovascular condition characterized by a sudden disruption of the passage of blood to the brain, resulting in focal neurological dysfunction and potential cell death due to oxygen deprivation. Dritsas and Trigka (2022) believe that the consequences of stroke can vary from short-term issues to life-altering complications, and exhibiting symptoms such as paralysis of the arms or legs, trouble speaking, trouble walking, headaches, dizziness, reduced vision, vomiting, and sensation of numbness in the face, arms, or legs.
The history of stroke could be traced as far back as 460 to 370 BC when it was called apoplexy, a Greek word meaning struck down with violence. Scholars at this time believed that the symptoms of stroke were convulsion and paralysis. And in 1658, Johann Wepfer's studies identified that stroke was caused by bleeding into the brain (haemorrhagic stroke), and the obstruction of the arteries supplying blood to the brain (ischemic stroke). Thus stroke became known as cerebrovascular disease (Ashrafan, 2010; Thompson, 1996).
In 1948, pained by the death of Franklin Roosevelt, the United States Public Health Service floated the Framingham Heart Study. The study further identified major causal risk factors for stroke. It emphasized the importance of shifting focus from merely treating diseases as they arise to preventing them and tackling factors that can be modified to reduce risks. The Framingham Study identified two key contributors to stroke risk: modifiable and non-modifiable factors, (Nilsen, 2010). Modifiable risk factors are lifestyle choices and health conditions that you can influence, such as smoking, high blood pressure, obesity, high cholesterol, and physical inactivity. By managing these factors, it is possible to considerably reduce the probability of having a stroke. The factors that cannot be modified (non-modifiable) are factors like age, genetics, or family history of stroke (Nilsen, 2010).
Early efforts at core stroke risk analysis relied on traditional statistical models. Among these models were, the Framingham Stroke Risk Profile, the COX Regression model, the QRisk Model, the CHADS(2) Stroke Risk Model, and the Reynolds Risk Score Model. Statistical models have been criticized for their inability to capture complex, non-linear interactions between various risk factors and stroke risk (Kleinbaum and Klein, 2012).
The developments in medical research came with the use of information technology systems like artificial intelligence, machine learning, and cloud computing. Logistic regression (LR), K- Nearest Neighbours (KNN), support vector machine, eXtreme Gradient Boosting (XGB), random forest, multilayer perceptron (MLP), and neural networks are now increasingly being adopted for early stroke risk prediction and the prognosis of any medical condition.
However, current studies in stroke risk prediction, lack key elements for evaluating the practical application of stroke risk prediction models. Notably, none of them reported using decision-analytic measures to assess the models' clinical utility in real-world settings to aid clinicians in making treatment decisions and estimating prognoses or assisting patients in stroke risk monitoring and prevention. So far, the models developed have a very limited reach, with little benefit or no benefit to developing countries. This rsearch work focuses on developing and deploying a web application for stroke risk prediction on a cloud platform. By leveraging the cloud's reach, this web app empowers individuals everywhere, regardless of location, education, or financial background, to assess their stroke risk with ease. Users can securely input their biomedical and lifestyle data to receive real-time, personalized, and accurate risk predictions. Furthermore, the app allows users to monitor their risk over time. They can revisit the app to reassess their risk if they experience changes in their risk factors. This web application has the potential to become a valuable companion for individuals identified as high-risk for stroke, providing ongoing support, and encouraging proactive health management that assists clinicians in making treatment decisions and estimating prognoses.
This research will adopt a multi-modal approach to stroke risk prediction, analysing and correlating different biomedical and socio-economic risk factors with stroke. I shall employ different algorithm machine learning algorithms to build, test, and evaluate my models. The algorithm with the highest accuracy predictive capability will then be fit into a Python web application framework, Flask.
Web application development is the process of creating software applications that can be used on websites. This process requires developing a front-end interphase and back-end structure. The front end of the web application will be created with HTML and CSS and the back end will be created with Python, using Flask as a framework. Flask is a Python web framework that allows machine-learning developers to create scalable, user-friendly, and flexible machine-learning web applications. To make the model widely impactful, I shall deploy it on Heroku cloud platform where anyone irrespective of geographical local, or educational level can have access to it.
This research aims to develop a predictive model for evaluating stroke risk and deploy it on the Heroku cloud platform using Flask, a microweb framework written in Python. The following are the specific objectives.
This study investigates the following research questions:
This dissertation utilized a publicly available dataset for stroke risk prediction model development. While this approach facilitated efficient exploration of machine learning techniques, it also limited the study's ability to address specific characteristics of a targeted population. However, it is important to acknowledge the significant resource requirements associated with primary data collection. Such endeavours can be time-consuming due to data-gathering processes. Additionally, funding limitations and ethical considerations surrounding patient privacy can pose significant challenges.
Another limitation to consider is the class imbalance within the dataset. The output variable exhibited a bias towards individuals without stroke. While this imbalance was addressed through SMOTE (Synthetic Minority Oversampling Technique) normalization, it's worth acknowledging this potential limitation for future research endeavours.
The ongoing cost of cloud deployment presents a potential limitation for long-term model operation. Exploring alternative deployment strategies, such as on-premises infrastructure or cost-optimized cloud services, could be a future consideration.
This dissertation utilized a secondary dataset of electronic health records of patients released by McKinsey and Company and available from a certified open-source repository. This dataset contains 5110 records of patients. The dataset was in its raw, uncleaned state and contained 11 input variables and 1 output variable. The output variable was in a binary state, where 1 represents patients who have suffered a stroke, and 0 represents patients who have not. The output variable had highly imbalanced classes where 95% of the patients never had a stroke and only 5% have had a stroke. The study employed a profoundly professional statistical method to handle this imbalance. The 11 input variables in my dataset are Patient ID (this was dropped, as it is not useful to the study), gender, age, hypertension, heart disease, marital status, occupation, residence type, average glucose level, body mass index, and smoking status. The dataset publisher adheres to industry best practices for ethical data collection, such as informed consent and data anonymization, thus this data research is highly suitable for my dissertation.
import matplotlib.pyplot as plt #Used for creating visualizations.
%matplotlib inline
import pandas as pd #Used for data manipulation and analysis
import numpy as np #Used for numerical computations and array operations.
import seaborn as sns #Built on top of matplotlib, offering a higher-level interface for creating statistical graphics.
stroke_df = pd.read_csv(r"C:\Users\User\OneDrive - Cardiff Metropolitan University\Dissertation\Data Analytics\healthcare-dataset-stroke-data.csv")
stroke_df
id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 |
2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5105 | 18234 | Female | 80.0 | 1 | 0 | Yes | Private | Urban | 83.75 | NaN | never smoked | 0 |
5106 | 44873 | Female | 81.0 | 0 | 0 | Yes | Self-employed | Urban | 125.20 | 40.0 | never smoked | 0 |
5107 | 19723 | Female | 35.0 | 0 | 0 | Yes | Self-employed | Rural | 82.99 | 30.6 | never smoked | 0 |
5108 | 37544 | Male | 51.0 | 0 | 0 | Yes | Private | Rural | 166.29 | 25.6 | formerly smoked | 0 |
5109 | 44679 | Female | 44.0 | 0 | 0 | Yes | Govt_job | Urban | 85.28 | 26.2 | Unknown | 0 |
5110 rows × 12 columns
sns.set_style("whitegrid")
plt.style.use("fivethirtyeight")
stroke_df.head() #inspect five five columns of the dataset
id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 |
2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
stroke_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5110 entries, 0 to 5109 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 id 5110 non-null int64 1 gender 5110 non-null object 2 age 5110 non-null float64 3 hypertension 5110 non-null int64 4 heart_disease 5110 non-null int64 5 ever_married 5110 non-null object 6 work_type 5110 non-null object 7 Residence_type 5110 non-null object 8 avg_glucose_level 5110 non-null float64 9 bmi 4909 non-null float64 10 smoking_status 5110 non-null object 11 stroke 5110 non-null int64 dtypes: float64(3), int64(4), object(5) memory usage: 479.2+ KB
stroke_df.isna().sum()
id 0 gender 0 age 0 hypertension 0 heart_disease 0 ever_married 0 work_type 0 Residence_type 0 avg_glucose_level 0 bmi 201 smoking_status 0 stroke 0 dtype: int64
stroke_df.describe()
id | age | hypertension | heart_disease | avg_glucose_level | bmi | stroke | |
---|---|---|---|---|---|---|---|
count | 5110.000000 | 5110.000000 | 5110.000000 | 5110.000000 | 5110.000000 | 4909.000000 | 5110.000000 |
mean | 36517.829354 | 43.226614 | 0.097456 | 0.054012 | 106.147677 | 28.893237 | 0.048728 |
std | 21161.721625 | 22.612647 | 0.296607 | 0.226063 | 45.283560 | 7.854067 | 0.215320 |
min | 67.000000 | 0.080000 | 0.000000 | 0.000000 | 55.120000 | 10.300000 | 0.000000 |
25% | 17741.250000 | 25.000000 | 0.000000 | 0.000000 | 77.245000 | 23.500000 | 0.000000 |
50% | 36932.000000 | 45.000000 | 0.000000 | 0.000000 | 91.885000 | 28.100000 | 0.000000 |
75% | 54682.000000 | 61.000000 | 0.000000 | 0.000000 | 114.090000 | 33.100000 | 0.000000 |
max | 72940.000000 | 82.000000 | 1.000000 | 1.000000 | 271.740000 | 97.600000 | 1.000000 |
stroke_df.duplicated()
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[2], line 1 ----> 1 stroke_df.duplicated() NameError: name 'stroke_df' is not defined
LETS FILL NULL VALUES
Our Body Mass Index (BMI) column has 201 null values. It will be wrong to fill null values with average body mass index because BMI could vary with different factors like sex or age. I will use the box-plot to confirm if there is a significant difference in BMI with sex
sns.violinplot(x='gender',y='bmi', data = stroke_df)
<Axes: xlabel='gender', ylabel='bmi'>
stroke_df['AgeLabel'] = pd.cut(x=stroke_df['age'], bins=[0, 3, 17, 63, 99],
labels=['Toddler', 'Child','Adult','Elderly'])
stroke_df
id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | AgeLabel | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 | Elderly |
1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 | Adult |
2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 | Elderly |
3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 | Adult |
4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 | Elderly |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5105 | 18234 | Female | 80.0 | 1 | 0 | Yes | Private | Urban | 83.75 | NaN | never smoked | 0 | Elderly |
5106 | 44873 | Female | 81.0 | 0 | 0 | Yes | Self-employed | Urban | 125.20 | 40.0 | never smoked | 0 | Elderly |
5107 | 19723 | Female | 35.0 | 0 | 0 | Yes | Self-employed | Rural | 82.99 | 30.6 | never smoked | 0 | Adult |
5108 | 37544 | Male | 51.0 | 0 | 0 | Yes | Private | Rural | 166.29 | 25.6 | formerly smoked | 0 | Adult |
5109 | 44679 | Female | 44.0 | 0 | 0 | Yes | Govt_job | Urban | 85.28 | 26.2 | Unknown | 0 | Adult |
5110 rows × 13 columns
# plt.figure(figsize=(10,4))
sns.boxplot(x='gender',y='bmi', data = stroke_df)
plt.show()
plt.figure(figsize=(10,8))
sns.violinplot(x='AgeLabel',y='bmi', data = stroke_df)
<Axes: xlabel='AgeLabel', ylabel='bmi'>
We will replace the null values with the average Body Mass Index by sex.
plt.hist(stroke_df[stroke_df['AgeLabel']=='Toddler']['bmi'], color='red')
plt.show()
stroke_df[stroke_df['AgeLabel']=='Toddler']['bmi'].describe()
count 213.000000 mean 18.664789 std 3.096884 min 10.300000 25% 16.800000 50% 18.300000 75% 20.300000 max 33.100000 Name: bmi, dtype: float64
The mean BMI of the Toddler category is 19 appriximately
plt.hist(stroke_df[stroke_df['AgeLabel']=='Child']['bmi'])
plt.show()
stroke_df[stroke_df['AgeLabel']=='Child']['bmi'].describe()
count 623.000000 mean 22.325361 std 6.799040 min 12.000000 25% 18.000000 50% 20.600000 75% 24.800000 max 97.600000 Name: bmi, dtype: float64
plt.hist(stroke_df[stroke_df['AgeLabel']=='Adult']['bmi'], color='y')
plt.show()
stroke_df[stroke_df['AgeLabel']=='Adult']['bmi'].describe()
count 3068.000000 mean 30.699772 std 7.638310 min 11.500000 25% 25.275000 50% 29.300000 75% 34.700000 max 92.000000 Name: bmi, dtype: float64
plt.hist(stroke_df[stroke_df['AgeLabel']=='Elderly']['bmi'], color='r')
plt.show()
stroke_df[stroke_df['AgeLabel']=='Elderly']['bmi'].describe()
count 1005.000000 mean 29.617612 std 5.761813 min 11.300000 25% 26.000000 50% 28.900000 75% 32.900000 max 54.600000 Name: bmi, dtype: float64
The mean BMI of the Elderly category is 30 appriximately
We shall create a function that will replace the null values in the BMI column with the mean BMI for each age category
#lets create a function
def fill_bmi(col):
bmi = col[0]
AgeLabel = col[1]
# We checking if there is a null value in bmi column
if pd.isnull(bmi):
# sampled toddlers class
if AgeLabel == 'Toddler':
return 19
#respondents in child age group
elif AgeLabel == 'Child':
return 22
elif AgeLabel == 'Adult':
return 31
#For passengers in 3rd class
else:
return 30
else:
# if no null values, just retun our known bmi value
return bmi
We shall fill the BMI column with the average mean bmi per age category
stroke_df['bmi'] = stroke_df[['bmi','AgeLabel']].apply(fill_bmi,axis=1)
stroke_df.isna().sum()
id 0 gender 0 age 0 hypertension 0 heart_disease 0 ever_married 0 work_type 0 Residence_type 0 avg_glucose_level 0 bmi 0 smoking_status 0 stroke 0 AgeLabel 0 dtype: int64
fig,ax=plt.subplots(figsize = (5,5))
lab = ['Non-Stroke','Stroke']
col = sns.color_palette('pastel')[0:5]
stroke_df['stroke'].value_counts().plot.pie(explode=[0.1,0.0],autopct='%1.f%%',shadow=True,labels = lab, colors = col)
plt.title("Percentage Count of Stroke")
plt.show()
plt.figure(figsize=(18,8))
plt.subplot(1,2,1)
vc = stroke_df['stroke'].value_counts()
g = sns.barplot(x=vc.index,y=vc, palette='terrain_r')
for p in g.patches:
g.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.4, p.get_height()+20),ha='center', va='bottom',
color= 'black')
plt.title('Count of Stroke')
plt.subplot(1,2,2)
colors = ['#CFD6E4', '#EFCFE3', '#E4F0CF', '#F3CFB6', '#B9DCCC']
stroke_df['stroke'].value_counts().plot(kind='pie', explode=[0.1,0], autopct='%.2f%%', colors=colors)
plt.title('% Distribution of Stoke')
plt.show()
plt.figure(figsize=(18,8))
plt.subplot(1,2,1)
vc = stroke_df['stroke'].value_counts()
g = sns.barplot(x=vc.index,y=vc, palette='dark')
for p in g.patches:
g.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.4, p.get_height()+20),ha='center', va='bottom',
color= 'black')
stroke_df.columns
Index(['id', 'gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status', 'stroke', 'AgeLabel'], dtype='object')
Drop the 'id' column it is not useful
stroke_df=stroke_df.drop('id',axis=1)
stroke_df
gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | AgeLabel | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 | Elderly |
1 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | 31.0 | never smoked | 1 | Adult |
2 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 | Elderly |
3 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 | Adult |
4 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 | Elderly |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5105 | Female | 80.0 | 1 | 0 | Yes | Private | Urban | 83.75 | 30.0 | never smoked | 0 | Elderly |
5106 | Female | 81.0 | 0 | 0 | Yes | Self-employed | Urban | 125.20 | 40.0 | never smoked | 0 | Elderly |
5107 | Female | 35.0 | 0 | 0 | Yes | Self-employed | Rural | 82.99 | 30.6 | never smoked | 0 | Adult |
5108 | Male | 51.0 | 0 | 0 | Yes | Private | Rural | 166.29 | 25.6 | formerly smoked | 0 | Adult |
5109 | Female | 44.0 | 0 | 0 | Yes | Govt_job | Urban | 85.28 | 26.2 | Unknown | 0 | Adult |
5110 rows × 12 columns
Visualizing countplot for gender,ever_married, work_type, Residence_type,smoking_status
stroke_df.columns
Index(['gender', 'age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status', 'stroke', 'AgeLabel'], dtype='object')
plt.figure(figsize=(15,10))
col= ['gender','ever_married','work_type','Residence_type','smoking_status', 'AgeLabel','hypertension', 'heart_disease']
i = 1
for a in col:
plt.subplot(4,2,i)
g = sns.countplot(x=a,data=stroke_df,palette="dark")
for p in g.patches:
g.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.2, p.get_height()+20),ha='center', va='bottom',
color= 'black')
plt.xlabel(a,fontsize=30)
i = i+1
plt.tight_layout()
Convert features into categorical and continuous variables based on data type to aid understanding of EDA
categorical_val = []
continous_val = []
for column in stroke_df.columns:
#here we are using a threshold to check the categories in the
#columns
if len(stroke_df[column].unique()) <= 10:
categorical_val.append(column)
else:
continous_val.append(column)
print(continous_val)
['age', 'avg_glucose_level', 'bmi']
Visualize count plot for continous variables
plt.figure(figsize=(15,15))
for i, column in enumerate(continous_val):
plt.subplot(3, 3, i+1)
sns.histplot(x=column, bins=50, data= stroke_df, color='g')
plt.title('Count of '+ column)
plt.show()
Visualizing countplot for stroke by categorical values
plt.figure(figsize=(17,20))
col= ['gender','ever_married','work_type','Residence_type','smoking_status', 'hypertension','heart_disease']
i = 1
for a in col:
plt.subplot(4,3,i)
g = sns.countplot(x=a,hue='stroke',data=stroke_df,palette='bright')
for p in g.patches:
g.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.2, p.get_height()+20),ha='center', va='bottom',
color= 'black')
plt.xlabel(a,fontsize=30)
i = i+1
plt.tight_layout()
Visualizing countplot for stroke by continous values
plt.figure(figsize=(10,30))
for i, column in enumerate(continous_val):
plt.subplot(3, 3, i+1)
sns.histplot(x=column, bins=50, hue='stroke', data= stroke_df, color='red')
plt.title('Count of '+ column)
#plt.show()
This research leverages the visualization libraries Seaborn and Matplotlib to effectively explore the distribution of its variables. By employing histograms, count plots, and pie charts within a subplot layout, the analysis presents a comprehensive overview, enabling readers to grasp key insights immediately. This approach minimizes the need for scrolling through individual plots, promoting efficient data exploration and interpretation.
My dataset presents an array of different features with variations in distribution. This is expected when a random sampling technique is employed.
Stroke occurs less in females than males; 5.3% of the sampled male population have a stroke while 4.7% of the sampled female population have a stroke. The figure also shows that people who never smoked are the least likely to have a stroke. Age is a very significant factor in stroke risk, the elderly are the most likely to have a stroke, Urban dwellers people with heart disease, and people with hypertension are the most likely to have a stroke amongst their respective classes.
Label Encoder
from sklearn.preprocessing import LabelEncoder
le=LabelEncoder()
stroke_df['ever_married']= le.fit_transform(stroke_df['ever_married'])
stroke_df['gender']= le.fit_transform(stroke_df['gender'])
stroke_df['work_type']= le.fit_transform(stroke_df['work_type'])
stroke_df['Residence_type']= le.fit_transform(stroke_df['Residence_type'])
stroke_df['smoking_status']= le.fit_transform(stroke_df['smoking_status'])
stroke_df
gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | AgeLabel | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 67.0 | 0 | 1 | 1 | 2 | 1 | 228.69 | 36.6 | 1 | 1 | Elderly |
1 | 0 | 61.0 | 0 | 0 | 1 | 3 | 0 | 202.21 | 31.0 | 2 | 1 | Adult |
2 | 1 | 80.0 | 0 | 1 | 1 | 2 | 0 | 105.92 | 32.5 | 2 | 1 | Elderly |
3 | 0 | 49.0 | 0 | 0 | 1 | 2 | 1 | 171.23 | 34.4 | 3 | 1 | Adult |
4 | 0 | 79.0 | 1 | 0 | 1 | 3 | 0 | 174.12 | 24.0 | 2 | 1 | Elderly |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5105 | 0 | 80.0 | 1 | 0 | 1 | 2 | 1 | 83.75 | 30.0 | 2 | 0 | Elderly |
5106 | 0 | 81.0 | 0 | 0 | 1 | 3 | 1 | 125.20 | 40.0 | 2 | 0 | Elderly |
5107 | 0 | 35.0 | 0 | 0 | 1 | 3 | 0 | 82.99 | 30.6 | 2 | 0 | Adult |
5108 | 1 | 51.0 | 0 | 0 | 1 | 2 | 0 | 166.29 | 25.6 | 1 | 0 | Adult |
5109 | 0 | 44.0 | 0 | 0 | 1 | 0 | 1 | 85.28 | 26.2 | 0 | 0 | Adult |
5110 rows × 12 columns
Drop the AgeLabel column
stroke_df.drop('AgeLabel',axis=1, inplace=True)
stroke_df
gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 67.0 | 0 | 1 | 1 | 2 | 1 | 228.69 | 36.6 | 1 | 1 |
1 | 0 | 61.0 | 0 | 0 | 1 | 3 | 0 | 202.21 | 31.0 | 2 | 1 |
2 | 1 | 80.0 | 0 | 1 | 1 | 2 | 0 | 105.92 | 32.5 | 2 | 1 |
3 | 0 | 49.0 | 0 | 0 | 1 | 2 | 1 | 171.23 | 34.4 | 3 | 1 |
4 | 0 | 79.0 | 1 | 0 | 1 | 3 | 0 | 174.12 | 24.0 | 2 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5105 | 0 | 80.0 | 1 | 0 | 1 | 2 | 1 | 83.75 | 30.0 | 2 | 0 |
5106 | 0 | 81.0 | 0 | 0 | 1 | 3 | 1 | 125.20 | 40.0 | 2 | 0 |
5107 | 0 | 35.0 | 0 | 0 | 1 | 3 | 0 | 82.99 | 30.6 | 2 | 0 |
5108 | 1 | 51.0 | 0 | 0 | 1 | 2 | 0 | 166.29 | 25.6 | 1 | 0 |
5109 | 0 | 44.0 | 0 | 0 | 1 | 0 | 1 | 85.28 | 26.2 | 0 | 0 |
5110 rows × 11 columns
plt.figure(figsize=(16,8))
sns.set_context('notebook',font_scale = 1.3)
sns.heatmap(stroke_df.corr(),annot=True,linewidth =2)
plt.tight_layout()
Correlation of stroke with input variables
stroke_corr= stroke_df.corr()['stroke'].sort_values(ascending=False).to_frame()
plt.figure(figsize=(2,8))
sns.heatmap(stroke_corr, cmap='Reds', cbar=False , annot=True)
plt.show()
To explore the strength of relationships between stroke and its attributes feature, I employed Pearson’s correlation coefficient to generate a heatmap. The correlation value will determine the strength of the linear relationship between any two features of the patient’s electronic health data. I have used a heatmap show insights such that dark red means a perfect relationship, and the shades of red fades-out as you move down the ladder to a negative relationship.
The heatmap above shows the strength of relationship between all features. Age is most correlated with body mass index, a correlation coefficient of 0.33, hypertension age at 0.28, heart disease and average glucose level at a correlation coefficient of 0.26 and 0.24 respectively. The implication of these is that, it is important to maintain a healthy lifestyle as we grow older the importance of a healthy diet and lifestyle is paramount to avoid any of these depilating diseases.
The correlation coefficient between stroke and age, heart disease, average glucose level and hypertension, and marital status is 0.25, 0.13, 0.13, 0.13, and 0.11 respectively. This makes age the most correlated feature to stroke. It shows that there is a positive correlation of stroke to these features; this implies that the higher the age, glucose level, body mass index the higher the risk of having a stroke, and the presence of heart, hypertension increases the chances of a stroke. We also see that gender has very weak relationship with stroke.
X= np.asarray(stroke_df[['gender', 'age','hypertension', 'heart_disease',
'ever_married', 'work_type', 'Residence_type',
'avg_glucose_level',
'bmi','smoking_status']])
y=np.asarray(stroke_df['stroke'])
y[0:5]
array([1, 1, 1, 1, 1], dtype=int64)
I am going to use standardScaler for data standardization
from sklearn.preprocessing import StandardScaler
std=StandardScaler()
X=std.fit_transform(X)
X
array([[ 1.18807255, 1.05143428, -0.32860186, ..., 2.70637544, 0.99442704, -0.35178071], [-0.840344 , 0.78607007, -0.32860186, ..., 2.12155854, 0.26918169, 0.58155233], [ 1.18807255, 1.62639008, -0.32860186, ..., -0.0050283 , 0.46344384, 0.58155233], ..., [-0.840344 , -0.36384151, -0.32860186, ..., -0.51144264, 0.21737846, 0.58155233], [ 1.18807255, 0.34379639, -0.32860186, ..., 1.32825706, -0.43016203, -0.35178071], [-0.840344 , 0.03420481, -0.32860186, ..., -0.46086746, -0.35245718, -1.28511375]])
Splitting is an act of dividing the dataset by rows into two parts, for training and testing. For better accuracy and efficiency, I am splitting the data set into 80% for training data and 20% for testing.
from sklearn.model_selection import train_test_split
X_train , X_test , y_train , y_test = train_test_split(X,y,test_size=0.2,random_state=None)
print("Number transactions X_train dataset: ", X_train.shape)
print("Number transactions y_train dataset: ", y_train.shape)
print("Number transactions X_test dataset: ", X_test.shape)
print("Number transactions y_test dataset: ", y_test.shape)
Number transactions X_train dataset: (4088, 10) Number transactions y_train dataset: (4088,) Number transactions X_test dataset: (1022, 10) Number transactions y_test dataset: (1022,)
This imbalance in the dataset is a situation where a class of the output variable has significantly fewer instances compared to the other class. Class imbalance is a common issue, and it can impact the performance of machine learning models (Tahir, et al. 2019).
from imblearn.over_sampling import SMOTE
print("Before OverSampling, counts of label '1': {}".format(sum(y_train==1)))
print("Before OverSampling, counts of label '0': {} \n".format(sum(y_train==0)))
Before OverSampling, counts of label '1': 195 Before OverSampling, counts of label '0': 3893
sm = SMOTE(random_state=2)
X_train_res, y_train_res = sm.fit_resample(X_train, y_train.ravel())
print("After OverSampling, counts of label '1': {}".format(sum(y_train_res==1)))
print("After OverSampling, counts of label '0': {}".format(sum(y_train_res==0)))
After OverSampling, counts of label '1': 3893 After OverSampling, counts of label '0': 3893
Model Building involves three phases, the training phase, and the testing phase. The model is built using 80% of the total data and tested using 20% of data, evaluated using accuracy score, area under the curve, confusion matrix, and precision. I have described 6 different classification models; logistic regression, support vector machine, k-nearest neighbour, random forest, artificial neural network and XGBoost. The first step I have adopted is to import python evaluation metrics. Each model will start with a brief introduction followed by model training, testing and evaluation, and a short discussion.
#Import evaluation libraries
from sklearn.metrics import confusion_matrix, accuracy_score,classification_report,roc_auc_score,roc_curve
from sklearn.linear_model import LogisticRegression #import libraries
lg_model = LogisticRegression() #create model
lg_model.fit(X_train_res, y_train_res) #train model
y_pred_lr = lg_model.predict(X_test) #test mdel
ac_lr = lg_model.score(X_test,y_test)*100
ac_lr
73.77690802348337
auc_lr = roc_auc_score(y_test,y_pred_lr)*100
auc_lr
75.66574839302112
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test,y_pred_lr)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Greys)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_lr))
print('Logistic AUC: {:.3f}'.format(roc_auc_score(y_test, y_pred_lr)))
print("Accuracy of The Model :",accuracy_score(y_test,y_pred_lr)*100)
precision recall f1-score support 0 0.98 0.74 0.84 968 1 0.14 0.78 0.24 54 accuracy 0.74 1022 macro avg 0.56 0.76 0.54 1022 weighted avg 0.94 0.74 0.81 1022 Logistic AUC: 0.757 Accuracy of The Model : 73.77690802348337
from sklearn.svm import SVC
svm_model = SVC()
svm_model.fit(X_train_res,y_train_res)
SVC()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
SVC()
y_pred_svm = svm_model.predict(X_test)
ac_svm = svm_model.score(X_test,y_test)*100
ac_svm
76.22309197651663
auc_svm = roc_auc_score(y_test,y_pred_svm)*100
auc_svm
64.7172482399755
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test,y_pred_svm)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Blues)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_svm))
print('AUC_SVM: {:.3f}'.format(roc_auc_score(y_test, y_pred_svm)))
print("Accuracy of The Model :",accuracy_score(y_test,y_pred_svm)*100)
precision recall f1-score support 0 0.97 0.78 0.86 968 1 0.11 0.52 0.19 54 accuracy 0.76 1022 macro avg 0.54 0.65 0.52 1022 weighted avg 0.92 0.76 0.83 1022 AUC_SVM: 0.647 Accuracy of The Model : 76.22309197651663
from sklearn.neighbors import KNeighborsClassifier
knn_model=KNeighborsClassifier()
knn_model.fit(X_train_res,y_train_res) #create your model
knn_model.fit(X_train_res,y_train_res) #create your model
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
KNeighborsClassifier()
y_pred_knn=knn_model.predict(X_test) #create a prediction function
ac_knn=accuracy_score(y_test,y_pred_knn)*100
ac_knn
82.38747553816047
auc_knn = roc_auc_score(y_test,y_pred_knn)*100
auc_knn
59.22865013774105
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test, y_pred_knn)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Greens)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_knn))
print('AUC_KNN: {:.3f}'.format(roc_auc_score(y_test, y_pred_knn)))
print("Accuracy of The Model :",accuracy_score(y_test,y_pred_knn)*100)
precision recall f1-score support 0 0.96 0.85 0.90 968 1 0.11 0.33 0.17 54 accuracy 0.82 1022 macro avg 0.53 0.59 0.53 1022 weighted avg 0.91 0.82 0.86 1022 AUC_KNN: 0.592 Accuracy of The Model : 82.38747553816047
from sklearn.ensemble import RandomForestClassifier
rf_model=RandomForestClassifier()
rf_model.fit(X_train_res,y_train_res)
RandomForestClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomForestClassifier()
y_pred_rf=rf_model.predict(X_test)
ac_rf=accuracy_score(y_test,y_pred_rf)*100
ac_rf
#print('AUC_rf: {:.3f}'.format(roc_auc_score(y_test, Y_pred)))
91.48727984344423
auc_rf = roc_auc_score(y_test,y_pred_rf)*100
auc_rf
50.91827364554638
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test,y_pred_rf)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Oranges)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_rf))
print('AUC_rf: {:.3f}'.format(roc_auc_score(y_test, y_pred_rf)))
print("Accuracy of The Model :",accuracy_score(y_test, y_pred_rf)*100)
precision recall f1-score support 0 0.95 0.96 0.96 968 1 0.08 0.06 0.06 54 accuracy 0.91 1022 macro avg 0.51 0.51 0.51 1022 weighted avg 0.90 0.91 0.91 1022 AUC_rf: 0.509 Accuracy of The Model : 91.48727984344423
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
X_train_res.shape
X_train_res.shape
(7786, 10)
Build ANN
ann_model=Sequential()
ann_model.add(Dense(32,input_dim=X_train_res.shape[1],activation='relu'))
ann_model.add(Dense(16,activation='relu'))
ann_model.add(Dense(8,activation='relu'))
ann_model.add(Dense(4,activation='relu'))
ann_model.add(Dense(1,activation='sigmoid'))
ann_model.compile(optimizer="adam",loss="binary_crossentropy",metrics=['accuracy'])
ann_model.fit(X_train_res,y_train_res,batch_size=20,epochs = 20)
Epoch 1/20 390/390 [==============================] - 3s 3ms/step - loss: 0.5022 - accuracy: 0.7670 Epoch 2/20 390/390 [==============================] - 1s 3ms/step - loss: 0.4159 - accuracy: 0.8162 Epoch 3/20 390/390 [==============================] - 1s 3ms/step - loss: 0.3927 - accuracy: 0.8258 Epoch 4/20 390/390 [==============================] - 1s 3ms/step - loss: 0.3702 - accuracy: 0.8402 Epoch 5/20 390/390 [==============================] - 1s 3ms/step - loss: 0.3512 - accuracy: 0.8472 Epoch 6/20 390/390 [==============================] - 1s 3ms/step - loss: 0.3312 - accuracy: 0.8523 Epoch 7/20 390/390 [==============================] - 1s 3ms/step - loss: 0.3147 - accuracy: 0.8693 Epoch 8/20 390/390 [==============================] - 1s 4ms/step - loss: 0.2988 - accuracy: 0.8753 Epoch 9/20 390/390 [==============================] - 2s 4ms/step - loss: 0.2880 - accuracy: 0.8833 Epoch 10/20 390/390 [==============================] - 2s 4ms/step - loss: 0.2771 - accuracy: 0.8876 Epoch 11/20 390/390 [==============================] - 2s 4ms/step - loss: 0.2678 - accuracy: 0.8917 Epoch 12/20 390/390 [==============================] - 2s 5ms/step - loss: 0.2585 - accuracy: 0.8961 Epoch 13/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2511 - accuracy: 0.9025 Epoch 14/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2426 - accuracy: 0.9065 Epoch 15/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2357 - accuracy: 0.9087 Epoch 16/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2339 - accuracy: 0.9096 Epoch 17/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2254 - accuracy: 0.9170 Epoch 18/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2175 - accuracy: 0.9201 Epoch 19/20 390/390 [==============================] - 1s 3ms/step - loss: 0.2151 - accuracy: 0.9205 Epoch 20/20 390/390 [==============================] - 1s 4ms/step - loss: 0.2113 - accuracy: 0.9226
<keras.src.callbacks.History at 0x1260d8c78d0>
y_pred_ann=ann_model.predict(X_test)
y_pred_ann=(y_pred_ann>0.5)
y_pred_ann=y_pred_ann.astype(int)
y_pred_ann[:4]
32/32 [==============================] - 0s 3ms/step
array([[0], [0], [0], [0]])
ac_ann=accuracy_score(y_test, y_pred_ann)*100
ac_ann
83.75733855185909
auc_ann = roc_auc_score(y_test,y_pred_ann)*100
auc_ann
56.45469850015304
#print('AUC_ann: {:.3f}'.format(roc_auc_score(y_test, y_pred)))
#print("Test set Accuracy: ", metrics.accuracy_score(y_test, y_pred))
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test,y_pred_ann)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Purples)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_ann))
print('ANN_AUC: {:.3f}'.format(roc_auc_score(y_test, y_pred_ann)))
print("Accuracy of The Model :",accuracy_score(y_test,y_pred_ann)*100)
precision recall f1-score support 0 0.95 0.87 0.91 968 1 0.10 0.26 0.14 54 accuracy 0.84 1022 macro avg 0.53 0.56 0.53 1022 weighted avg 0.91 0.84 0.87 1022 ANN_AUC: 0.565 Accuracy of The Model : 83.75733855185909
# plt.bar(['Logistic Regression','SVM','KNN','Random Forest','ANN'],[ac_lr,ac_svm,ac_knn,ac_rf,ac_ann])
# plt.xlabel("Algorithms")
# plt.ylabel("Accuracy")
# plt.show()
!pip install xgboost
Requirement already satisfied: xgboost in c:\users\user\anaconda3\lib\site-packages (2.0.2) Requirement already satisfied: numpy in c:\users\user\anaconda3\lib\site-packages (from xgboost) (1.24.3) Requirement already satisfied: scipy in c:\users\user\anaconda3\lib\site-packages (from xgboost) (1.11.1)
from xgboost import XGBClassifier
xgb_model = XGBClassifier()
xgb_model.fit(X_train_res, y_train_res)
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, random_state=None, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
XGBClassifier(base_score=None, booster=None, callbacks=None, colsample_bylevel=None, colsample_bynode=None, colsample_bytree=None, device=None, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=None, grow_policy=None, importance_type=None, interaction_constraints=None, learning_rate=None, max_bin=None, max_cat_threshold=None, max_cat_to_onehot=None, max_delta_step=None, max_depth=None, max_leaves=None, min_child_weight=None, missing=nan, monotone_constraints=None, multi_strategy=None, n_estimators=None, n_jobs=None, num_parallel_tree=None, random_state=None, ...)
y_pred_xgb =xgb_model.predict(X_test) #Make predictions for values of y
ac_xgb=accuracy_score(y_test, y_pred_xgb)*100 #confirm accuracy of predictions
ac_xgb
92.27005870841487
auc_xgb = roc_auc_score(y_test,y_pred_xgb)*100
auc_xgb
53.08004285277013
plt.figure(figsize=(6, 6))
ax = plt.subplot()
cm = confusion_matrix(y_test,y_pred_xgb)
sns.heatmap(cm, annot=True, ax = ax, fmt = 'g' ,cmap=plt.cm.Reds)
ax.set_xlabel('Predicted label')
ax.set_ylabel('Actual label')
plt.show()
print(classification_report(y_test,y_pred_xgb))
print('AUC_rf: {:.3f}'.format(roc_auc_score(y_test, y_pred_xgb)))
print("Accuracy of The Model :",accuracy_score(y_test, y_pred_xgb)*100)
precision recall f1-score support 0 0.95 0.97 0.96 968 1 0.14 0.09 0.11 54 accuracy 0.92 1022 macro avg 0.55 0.53 0.54 1022 weighted avg 0.91 0.92 0.91 1022 AUC_rf: 0.531 Accuracy of The Model : 92.27005870841487
evaluation= pd.DataFrame({'Algorithms':['Logistic Regression','SVM','KNN','Random Forest','ANN', 'XGBoost'],
'Accuracy':[ac_lr, ac_svm, ac_knn, ac_rf, ac_ann, ac_xgb],
'AUC': [auc_lr, auc_svm, auc_knn, auc_rf, auc_ann, auc_xgb]})
evaluation
Algorithms | Accuracy | AUC | |
---|---|---|---|
0 | Logistic Regression | 73.776908 | 75.665748 |
1 | SVM | 76.223092 | 64.717248 |
2 | KNN | 82.387476 | 59.228650 |
3 | Random Forest | 91.487280 | 50.918274 |
4 | ANN | 83.757339 | 56.454699 |
5 | XGBoost | 92.270059 | 53.080043 |
evaluation.columns
Index(['Algorithms', 'Accuracy', 'AUC'], dtype='object')
fig = plt.figure(figsize = (11,8))
colors = sns.color_palette("Set2")
ax=sns.barplot(x=evaluation['Algorithms'],y=evaluation['Accuracy'],palette=colors, label='Accuracy')
sns.lineplot(x=evaluation['Algorithms'], y=evaluation['AUC'], label='AUC')
plt.xlabel("Machine Learning Models",fontsize = 20)
plt.ylabel("Performance (%)",fontsize = 20)
plt.title("Model Evaluation - Accuracy and AUC",fontsize = 20)
plt.xticks(fontsize = 12, horizontalalignment = 'center', rotation = 8)
plt.yticks(fontsize = 12)
for i, v in enumerate(evaluation['Accuracy']):
plt.text(i, v + 0.01, f"{v:.2f}%", ha='center', va='bottom', fontsize=20)
plt.show()
Save Logistic Regression
It has the best AUC, but lowest accuracy
!pip install pickle
ERROR: Could not find a version that satisfies the requirement pickle (from versions: none) ERROR: No matching distribution found for pickle
import pickle
with open('lg_model.pkl' , 'wb') as file : #lg_model.pkl is my pickle file in binary write mode('wb')
pickle.dump(lg_model, file)
Save XGBoost Model
Best Accuracy
import pickle
with open('xgb_model.pkl' , 'wb') as file : #xgb_model.pkl is my pickle file in binary write mode('wb')
pickle.dump(xgb_model, file)
Save Random Forest Model
Next best accuracy
import pickle
with open('rf_model.pkl' , 'wb') as file : #rf_model.pkl is my pickle file in binary write mode('wb')
pickle.dump(rf_model, file)
print(stroke_df[:5].to_csv())
print(stroke_df[5104:5109].to_csv())
,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke 0,1,67.0,0,1,1,2,1,228.69,36.6,1,1 1,0,61.0,0,0,1,3,0,202.21,31.0,2,1 2,1,80.0,0,1,1,2,0,105.92,32.5,2,1 3,0,49.0,0,0,1,2,1,171.23,34.4,3,1 4,0,79.0,1,0,1,3,0,174.12,24.0,2,1 ,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke 5104,0,13.0,0,0,0,4,0,103.08,18.6,0,0 5105,0,80.0,1,0,1,2,1,83.75,30.0,2,0 5106,0,81.0,0,0,1,3,1,125.2,40.0,2,0 5107,0,35.0,0,0,1,3,0,82.99,30.6,2,0 5108,1,51.0,0,0,1,2,0,166.29,25.6,1,0
Logistic Regression
#line 1 from given data set was correctly predicted
print(lg_model.predict([[1,67.0,0,1,1, 2,1, 228.69, 36.6, 1]]))
#serial Number 2, Correct
print(lg_model.predict([[0,61.0,0,0,1,3,0,202.21,31.0,2]]))
#serial Number 3, correctly predicted
print(lg_model.predict([[1,80.0,0,1,1,2,0,105.92,32.5,2]]))
[1]
#Serial number 5104, wrongly predicted
print(lg_model.predict([[0,13.0,0,0,0,4,0,103.08,18.6,0]]))
[1]
#Serial Number 5108, wrongly predicted
print(lg_model.predict([[1,51.0,0,0,1,2,0,166.29,25.6,1]]))
[1]
#Serial Number 5106, wrongly predicted
print(lg_model.predict([[0,81.0,0,0,1,3,1,125.2,40.0,2]]))
[1]
Random Prediction with XGBoost
#line 1 from given data set was correctly predicted
print(xgb_model.predict([[1,67.0,0,1,1, 2,1, 228.69, 36.6, 1]]))
[1]
#serial Number 2, correctly predicted
print(xgb_model.predict([[0,61.0,0,0,1,3,0,202.21,31.0,2]]))
[0]
#serial Number 3, correctly predicted
print(xgb_model.predict([[1,80.0,0,1,1,2,0,105.92,32.5,2]]))
[0]
#Serial number 5104, wrongly predicted
print(xgb_model.predict([[0,13.0,0,0,0,4,0,103.08,18.6,0]]))
[1]
#Serial Number 5108, wrongly predicted
print(xgb_model.predict([[1,51.0,0,0,1,2,0,166.29,25.6,1]]))
[1]
#Serial Number 5106, correctly predicted
print(xgb_model.predict([[0,81.0,0,0,1,3,1,125.2,40.0,2]]))
[0]
To access the flask and heroku files, please visit my GitHub page
The study has shown that a researcher can use a primary or secondary data source for stroke risk prediction and produce an outstanding study.
The study has shown the effectiveness of developing a secure and user-friendly web application using a simple and efficient Python framework, Flask.
Stroke is a critical medical condition that should be treated before it worsens. Building a machine learning model can help in the early prediction of stroke and reduce the severe impact of the future. This paper shows the performance of various machine learning algorithms in successfully predicting stroke based on multiple physiological attributes. Out of all the algorithms chosen, Naïve Bayes Classification performs best with an accuracy of 82%. The comparison of accuracies obtained from various algorithms is as shown in Fig. 12. Among all the precision, recall and F1 scores obtained, Naïve Bayes has performed better. The comparison of Precision score, recall score and F1 Stroke is a serious medical condition that should be prevented to avoid to permanent disability or death. Building and deploying a stroke risk prediction model can help in the prevention of the severe impact of stroke. This dissertation investigated the effectiveness of logistic regression, support vector machines, k-nearest neighbours, random forest, artificial neural network and extreme gradient boost in predicting the risk of stroke using patients' biomedical and lifestyle features of age, presence of heart disease and hypertensive, body mass index, average glucose level, marital status, residential and smoking status, gender and work type. The study evaluated the models against some established metrics; accuracy score, precision, recall, F1 score, and area under the curve. It was discovered that eXtreme gradient boosting (XGBoost) was best performing at predicting stroke risk. It correctly predicted 943 cases, and wrongly predicted 79 cases giving us a 92% accuracy score, it had a precision of 91%, recall of 92%, and F1 score of 91%. XGBoost was then chosen as the best model for this dissertation and integrated into a web framework, Flask. Flask proved very efficient at building secure and user-friendly web applications. Deployment of the web application to Heroku was hitch-free, the cloud platform provided me with the needed security, scalability, and accessibility for this research.
This dissertation lays the groundwork for a future Machine learning framework that leverages deep learning techniques for stroke risk prediction. Obviously, one potential area for future research is gathering and examining brain CT scan images to assess the predictive power of deep learning models in forecasting stroke events. Subsequent studies could focus on compiling a comprehensive dataset of brain imaging data, like CT scans, which could then be leveraged to develop and evaluate advanced deep learning techniques for stroke prediction.