Define functions

Read in data

# read filelist 
files_in = list.files("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/",
                      recursive = T,
                      full.names = T,
                        pattern = "csv")

# read in files
data = purrr::map(files_in, function(x){
  message("Reading in ", x)
  y = readr::read_csv(x,
                      col_types = cols(.default="c"))
})

# shorten names of files
short_names = files_in %>% 
  str_remove("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/") %>% 
  str_remove(".csv")

names(data) = short_names

# save 
saveRDS(data,"S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/datasets/all_datasets.rds")

Data cleaning

Demographics

# read in data 
data = readRDS("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/datasets/all_datasets.rds")

# get portal demographics table 
demographics_portal = data$`/datasets/portal_demographics`

demographics_portal %>% glimpse

# count NAs
count_missing(demographics_portal,"Gender")
count_missing(demographics_portal,"YearOfBirth")
count_missing(demographics_portal,"Ethnicity")
count_missing(demographics_portal,"DiagnosisYear")

# calculate current age & age at dx 
data_extract_date = as.Date("16-11-2023",format = "%d-%m-%Y")

# set unrealistic dates to NA 
demographics_portal = demographics_portal %>%
  mutate(DiagnosisYear = ifelse(DiagnosisYear < 1900 | as.numeric(DiagnosisYear) > 2023,NA, DiagnosisYear)) %>%
  mutate(SymptomsYear = ifelse(SymptomsYear < 1900 | as.numeric(SymptomsYear) > 2023,NA, SymptomsYear))

demographics_portal = demographics_portal %>%
  mutate(dob = date_from_year(YearOfBirth)) %>%
  mutate(date_of_dx = date_from_year(DiagnosisYear)) %>%
  mutate(date_of_sx = date_from_year(SymptomsYear)) %>%
  mutate(age_at_dx = delta_dates_years(date_of_dx,dob)) %>%
  mutate(age_at_data_extract = delta_dates_years(data_extract_date,dob)) %>%
  mutate(age_at_sx = delta_dates_years(date_of_sx,dob)) 

# define diagnostic lag from sx to dx 
demographics_portal = demographics_portal %>%
  mutate(time_from_sx_to_dx = age_at_dx - age_at_sx)

# filter out nonsense date of diagnosis & paediatric onset (<18 or >90)
filter_na(demographics_portal,"age_at_dx")
nrow(demographics_portal)
demographics_portal = demographics_portal %>%
  filter(age_at_dx >= 18 & age_at_dx <= 90)
nrow(demographics_portal)

# set genders other than M or F to missing 
demographics_portal = demographics_portal %>%
  mutate(Gender = ifelse(Gender %in% c("MALE","FEMALE"), Gender,NA))

# filter NAs
demographics_portal = filter_na(demographics_portal,"Gender")
demographics_portal = filter_na(demographics_portal,"DiagnosisYear")
demographics_portal = filter_na(demographics_portal,"SymptomsYear")
demographics_portal = filter_na(demographics_portal,"YearOfBirth")
demographics_portal = filter_na(demographics_portal,"MSAtDiagnosis")
demographics_portal = filter_na(demographics_portal,"Ethnicity")

# tabulate ethnicity
table(demographics_portal$EthnicityTopLevel)

# simplify ethnicity
demographics_portal = demographics_portal %>%
  mutate(ethnicity_simple = case_when(
    EthnicityTopLevel == "I am Asian or British Asian (Indian / Pakistani / Bangladeshi)" ~ "S_Asian",
    EthnicityTopLevel == "I am Black or Black British (Caribbean, African, Other)" ~ "Black",
    EthnicityTopLevel == "I am mixed (White and Black Caribbean, Black African, Asian)"  ~ "Other",
    EthnicityTopLevel == "I am white (British, Irish, Other)" ~ "White",
    EthnicityTopLevel == "I would rather not say" ~ "Other",
    EthnicityTopLevel == "Other (Chinese, Another ethnic group)" ~ "Other"))


table(demographics_portal$ethnicity_simple)

# make histograms
message("basic demographic plots")
p1=  make_hist(demographics_portal,colname = "age_at_dx")+labs(x="Age at diagnosis")
p2=  make_hist(demographics_portal,colname = "age_at_sx")+labs(x="Age at symptom onset")
p3=  make_hist(demographics_portal,colname = "age_at_data_extract")+labs(x="Age at data extract")
p4=  make_barchart(demographics_portal,colname = "Gender")
p5=  make_barchart(demographics_portal,colname = "MSAtDiagnosis")+labs(x="MS type at diagnosis")
p6=  make_barchart(demographics_portal,colname = "ethnicity_simple")+labs(x="Ethnicity")

png("demographics.png",res=600,units="in",height=8,width=8)
print(
  gridExtra::grid.arrange(p1,p2,p3,p4,p5,p6)
)
dev.off()
print(
  gridExtra::grid.arrange(p1,p2,p3,p4,p5,p6)
)

# summarise
demographics_portal %>%
  summarise_at(
    .vars = c("age_at_dx","age_at_sx","age_at_data_extract"),
    .funs = c(median,IQR),
    na.rm=T
  )


get_prop(demographics_portal,"Gender")
get_prop(demographics_portal,"MSAtDiagnosis")
get_prop(demographics_portal,"ethnicity_simple")

# select key columns 
demographics_portal = demographics_portal %>%
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  dplyr::select(UserId,Gender,ethnicity_simple,MSAtDiagnosis,DiagnosisYear,dob,date_of_dx,age_at_dx,age_at_sx,age_at_data_extract,v3_education)

# basic plots
ggplot(demographics_portal,
       aes(ethnicity_simple,age_at_dx,fill=Gender))+
         geom_boxplot(outlier.shape=NA)
ggplot(demographics_portal,
       aes(ethnicity_simple,age_at_sx,fill=Gender))+
  geom_boxplot(outlier.shape=NA)
ggplot(demographics_portal,
       aes(ethnicity_simple,as.numeric(DiagnosisYear),fill=Gender))+
  geom_boxplot(outlier.shape=NA)

# filter to people diagnosed post-mcdonald
pre = nrow(demographics_portal)
demographics_portal = demographics_portal %>%
  filter(as.numeric(DiagnosisYear) >= 2001)
post=nrow(demographics_portal)
message("removed ",post-pre, "people")

# clean first symptom 
first_symptom = data$`/datasets/portal_symptoms`%>%
  pivot_longer(cols = -1) %>%
  filter(value != "NULL") %>%
  filter(grepl("_since",name)) %>%
  mutate(name = str_remove_all(string = name,pattern="_since")) %>%
  mutate(symptom = str_remove_all(string = name,pattern="_")) %>%
  mutate(sx_date = datify2(value)) %>%
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  group_by(UserId) %>%
  slice_min(sx_date,with_ties = F) %>%
  dplyr::select(UserId,symptom,sx_date) %>%
  ungroup()

demographics_portal = demographics_portal %>%
  left_join(first_symptom,by="UserId")

Outcome measures

# get EDSS data
edss = data$`/datasets/portal_edss`

# restrict to individuals in cleaned demographics table
edss = edss %>%
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  filter(UserId %in% demographics_portal$UserId)

# format completion date 
edss = edss %>%
  mutate(date_completed = as.Date(CompletedDate_webEDSS,format = "%Y-%m-%d"))

# coerce EDSS to number 
edss = edss %>%
  mutate(webEDSS = as.numeric(webEDSS))

# count NAs
count_missing(edss,"CompletedDate_webEDSS")
count_missing(edss,"webEDSS")

# filter out nonsense readings (must be between 0 and 10)
edss = edss %>%
  mutate(webEDSS = as.numeric(webEDSS)) %>%
  mutate(webEDSS = ifelse(webEDSS <0 | webEDSS>10,NA,webEDSS))

edss = filter_na(edss,"webEDSS") 


# check n indivs
edss %>% distinct(UserId) %>% nrow()

# histogram 
p1=make_hist(edss,"webEDSS")
plot_fx("webedss.png",p1)

# bring in demographics 
# calculate time from date of diagnosis to date of edss
edss = edss %>%
  left_join(demographics_portal,by="UserId") %>%
  mutate(time_from_dx_to_edss = delta_dates_years(date_completed,date_of_dx))

message("People with EDSS readings")
nrow(edss %>% distinct(UserId))

# sense-check plot - see how EDSS changes with age 
ggplot(edss,
       aes(time_from_dx_to_edss + age_at_dx, webEDSS,col=ethnicity_simple))+
  geom_point()+
  geom_smooth(method="lm")+
  labs(x="Age at EDSS")+
  theme_minimal()

# calculate ARMSS 
library(ms.sev)

edss = edss %>%
  mutate(age_at_edss = abs(delta_dates_years(dob,date_completed)))
min_age = floor(min(edss$age_at_edss))
max_age = ceiling(max(edss$age_at_edss))
edss$age_bin = Hmisc::cut2(edss$age_at_edss, cuts = seq(min_age-2,max_age+2,by=4))
levels(edss$age_bin) = as.numeric(seq(min_age,max_age+2,by=4))

armss = global_armss(edss %>% dplyr::rename("edss" = webEDSS,"ageatedss"= age_at_edss))
armss = armss$data

# manually calculate local ARMSS 
armss = calculate_armss_score(armss,"edss","ageatedss")

plot_dat = armss %>%
  pivot_longer(cols = c("gARMSS","edss"),names_to = "outcome") %>%
  select(age_bin,outcome,value)
p=ggplot(plot_dat,aes(age_bin,value,fill=outcome))+
  geom_boxplot(outlier.shape=NA,alpha=0.5)+
  scale_fill_brewer(palette="Set1",labels = c("EDSS","ARMSS"))+
  theme_minimal()+
  labs(x="Age at EDSS",y = "ARMSS or EDSS")

plot_fx("armss.png",p,plotwidth = 6)

p=ggplot(armss,aes(gARMSS,local_armss,col=ageatedss))+
  geom_point()+
  geom_smooth(method="lm")+
  theme_minimal()+
  labs(x="Global ARMSS",y="Local ARMSS")+
  scale_x_continuous(limits = c(0,10))+
  scale_y_continuous(limits = c(0,10))+
  scale_color_viridis_c(option = "magma")

plot_fx("local_v_global_armss.png",p)
cor.test(armss$gARMSS,armss$local_armss)
cor.test(armss$gARMSS,armss$edss,method= "spearman")
cor.test(armss$local_armss,armss$edss,method = "spearman")
p


# get baseline EDSS
baseline_edss = armss %>%
  group_by(UserId) %>%
  slice_min(date_completed,with_ties = F) %>%
  dplyr::select(UserId,edss,ageatedss,gARMSS,local_armss)
colnames(baseline_edss)[-1] = paste0("baseline_",colnames(baseline_edss)[-1])

# join with demographics
demographics_portal = demographics_portal %>%
  left_join(baseline_edss,by="UserId")

# clean other outcomes 
eq5d = data$`/datasets/portal_eq5d` %>%
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  mutate(HealthState = as.numeric(HealthState)) %>%
  filter(HealthState >= 0 & HealthState <= 100) %>%
  group_by(UserId) %>%
  mutate(eq5d_date = datify(CompletedDate)) %>%
  slice_min(eq5d_date,with_ties = F) %>%
  dplyr::rename("eq5d_vas" = HealthState) %>%
  dplyr::select(UserId,eq5d_vas,eq5d_date)

msis29 = data$`/datasets/portal_msis` %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  filter(version==2) %>%
  group_by(UserId) %>%
  mutate(msis_date = datify(CompletedDate)) %>%
  slice_min(msis_date,with_ties = F) %>%
  dplyr::select(-version,-CompletedDate) %>%
  ungroup() %>% 
  mutate_at(.vars = vars(-1,-msis_date),as.numeric) %>%
  na.omit() %>%
  mutate(msis = rowSums(across(c(2:21)),na.rm=T)) %>%
  dplyr::select(1,msis,msis_date) %>%
  filter(msis >=20 & msis <= 80) %>%
  mutate(msis = (msis - 20)/60 * 100 )


msis29_all_readings = data$`/datasets/portal_msis` %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  filter(version==2) %>%
  group_by(UserId) %>%
  mutate(msis_date = datify(CompletedDate)) %>%
  dplyr::select(-version,-CompletedDate) %>%
  ungroup() %>% 
  mutate_at(.vars = vars(-1,-msis_date),as.numeric) %>%
  na.omit() %>%
  mutate(msis = rowSums(across(c(2:21)),na.rm=T)) %>%
  dplyr::select(1,msis,msis_date) %>%
  filter(msis >=20 & msis <= 80) %>%
  mutate(msis = (msis - 20)/60 * 100 )


fss = data$`/datasets/portal_fss` %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  group_by(UserId) %>%
  mutate(fss_date = datify(CompletedDate)) %>%
  slice_min(fss_date,with_ties = F) %>%
  dplyr::select(-CompletedDate) %>%
  ungroup() %>% 
  mutate_at(.vars = vars(-1,-fss_date),as.numeric) %>%
  na.omit() %>%
  mutate(fss = rowSums(across(c(2:10)),na.rm=T)) %>%
  filter(fss >= 9 & fss <= 63) %>%
  dplyr::select(1,fss,fss_date) %>%
  mutate(fss = (fss-9)/(63-9) * 100 )


msws = data$`/datasets/portal_msws` %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  group_by(UserId) %>%
  mutate(msws_date = datify(CompletedDate)) %>%
  slice_min(msws_date,with_ties = F) %>%
  dplyr::select(-CompletedDate) %>%
  ungroup() %>% 
  mutate_at(.vars = vars(-1,-CannotWalk,-msws_date),as.numeric) %>%
  na.omit() %>%
  mutate(msws = rowSums(across(c(3:14)),na.rm=T)) %>%
  dplyr::select(1,msws,msws_date) %>%
  filter(msws >= 12 & msws <= 60) %>%
  mutate(msws = (msws - 12 )/48 * 100)

msws_all_readings = data$`/datasets/portal_msws` %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  group_by(UserId) %>%
  mutate(msws_date = datify(CompletedDate)) %>%
  dplyr::select(-CompletedDate) %>%
  ungroup() %>% 
  mutate_at(.vars = vars(-1,-CannotWalk,-msws_date),as.numeric) %>%
  na.omit() %>%
  mutate(msws = rowSums(across(c(3:14)),na.rm=T)) %>%
  dplyr::select(1,msws,msws_date) %>%
  filter(msws >= 12 & msws <= 60) %>%
  mutate(msws = (msws - 12 )/48 * 100)

# save cleaned demographics file 
root_path = "S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/datasets/"
saveRDS(demographics_portal,paste0(root_path,"demographics_cleaned.rds"))
saveRDS(armss,paste0(root_path,"edss_cleaned.rds"))
saveRDS(msis29,paste0(root_path,"msis_cleaned.rds"))
saveRDS(msis29_all_readings,paste0(root_path,"msis_cleaned_longitudinal.rds"))
saveRDS(edss,paste0(root_path,"edss_cleaned_longitudinal.rds"))
saveRDS(eq5d,paste0(root_path,"eq5d_cleaned.rds"))
saveRDS(fss,paste0(root_path,"fss_cleaned.rds"))
saveRDS(msws,paste0(root_path,"msws_cleaned.rds"))
saveRDS(msws_all_readings,paste0(root_path,"msws_cleaned_longitudinal.rds"))

Make matched cohort

demographics_portal = demographics_portal %>%
  mutate(round_age_of_dx = round(age_at_dx,0))

# get black participants
black_participants = demographics_portal %>%
  filter(ethnicity_simple == "Black")
sa_participants = demographics_portal %>%
  filter(ethnicity_simple == "S_Asian")
white_participants = demographics_portal %>%
  filter(ethnicity_simple == "White")

# match 
matched_df = data.frame()
case_df = data.frame()

# this code matches on year of dx, age at dx, sex, and ms type
# it excludes ppl with missing MSIS and people without 2 white controls
match_n = 2
for(i in c(1:nrow(black_participants))){
  this_person = black_participants[i,] 
  
  # matching controls 
  matches = white_participants %>%
    filter(
      UserId %in% msis29$UserId &
      !UserId %in% matched_df$UserId &
      Gender == this_person$Gender &
      DiagnosisYear == this_person$DiagnosisYear &
      round_age_of_dx == this_person$round_age_of_dx &
        MSAtDiagnosis == this_person$MSAtDiagnosis)
  
  if(nrow(matches)>=match_n){
    matched_df <<- bind_rows(matched_df,
                             sample_n(matches,match_n))
    case_df <<- bind_rows(case_df,this_person)
  }
}
for(i in c(1:nrow(sa_participants))){
  this_person = sa_participants[i,] 
  
  # matching controls 
  matches = white_participants %>%
    filter(
      !UserId %in% matched_df$UserId &
      UserId %in% msis29$UserId &
      Gender == this_person$Gender &
      DiagnosisYear == this_person$DiagnosisYear &
      round_age_of_dx == this_person$round_age_of_dx &
        MSAtDiagnosis == this_person$MSAtDiagnosis)
  
  if(nrow(matches)>=match_n){
    matched_df <<- bind_rows(matched_df,
                             sample_n(matches,match_n))
    case_df <<- bind_rows(case_df,this_person)
  }
}

# combine 
combo_dat = bind_rows(
  matched_df,
  case_df
)
saveRDS(combo_dat,paste0(root_path,"matched_cohort.rds"))

Combine datasets

###########################
# Combine with outcomes 
###########################

# exclude missing/NA/other ethnicity
demographics = readRDS(paste0(root_path,"demographics_cleaned.rds"))

# tabulate 
table(demographics$ethnicity_simple)
demographics = demographics %>%
  filter(ethnicity_simple != "Other")
nrow(demographics)
edss = readRDS(paste0(root_path,"edss_cleaned.rds"))
msis = readRDS(paste0(root_path,"msis_cleaned.rds"))
eq5d = readRDS(paste0(root_path,"eq5d_cleaned.rds"))
fss = readRDS(paste0(root_path,"fss_cleaned.rds"))
msws = readRDS(paste0(root_path,"msws_cleaned.rds"))

# combine with these outcomes 
population = demographics %>%
  left_join(msis,by="UserId") %>%
  left_join(eq5d,by="UserId") %>%
  left_join(fss,by="UserId") %>%
  left_join(msws,by="UserId")

# relevel factors
population$ethnicity_simple = relevel(factor(population$ethnicity_simple),ref="White")

# add indicator re whether in matched cohort 
population = population %>%
  mutate(in_matched_cohort = ifelse(UserId %in% combo_dat$UserId,
                                    "Yes",
                                    "No"))

# add age 
population = population %>%
  mutate(age_at_msis = abs(delta_dates_years(dob,msis_date))) %>%
  mutate(age_at_eq5d = abs(delta_dates_years(dob,eq5d_date))) %>%
  mutate(age_at_fss = abs(delta_dates_years(dob,fss_date))) %>%
  mutate(age_at_msws = abs(delta_dates_years(dob,msws_date)))


# add DMT data 
# clean dmt data 
dmt = data$`/datasets/portal_dmt_new`

# filter to those in demog 
dmt = dmt %>% 
  dplyr::rename("UserId" = PortalUserId_Enc) %>%
  filter(UserId %in% population$UserId)

# define baseline date as minimum of scores 
population = population %>%
  mutate(approx_baseline_age = 
           pmin(baseline_ageatedss,age_at_msis,age_at_eq5d,age_at_fss,age_at_msws))

# remove records from after study entry
dmt = dmt %>%
  mutate(start_date = datify(StartDate)) %>%
  filter(start_date < Sys.Date()) %>%
  left_join(population %>%
              dplyr::select(UserId,dob,approx_baseline_age),
            by="UserId") %>%
  mutate(age_at_dmt_start = abs(delta_dates_years(start_date,dob))) %>%
  filter(age_at_dmt_start < approx_baseline_age)

# clean drug names 
dmt = dmt %>%
  separate(Name, sep=" ",into = c("DMT","other")) 

# classify into type
high_efficacy_dmt = c("Alemtuzumab","Cladribine","Fingolimod","Natalizumab","Ocrelizumab","Siponimod","Tysabri","Ocrevus","Gilenya","Mavenclad","Lemtrada","Ofatumumab")
low_efficacy_dmt = c("Beta-Interferon","Dimethyl Fumarate","Glatiramer Acetate","Teriflunomide","Tecfidera","Copaxone","Avonex","Rebif","Plegridy","Betaferon")
dmt = dmt %>%
  mutate(dmt_type = 
           case_when(
             DMT %in% high_efficacy_dmt ~ "High_efficacy",
             DMT %in% low_efficacy_dmt ~ "Low_efficacy"
           ))
high_eff_records = dmt %>% filter(DMT %in% high_efficacy_dmt)
low_eff_records = dmt %>% filter(DMT %in% low_efficacy_dmt)

# add indicator re whether in DMT
population = population %>%
  mutate(in_dmt_table = ifelse(UserId %in% dmt$UserId,"in_dmt","no_dmt_data"))

# add indicator re whether has had high or low eff DMT prior to study entry
population = population %>%
  mutate(high_eff_dmt = case_when(
    UserId %in% high_eff_records$UserId ~ "high_eff_DMT",
    !(UserId %in% high_eff_records$UserId) & UserId %in% low_eff_records$UserId ~ "no_high_eff_DMT",
    !(UserId %in% high_eff_records$UserId) & !(UserId %in% low_eff_records$UserId) & UserId %in% dmt$UserId ~ "no_high_eff_DMT",
    !(UserId %in% dmt$UserId) ~ "NA"
  )) %>%
  mutate(high_eff_dmt = ifelse(high_eff_dmt=="NA",NA,high_eff_dmt))
saveRDS(population,paste0(root_path,"population_for_models.rds"))

Analysis

Demographics

# read in data 
population = readRDS(paste0(root_path,"population_for_models.rds"))

# count NAs 
count_na("msis")
count_na("baseline_edss")
count_na("fss")
count_na("msws")
count_na("eq5d_vas")
count_na("msis")

# plot missingness
plot_dat = bind_rows(
  make_plot_dat("msis"),
  make_plot_dat("baseline_edss"),
  make_plot_dat("eq5d_vas"),
  make_plot_dat("fss"),
  make_plot_dat("msws")
)

# print table
tbl1b = plot_dat %>% 
  mutate(n_perc = paste0(n," (",round(percent,2),"%)")) %>%
  pivot_wider(id_cols = outcome,values_from = n_perc,names_from = ethnicity_simple) %>%
  dplyr::select(-All)
write_csv(tbl1b,"table1b.csv")

p=ggplot(plot_dat %>%
         mutate(outcome = ifelse(outcome == "baseline_edss","EDSS",outcome)) %>%
         mutate(outcome = ifelse(outcome == "eq5d_vas","EQ5D",outcome)),
       aes(outcome,percent,fill=ethnicity_simple))+
  geom_col(color="black",position=position_dodge())+
  labs(x="Metric",y="% with available (non-missing) data",fill="Ethnicity")+
  scale_fill_brewer(palette="Set1")+
  theme_minimal()
plot_fx("missingness.png",p)  

# plot age at baseline 
make_age_boxplot = function(colname, lab){
  ggplot(population,
       aes(ethnicity_simple,.data[[colname]],fill=Gender))+
  geom_boxplot(outlier.shape=NA)+
    theme_minimal()+
    scale_fill_brewer(palette = "Set1")+
    labs(x="Ethnicity",y=paste0("Age at baseline ",lab))+
    scale_y_continuous(limits=c(0,100))
}

p1=make_age_boxplot("age_at_msis","MSIS29")
p2=make_age_boxplot("age_at_fss","FSS")
p3=make_age_boxplot("baseline_ageatedss","EDSS")
p4=make_age_boxplot("age_at_eq5d","EQ5D")
p5=make_age_boxplot("age_at_msws","MSIS29")
png("age_at_severity_measures.png",res=600,units="in",height=8,width=8)
gridExtra::grid.arrange(
  p1,p2,p3,p4,p5 
)
dev.off()


# define PMS onset 
population = population %>%
  mutate(MSAtDiagnosis = ifelse(MSAtDiagnosis %in% c("PPMS","RRMS","SPMS"),MSAtDiagnosis,NA)) %>%
  mutate(pms_onset= ifelse(MSAtDiagnosis=="PPMS","Y","N"))

# get education data 
population = population %>%
  mutate(uni_education = case_when(
    v3_education %in% c("3","4") ~ "Uni",
    v3_education %in% c("0","1","2") ~ "Non_uni",
    v3_education == "5" ~ "NA")) %>%
  mutate(uni_education = ifelse(uni_education=="NA",NA,uni_education))

population = population %>%
  mutate(had_dmt = case_when(
    high_eff_dmt == "high_eff_DMT" ~ "yes",
    high_eff_dmt == "no_high_eff_DMT" ~ "yes",
    is.na(high_eff_dmt) ~ "no"
  ))

# define missingness
population = population %>%
  mutate("Missing EDSS" = ifelse(is.na(baseline_edss),"Missing","Non-missing")) %>%
  mutate("Missing MSWS" = ifelse(is.na(msws),"Missing","Non-missing")) %>%
  mutate("Missing MSIS" = ifelse(is.na(msis),"Missing","Non-missing")) %>%
  mutate("Missing EQ5D" = ifelse(is.na(eq5d_vas),"Missing","Non-missing")) %>%
  mutate("Missing FSS" = ifelse(is.na(fss),"Missing","Non-missing"))

# stats
x="ethnicity_simple"
get_count = function(x){
population %>%
  dplyr::count(.data[[x]]) %>%
  mutate(total = sum(n),
         prop = n/sum(n))
}
get_stat = function(x){
  population %>%
    summarise(median(.data[[x]]),
              IQR(.data[[x]])
              )
}
get_count("ethnicity_simple")
get_count("Gender")
get_count("pms_onset")
get_stat("age_at_sx")
get_stat("age_at_dx")

comp = compareGroups::compareGroups(data = population %>% 
                                      mutate(dx_year = as.numeric(DiagnosisYear)),
                                    formula = ethnicity_simple ~ 
                                      age_at_dx + 
                                      dx_year +
                                      Gender +
                                      uni_education +
                                      age_at_sx + 
                                      in_dmt_table + 
                                      had_dmt+
                                      pms_onset+
                                      baseline_edss+
                                      baseline_ageatedss+
                                      baseline_gARMSS+
                                      msis+
                                      age_at_msis+
                                      eq5d_vas+
                                      age_at_eq5d+
                                      fss +
                                      age_at_fss + 
                                      msws +
                                      age_at_msws+
                                      `Missing EDSS`+
                                      `Missing MSIS`+
                                      `Missing FSS`+
                                      `Missing EQ5D`+
                                      `Missing MSWS`,
                                    method = c(2,2,3,3,2,3,3,3,2,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3))
comp
x=compareGroups::createTable(comp)
#compareGroups::export2word(x,"S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/demographics_table.doc")

Explore severity measures

# look at raw distributions of severity measures
p1=make_hist(population,"msis")+labs(x="MSIS29")
p2=make_hist(population,"eq5d_vas")+labs(x="EQ5D")
p3=make_hist(population,"fss")+labs(x="FSS")
p4=make_hist(population,"msws")+labs(x="MSWS")
p5=make_hist(population,"baseline_edss")+labs(x="EDSS")
p6=make_hist(population,"baseline_gARMSS")+labs(x="gARMSS")

png("raw_severity_distros.png",res=600,units="in",height=8,width=8)
gridExtra::grid.arrange(
  p1,p2,p3,p4,p5,p6  
)
dev.off()

# correlation between outcomes 
make_edss_comparison_plot = function(outcome,outcome_name){
  p=ggplot(population %>%
              filter(!is.na(baseline_edss) & !is.na(.data[[outcome]])),
            aes(factor(baseline_edss),.data[[outcome]],fill=Gender))+
    geom_boxplot(outlier.shape=NA)+
    theme_minimal()+
    scale_fill_brewer(palette="Set1")+
    labs(x="Baseline EDSS",y=outcome_name)
  p
}

p1=make_edss_comparison_plot("msis","Baseline MSIS29")
p2=make_edss_comparison_plot("eq5d_vas","Baseline EQ5D")
p3=make_edss_comparison_plot("fss","Baseline FSS")
p4=make_edss_comparison_plot("msws","Baseline MSWS")

plot_fx(
  "outcome_correlations.png",
  gridExtra::grid.arrange(p1,p2,p3,p4),
  10,10)


# correlations 
dat_for_cor = population %>% 
  dplyr::select(baseline_edss,baseline_gARMSS,msis,eq5d_vas,fss,msws) %>% 
  na.omit() %>%
  dplyr::rename("EDSS"=baseline_edss,"gARMSS"=baseline_gARMSS)
cor_mat = cor(dat_for_cor,method = "spearman")
cor_p_mat = corrplot::cor.mtest(dat_for_cor,method="spearman")

png("cor_plot.png",res=600,units="in",height=6,width=6)
corrplot::corrplot.mixed(cor_mat,p.mat = cor_p_mat$p,order="AOE")
dev.off()

# check correlation between numeric predictors 
population$DiagnosisYear = as.numeric(population$DiagnosisYear)
numeric_predictors = population %>% dplyr::select(contains("age"),DiagnosisYear) %>%
  na.omit()

cor_mat = cor(numeric_predictors)
cor_p_mat = corrplot::cor.mtest(numeric_predictors)
png("cor_plot_predictors.png",res=600,units="in",height=6,width=6)
corrplot::corrplot.mixed(cor_mat,p.mat=cor_p_mat$p)
dev.off()

Age at onset / diagnosis models

# exploratory plot 

ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(age_at_dx,fill=ethnicity_simple))+
  facet_grid(Gender~pms_onset)+
  geom_density(alpha=0.5)+
  theme_minimal()+
  labs(x="Age at diagnosis",fill="Ethnicity",y="Density")

ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(age_at_sx,fill=ethnicity_simple))+
  facet_grid(Gender~pms_onset)+
  geom_density(alpha=0.5)+
  theme_minimal()+
  labs(x="Age at symptom onset",fill="Ethnicity",y="Density")

population = population %>% mutate(diagnostic_lag = age_at_dx - age_at_sx)
ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(diagnostic_lag,fill=ethnicity_simple))+
  facet_grid(Gender~pms_onset)+
  geom_density(alpha=0.5)+
  theme_minimal()+
  labs(x="Diagnostic delay",fill="Ethnicity",y="Density")

# boxplot 
p1=ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(Gender,age_at_sx,fill=ethnicity_simple))+
  geom_boxplot(outlier.shape=NA,position = position_dodge(width=1),show.legend =T) +
  theme_minimal()+
  facet_wrap(~pms_onset)+
  theme(legend.position = "top")+
  scale_fill_brewer(palette="Set1",labels = c("White","Black","South Asian"))+
  scale_x_discrete(labels = c("Female","Male"))+
  labs(y="Age at symptom onset",fill="Ethnicity",x="Gender")
p2=ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(Gender,age_at_dx,fill=ethnicity_simple))+
  geom_boxplot(outlier.shape=NA,position = position_dodge(width=1),show.legend = T) +
  theme_minimal()+
  theme(legend.position = "top")+
  facet_wrap(~pms_onset)+
  scale_fill_brewer(palette="Set1",labels = c("White","Black","South Asian"))+
  scale_x_discrete(labels = c("Female","Male"))+
  labs(y="Age at diagnosis",fill="Ethnicity",x="Gender")

plot_fx("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/age_plots.png",gridExtra::grid.arrange(p2,p1),plotwidth = 5,plotheight = 6)

ggplot(population %>%
         filter(!is.na(pms_onset)) %>%
         mutate(pms_onset = ifelse(pms_onset == "Y","PPMS","RMS")),
       aes(Gender,diagnostic_lag,fill=ethnicity_simple))+
  geom_boxplot(outlier.shape=NA,alpha=0.5) +
  theme_minimal()+
  facet_wrap(~pms_onset)+
  labs(y="Diagnostic lag",fill="Ethnicity",x="Gender")


# crude histograms 
hist(population$age_at_dx)
hist(population$age_at_sx)
hist(population$diagnostic_lag)

population %>% filter(diagnostic_lag < 0)
get_stat("diagnostic_lag")
population %>% filter(diagnostic_lag < 0)

compareGroups::compareGroups(ethnicity_simple ~ diagnostic_lag,population,method="2")

# fit model 
# symptom onset
sx_model = lm(
  data = population,
  age_at_sx ~ ethnicity_simple
)
summary(sx_model)$coefficients

sx_model_gender = lm(
  data = population,
    age_at_sx ~ Gender + pms_onset + ethnicity_simple
)

clean_regression_output = function(x){
  output = data.frame(summary(x)$coefficients)
  output$variable = rownames(output)
  colnames(output) = c("beta","se","t","pval","variable")
  output = output %>% filter(variable != "(Intercept)")

  #  rename 
  output = output %>%
    mutate(variable = case_when(
      variable == "GenderMALE" ~ "Male gender (vs Female)",
      variable == "pms_onsetY" ~ "PPMS (vs RMS)",
      variable == "ethnicity_simpleBlack" ~ "Black ethnicity (vs White)",
      variable == "ethnicity_simpleS_Asian" ~ "South Asian ethnicity (vs White)")
    )
  
}

# dx 
dx_model = lm(
  data = population,
  age_at_dx ~ ethnicity_simple
)
summary(dx_model)$coefficients

dx_model_gender = lm(
  data = population,
    age_at_dx ~ Gender + pms_onset + ethnicity_simple
)
summary(dx_model_gender)$coefficients


# plot 
df1 = clean_regression_output(sx_model_gender) %>% mutate(outcome = "Symptom onset")
df2 = clean_regression_output(dx_model_gender) %>% mutate(outcome = "Diagnosis")
plot_dat = bind_rows(df1,df2)
plot_dat$variable = factor(plot_dat$variable,levels = unique(plot_dat$variable),ordered=T)
p=ggplot(plot_dat,aes(beta,variable,col=outcome,fill=outcome))+
    geom_errorbarh(show.legend = F,mapping = aes(xmin = beta- 1.96*se, 
                                 xmax = beta + 1.96*se,
                                 y=variable),
                   height=0.3,
                   color="black",
                   position = ggstance::position_dodgev(height=0.3))+
      geom_point(position = ggstance::position_dodgev(height=0.3),shape=21,color="black",size=2)+
    theme_bw()+
    geom_vline(xintercept=0,linetype="dashed",alpha=0.3)+
  labs(x="Mean effect on age \nat onset / diagnosis (years)",
       y="Variable",
       fill="Outcome")+
  scale_fill_brewer(palette="Set1")+
  theme(legend.position="top")+
  annotate("label",x=7,y=4.5,label="Later onset/diagnosis",size=3)+
  annotate("label",x=-7,y=4.5,label="Earlier onset/diagnosis",size=3)+
  scale_x_continuous(limits = c(-12,12))+
  expand_limits(y = c(0,4.7))
  
  
plot_fx(p,file="S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/age_regression_plot.png",plotheight = 4,plotwidth = 5)


sx_model = lm(
  data = population %>% 
    filter(as.numeric(DiagnosisYear)>2010),
  age_at_sx ~ Gender + pms_onset + ethnicity_simple
)
summary(sx_model)$coefficients

# diagnosis
dx_model = lm(
  data = population  %>% 
    filter(as.numeric(DiagnosisYear)>2010),
  age_at_dx ~ Gender + pms_onset + ethnicity_simple
)
summary(dx_model)$coefficients

#lag 
lag_model = lm(
  data = population %>% filter(diagnostic_lag > 0),
  diagnostic_lag ~ ethnicity_simple
)
model_data_input = population %>% 
    filter(diagnostic_lag > 0)
outputs = list()
for(i in c(1:1000)){
  model_dat = model_data_input %>%
    sample_n(size = nrow(model_data_input),
             replace=T)
  lag_model = lm(
    data = model_dat,
    diagnostic_lag ~ ethnicity_simple + Gender + pms_onset
    )
  out = summary(lag_model)$coefficients %>%
    data.frame() %>%
    mutate(var = rownames(summary(lag_model)$coefficients)) %>%
    filter(grepl("ethnicity_simple",var)) %>%
    dplyr::select(5,1,2) %>%
    mutate(var = str_remove_all(var,"ethnicity_simple"))
  outputs[[i]] = out

}

# bind 
outputs = do.call("bind_rows",outputs)

# get empirical CI 
outputs %>%
  group_by(var) %>%
  summarise(med_beta = median(Estimate),
            lower_ci = quantile(Estimate,0.025),
            upper_ci = quantile(Estimate,0.975)
      )
outputs %>%
  mutate(sign_beta = sign(Estimate)) %>%
  group_by(var) %>%
  dplyr::count(sign_beta) %>%
  mutate(p = 1 - ((n)/(sum(n)+1)))

model_data_input %>%
  group_by(ethnicity_simple) %>%
  summarise(median(diagnostic_lag))

lag_model = lm(
  data = population,
  diagnostic_lag ~ Gender + pms_onset + ethnicity_simple
)
summary(lag_model)$coefficients

First symptom

library(nnet)

count_na("symptom")

# simplify symptoms 
population = population %>%
  mutate(symptom_simple = case_when(
    symptom == "AlteredSensation" ~ "sensory",
    symptom == "Ataxia" ~ "brainstem",
    symptom == "BladderProblems" ~ "sphincter",
    symptom == "BowelProblems" ~ "sphincter",
    symptom == "BriefRepetitiveSymptoms" ~ "paroxysmal",
    symptom == "CognitiveDifficulties" ~ "cognitive/mood/fatigue",
    symptom == "Depression" ~ "cognitive/mood/fatigue",
    symptom == "DifficultySpeaking" ~ "bulbar",
    symptom == "DifficultySwallowing" ~ "brainstem",
    symptom == "DoubleVision" ~ "brainstem",
    symptom == "Dysarthia" ~ "brainstem",
    symptom == "Fatigue" ~ "cognitive/mood/fatigue",
    symptom == "Gait" ~ "motor",
    symptom == "MotorControl" ~ "motor",
    symptom == "MusclePain" ~ "pain",
    symptom == "Nystagmus" ~ "brainstem",
    symptom == "OpticNeuritis" ~ "optic neuritis",
    symptom == "Pain" ~ "pain",
    symptom == "Parasthesia" ~ "sensory",
    symptom == "SensoryLoss" ~ "sensory",
    symptom == "SexualDysfunction" ~ "sphincter",
    symptom == "Spasticity" ~ "pain",
    symptom == "Tremors" ~ "paroxysmal", 
    symptom == "TrigeminalNeuralgia" ~ "paroxysmal",
    symptom == "Weakness" ~ "motor"
  ))

# plot first symptom 
sx_counts = population %>%
  group_by(ethnicity_simple) %>%
  dplyr::count(symptom_simple) %>%
  arrange(desc(n)) %>%
  filter(!is.na(symptom_simple)) %>%
  mutate(prop = n/sum(n),total = sum(n))

# mutate <5 counts to other 
sx_counts = sx_counts %>%
  mutate(symptom_simple = ifelse(n < 5,"other",symptom_simple)) %>%
  group_by(ethnicity_simple,symptom_simple) %>%
  mutate(n_new = sum(n)) %>%
  mutate(prop_new = n_new/total) %>%
  distinct(ethnicity_simple,symptom_simple,.keep_all = T)

sx_counts = sx_counts %>% group_by(ethnicity_simple) %>% mutate(cumprop = cumsum(prop_new))
sx_counts$symptom_simple = factor(sx_counts$symptom_simple,levels= unique(sx_counts$symptom_simple),ordered=T) 
plot_dat = sx_counts %>%
            mutate(symptom_simple= str_to_title(symptom_simple)) %>%
            filter(ethnicity_simple != "other") %>%
            mutate(label = ifelse(prop_new >= 0.05,
                                  paste0(round(prop_new*100,0),"%"),
                                  " ")) %>%
           mutate(ethnicity_simple = as.character(ethnicity_simple)) %>%
           mutate(ethnicity_simple = ifelse(ethnicity_simple == "S_Asian","South Asian",ethnicity_simple))
  
p=ggplot(plot_dat,
  aes(ethnicity_simple,prop_new,fill=symptom_simple,label=
        label))+
  geom_col(color="black")+
  geom_text(position= position_stack(vjust=0.5))+
  scale_fill_brewer(palette="Paired")+
  labs(x="Ethnicity",y="Proportion",fill="First symptom")+
  theme_bw()

plot_fx("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/first_sx.png",
        p,plotwidth=5)

population %>% dplyr::count(symptom)


sx_to_keep = population %>% 
  group_by(ethnicity_simple) %>% 
  dplyr::count(symptom_simple) %>%
  filter(n>1) %>%
  ungroup() %>%
  dplyr::count(symptom_simple) %>%
  filter(n==3 & !is.na(symptom_simple))

# set ref level
model_dat = population %>%
  filter(symptom_simple %in% sx_to_keep$symptom_simple)
model_dat$symptom_simple = relevel(factor(model_dat$symptom_simple),ref="motor")

model = multinom(data = model_dat ,
         symptom_simple ~ ethnicity_simple + Gender + pms_onset)

summ = summary(model)
z_scores = summ$coefficients / summ$standard.errors
pvals = 2*(1-pnorm(abs(z_scores)))
betas = summ$coefficients %>%
  data.frame() %>%
  dplyr::select(contains("ethnicity")) %>%
  mutate(symptom = rownames(z_scores)) %>%
  pivot_longer(cols = contains("ethnicity")) %>%
  dplyr::rename("beta" = value) %>%
  mutate(ethnicity = str_remove_all(name,"ethnicity_simple"))
ses = summ$standard.errors %>%
  data.frame() %>%
  dplyr::select(contains("ethnicity")) %>%
  mutate(symptom = rownames(z_scores)) %>%
  pivot_longer(cols = contains("ethnicity")) %>%
  dplyr::rename("se" = value) %>%
  mutate(ethnicity = str_remove_all(name,"ethnicity_simple"))
pvals = pvals %>%
  data.frame() %>%
  dplyr::select(contains("ethnicity")) %>%
  mutate(symptom = rownames(z_scores)) %>%
  pivot_longer(cols = contains("ethnicity")) %>%
  dplyr::rename("p" = value) %>%
  mutate(ethnicity = str_remove_all(name,"ethnicity_simple"))

# combine 
plot_dat = betas %>%
  left_join(ses,by=c("symptom","ethnicity")) %>%
  left_join(pvals,by=c("symptom","ethnicity")) %>%
  dplyr::select(-contains("name"))

plot_dat = plot_dat %>% mutate(ethnicity = ifelse(ethnicity=="S_Asian","South Asian",ethnicity))

p=ggplot(plot_dat,aes(beta,str_to_title(symptom),
                      col=ethnicity,fill=ethnicity))+
    geom_errorbarh(show.legend = F,mapping = aes(xmin = beta- 1.96*se, 
                                 xmax = beta + 1.96*se,
                                 y=str_to_title(symptom)),
                   height=0.3,
                   color="black",
                   position = ggstance::position_dodgev(height=0.3))+
      geom_point(position = ggstance::position_dodgev(height=0.3),shape=21,color="black",size=2)+
    theme_bw()+
    geom_vline(xintercept=0,linetype="dashed",alpha=0.3)+
  labs(x="Log odds Ratio for first symptom \ncompared with White ethnicity",
       y="Symptom",
       fill="Ethnicity")+
  scale_fill_manual(values = c("#377EB8","#4DAF4A"))+
  theme(legend.position="top")+
  annotate("label",x=1,y=5.5,label="More common",size=3)+
  annotate("label",x=-1,y=5.5,label="Less common",size=3)

plot_fx("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/first_sx_regression_coefs.png",
        p,plotwidth=4)


  

population %>% 
  dplyr::count(symptom_simple) %>%
  filter(!is.na(symptom_simple)) %>%
  mutate(prop = n/sum(n)) %>%
  arrange(desc(n))

summary(model)$coefficients %>% data.frame()
population %>% 
    group_by(ethnicity_simple) %>%
    dplyr::count(symptom_simple) %>%
    filter(!is.na(symptom_simple)) %>%
    mutate(prop = n/sum(n)) %>%
    arrange(desc(n))


population %>%
  group_by(ethnicity_simple) %>%
  dplyr::count(high_eff_dmt) %>%
  mutate(prop = n/sum(n))

Cross-sectional analysis

get_missing_prop("uni_education")
get_missing_prop("pms_onset")
get_missing_prop("Gender")
get_missing_prop("age_at_dx")
get_missing_prop(x="pms_onset")
get_missing_prop(x="had_dmt")


# make models 
bootstrap_linreg = function(
    dat = population,
    outcome_age ="age_at_msis",
    outcome="MSIS29",
    n_boot = 1000){
  
  # remove NAs 
  dat = dat %>% filter(!is.na(.data[[outcome]]) &
                         !is.na(.data[[outcome_age]]) &
                         !is.na(pms_onset) &
                         !is.na(Gender) &
                         !is.na(ethnicity_simple)
  )
                       
  # get N 
  n_sample = nrow(dat)
     
  # main model with full dataset 
  model = lm(data=dat,
              dat[[outcome]] ~  pms_onset + dat[[outcome_age]] + Gender + ethnicity_simple
  )
  
  coefs = data.frame(summary(model)$coefficients)
  coefs$var = rownames(coefs)
  coefs = dplyr::select(coefs,var,Estimate)
  summ = summary(model)
  
  bootstrapped_coefs = foreach(i=1:n_boot,.packages = c("dplyr","ggplot2")) %dopar% {
  
  # resample data
  resampled_dat = dplyr::sample_n(dat, size = n_sample, replace = TRUE)
  
  model = lm(data=resampled_dat,
              resampled_dat[[outcome]] ~  pms_onset + resampled_dat[[outcome_age]] + Gender + ethnicity_simple
  )
  
  coefs = data.frame(summary(model)$coefficients)
  coefs$var = rownames(coefs)
  coefs = dplyr::select(coefs,var,Estimate)
  coefs = coefs %>% mutate(var = ifelse(var=="resampled_dat[[outcome_age]]","outcome_age",var))
  coefs$adj_r2 = summary(model)$adj.r.squared
  coefs$model_num = i
  coefs
  }
  
  # combine coefs 
  bootstrapped_coefs = do.call("bind_rows",bootstrapped_coefs)
  coefs = coefs %>% mutate(var = ifelse(var=="dat[[outcome_age]]","outcome_age",var))
  
  # get CI for R2
  r2_vals = bootstrapped_coefs %>%
    distinct(model_num,adj_r2) %>%
    summarise("median_r2" = median(adj_r2),
              "lower_ci_r2" = quantile(adj_r2,0.025),
              "upper_ci_r2" = quantile(adj_r2,0.975))
  # get CI
  bootstrapped_coefs = bootstrapped_coefs %>%
    dplyr::select(-adj_r2) %>%
    group_by(var) %>%
    summarise(median = median(Estimate),
              mean = mean(Estimate),
              sd = sd(Estimate),
              lower_ci = quantile(Estimate,0.025),
              upper_ci = quantile(Estimate,0.975)
              ) %>%
    mutate(z = mean/sd) %>%
    mutate(pval = 1 - pnorm(abs(z))) %>%
    left_join(coefs %>% 
                dplyr::rename("Main_estimate" = Estimate),
              by="var") %>%
    mutate(outcome_var = outcome) %>%
    mutate(main_adjusted_r2 = summ$adj.r.squared) %>%
    cbind(r2_vals)
  
  bootstrapped_coefs
}

# run for different outcomes
edss_model = bootstrap_linreg(outcome = "baseline_edss",outcome_age = "baseline_ageatedss")
msis_model = bootstrap_linreg(outcome = "msis",outcome_age = "age_at_msis")
eq5d_model =  bootstrap_linreg(outcome = "eq5d_vas",outcome_age = "age_at_eq5d")
fss_model =  bootstrap_linreg(outcome = "fss",outcome_age = "age_at_fss")
msws_model = bootstrap_linreg(outcome = "msws",outcome_age = "age_at_msws")

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs.csv")

# make clean results output for table
res_df = res_df %>%
  mutate(outcome_var = ifelse(outcome_var=="eq5d_vas","EQ5D",outcome_var)) %>%
  mutate(outcome_var = ifelse(outcome_var=="baseline_edss","EDSS",outcome_var))
rownames(res_df)=NULL
res_df$ethnicity = str_remove(res_df$var,"ethnicity_simple")
res_df$Beta = res_df$Main_estimate

# make table 2
simplify_p = function(x){
  ifelse(x<0.001,"<0.001",round(x,2))
}

table_2= res_df %>% 
  filter(var != "(Intercept)") %>%
  dplyr::select(var,outcome_var,Main_estimate,lower_ci,upper_ci,main_adjusted_r2,lower_ci_r2,upper_ci_r2,pval) %>%
  mutate("Model fit Adjusted R2" = 
           paste0(round(main_adjusted_r2,3)," (",round(lower_ci_r2,2)," - ",round(upper_ci_r2,2),")")) %>%
  dplyr::select(-main_adjusted_r2,-lower_ci_r2,-upper_ci_r2) %>%
  mutate("Beta" = 
           paste0(round(Main_estimate,2)," (",round(lower_ci,2)," - ",round(upper_ci,2),")")) %>%
  dplyr::select(-Main_estimate,-lower_ci,-upper_ci) %>%
  mutate(variable = str_remove(var,"ethnicity_simple")) %>%
  dplyr::select(6,2,5,3,4) %>%
  mutate(pval = simplify_p(pval)) %>%
  mutate(variable = dplyr::case_when(
    variable == "GenderMALE" ~ "Male gender",
    variable == "outcome_age" ~ "Age",
    variable == "pms_onsetY" ~ "PPMS",
    variable =="Black" ~ "Black ethnicity",
    variable =="S_Asian" ~ "South Asian ethnicity"
  ))

write_csv(table_2,"table_2.csv")
glimpse(res_df)

# plot all coefficients 


p=ggplot(res_df %>% filter(var != "(Intercept)"),
         aes(outcome_var,Beta,color=var))+
  geom_errorbar(mapping = aes(ymin = lower_ci,ymax=upper_ci,x=outcome_var),width=0.1,position = position_dodge(width = 0.4))+
  geom_point(position = position_dodge(width = 0.4))+
  coord_flip()+
  scale_color_brewer(palette="Set1")+
  theme_minimal()+
  geom_hline(yintercept=0,alpha=1,linetype="dashed")+
  labs(x="Severity measure",y="Beta (estimated between-group difference)")
plot_fx("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/all_coefs.png",p,plotwidth=8)

# clean results table
res_df = res_df %>%
  filter(grepl("ethnicity",var))
res_df %>% dplyr::select(
  ethnicity,outcome_var,Main_estimate,contains("ci"),pval,-contains("r2")
)
# plot results 

p=ggplot(res_df %>%
           mutate(outcome_var = toupper(outcome_var),
                  ethnicity = ifelse(ethnicity == "S_Asian","South Asian",ethnicity)),
         aes(Beta,outcome_var,
                      col=ethnicity,fill=ethnicity))+
    geom_errorbarh(show.legend = F,mapping = aes(xmin = lower_ci, 
                                 xmax = upper_ci,
                                 y=outcome_var),
                   height=0.3,
                   color="black",
                   position = ggstance::position_dodgev(height=0.3))+
      geom_point(position = ggstance::position_dodgev(height=0.3),shape=21,color="black",size=2)+
    theme_bw()+
    geom_vline(xintercept=0,linetype="dashed",alpha=0.3)+
  labs(x="Mean difference in severity score\nvs White ethnicity",
       y="Severity score",
       fill="Ethnicity")+
  scale_fill_manual(values = c("#377EB8","#4DAF4A"))+
  theme(legend.position="top")

# raw score distros
plot_dat = population %>%
  mutate(ethnicity_simple = as.character(ethnicity_simple)) %>%
  mutate(ethnicity_simple = ifelse(ethnicity_simple == "S_Asian","South Asian",ethnicity_simple))
plot_dat$ethnicity_simple = factor(plot_dat$ethnicity_simple,levels = c("White","Black","South Asian"),ordered=T)
plot_dat = plot_dat = plot_dat %>%
  dplyr::select(UserId,ethnicity_simple,eq5d_vas,msis,
                baseline_edss,baseline_gARMSS, fss, msws) %>%
  pivot_longer(cols = c(3:8)) %>%
  mutate(name = case_when(
    name == "eq5d_vas" ~ "EQ5D",
    name == "msis" ~ "MSIS29",
    name == "baseline_edss" ~ "EDSS",
    name == "baseline_gARMSS" ~ "gARMSS",
    name == "fss" ~ "FSS",
    name == "msws" ~ "MSWS"
  ))

plots = list()
for(measure in unique(plot_dat$name)){
plots[[length(plots)+1]] = 
  ggplot(plot_dat %>%
         filter(name == measure),
       aes(ethnicity_simple,value,fill=ethnicity_simple))+
  geom_boxplot(outlier.shape=NA)+
  theme_bw()+
  scale_fill_brewer(palette="Set1")+
  labs(y=measure,x="Ethnicity")+
  theme(legend.position = "none")
}
library(gridExtra)
p2=do.call("grid.arrange",plots)

plot_fx("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/severity_metrics.png",grid.arrange(p,p2,nrow=1),
        plotheight = 5,
        plotwidth = 10)  

Cross-sectional analysis - sensitivity analyses

##########################
# sensitivity analysis
##########################

bootstrap_linreg2 = function(
    dat = population,
    outcome,
    covars,
    n_boot = 1000){
  
  # remove NAs 
  dat = dat %>% filter(!is.na(.data[[outcome]]) &
                         !is.na(ethnicity_simple))
                       
  for(i in c(1:length(covars))){
    message("Filtering NAs for ", covars[i])
    dat <<- dat %>%
      filter(!is.na(.data[[covars[i]]]))
  }
  
  
  # get N 
  n_sample = nrow(dat)
  
  # main model with full dataset 
  model_formula = paste0(outcome," ~ ",
                   paste0(covars,collapse="+"),
                   " + ethnicity_simple"
  )
  model = lm(data=dat,
             formula = model_formula
  )
  
  coefs = data.frame(summary(model)$coefficients)
  coefs$var = rownames(coefs)
  coefs = dplyr::select(coefs,var,Estimate)
  summ = summary(model)
  
  bootstrapped_coefs = foreach(i=1:n_boot,.packages = c("dplyr","ggplot2")) %dopar% {
    
    # resample data
    resampled_dat = dplyr::sample_n(dat, size = n_sample, replace = TRUE)
    
    model_formula = formula(paste0(outcome," ~ ",
                           paste0(covars,collapse="+"),
                           " + ethnicity_simple"))
    model = lm(data=resampled_dat,
               formula = model_formula
    )
    
    coefs = data.frame(summary(model)$coefficients)
    coefs$var = rownames(coefs)
    coefs = dplyr::select(coefs,var,Estimate)
    coefs$adj_r2 = summary(model)$adj.r.squared
    coefs$model_num = i
    coefs
  }
  
  # combine coefs 
  bootstrapped_coefs = do.call("bind_rows",bootstrapped_coefs)
  
  # get CI for R2
  r2_vals = bootstrapped_coefs %>%
    distinct(model_num,adj_r2) %>%
    summarise("median_r2" = median(adj_r2),
              "lower_ci_r2" = quantile(adj_r2,0.025),
              "upper_ci_r2" = quantile(adj_r2,0.975))
  # get CI
  bootstrapped_coefs = bootstrapped_coefs %>%
    dplyr::select(-adj_r2) %>%
    group_by(var) %>%
    summarise(median = median(Estimate),
              mean = mean(Estimate),
              sd = sd(Estimate),
              lower_ci = quantile(Estimate,0.025),
              upper_ci = quantile(Estimate,0.975)
    ) %>%
    mutate(z = mean/sd) %>%
    mutate(pval = 1 - pnorm(abs(z))) %>%
    left_join(coefs %>% 
                dplyr::rename("Main_estimate" = Estimate),
              by="var") %>%
    mutate(outcome_var = outcome) %>%
    mutate(main_adjusted_r2 = summ$adj.r.squared) %>%
    cbind(r2_vals)
  
  bootstrapped_coefs
}

# run for different outcomes
edss_model = bootstrap_linreg2(outcome = "baseline_edss",
                               covars = c("baseline_ageatedss","Gender","pms_onset","age_at_dx"))
msis_model = bootstrap_linreg2(outcome = "msis",
                              covars = c("age_at_msis","Gender","pms_onset","age_at_dx"))
eq5d_model =  bootstrap_linreg2(outcome = "eq5d_vas",
                               covars = c("age_at_eq5d","Gender","pms_onset","age_at_dx"))
fss_model =  bootstrap_linreg2(outcome = "fss",
                               covars = c("age_at_fss","Gender","pms_onset","age_at_dx"))
msws_model = bootstrap_linreg2(outcome = "msws",
                              covars = c("age_at_msws","Gender","pms_onset","age_at_dx"))

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs_sensitivity1.csv")
res_df %>%
  filter(grepl("ethnicity",var)) %>% 
  dplyr::select(
  var,outcome_var,Main_estimate,contains("ci"),pval,-contains("r2")
)

# run for different outcomes
edss_model = bootstrap_linreg2(outcome = "baseline_edss",
                               covars = c("baseline_ageatedss","Gender","pms_onset","had_dmt"))
msis_model = bootstrap_linreg2(outcome = "msis",
                               covars = c("age_at_msis","Gender","pms_onset","had_dmt"))
eq5d_model =  bootstrap_linreg2(outcome = "eq5d_vas",
                                covars = c("age_at_eq5d","Gender","pms_onset","had_dmt"))
fss_model =  bootstrap_linreg2(outcome = "fss",
                               covars = c("age_at_fss","Gender","pms_onset","had_dmt"))
msws_model = bootstrap_linreg2(outcome = "msws",
                               covars = c("age_at_msws","Gender","pms_onset","had_dmt"))

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs_sensitivity2.csv")
res_df %>%
  filter(grepl("ethnicity",var)) %>% 
  dplyr::select(
    var,outcome_var,Main_estimate,contains("ci"),pval,-contains("r2")
  )

# run for different outcomes
edss_model = bootstrap_linreg2(outcome = "baseline_edss",
                               covars = c("baseline_ageatedss","Gender","pms_onset","uni_education"))
msis_model = bootstrap_linreg2(outcome = "msis",
                               covars = c("age_at_msis","Gender","pms_onset","uni_education"))
eq5d_model =  bootstrap_linreg2(outcome = "eq5d_vas",
                                covars = c("age_at_eq5d","Gender","pms_onset","uni_education"))
fss_model =  bootstrap_linreg2(outcome = "fss",
                               covars = c("age_at_fss","Gender","pms_onset","uni_education"))
msws_model = bootstrap_linreg2(outcome = "msws",
                               covars = c("age_at_msws","Gender","pms_onset","uni_education"))

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs_sensitivity3.csv")
res_df %>%
  filter(grepl("ethnicity",var)) %>% 
  dplyr::select(
    var,outcome_var,Main_estimate,contains("ci"),pval,-contains("r2")
  )

# run for different outcomes
edss_model = bootstrap_linreg2(outcome = "baseline_edss",
                               covars = c("baseline_ageatedss","Gender","pms_onset","DiagnosisYear"))
msis_model = bootstrap_linreg2(outcome = "msis",
                               covars = c("age_at_msis","Gender","pms_onset","DiagnosisYear"))
eq5d_model =  bootstrap_linreg2(outcome = "eq5d_vas",
                                covars = c("age_at_eq5d","Gender","pms_onset","DiagnosisYear"))
fss_model =  bootstrap_linreg2(outcome = "fss",
                               covars = c("age_at_fss","Gender","pms_onset","DiagnosisYear"))
msws_model = bootstrap_linreg2(outcome = "msws",
                               covars = c("age_at_msws","Gender","pms_onset","DiagnosisYear"))

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs_sensitivity4.csv")
res_df %>%
  filter(grepl("ethnicity",var)) %>% 
  dplyr::select(
    var,outcome_var,Main_estimate,contains("ci"),pval,-contains("r2")
  )


##########################
# matched analysis
##########################
matched_population = readRDS("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/datasets/matched_cohort.rds")

# run for different outcomes
edss_model = bootstrap_linreg(dat = matched_population,outcome = "baseline_edss",outcome_age = "baseline_ageatedss",n_boot=1000)
msis_model = bootstrap_linreg(dat = matched_population,outcome = "msis",outcome_age = "age_at_msis",n_boot=1000)
eq5d_model =  bootstrap_linreg(dat = matched_population,outcome = "eq5d_vas",outcome_age = "age_at_eq5d",n_boot=1000)
fss_model =  bootstrap_linreg(dat = matched_population,outcome = "fss",outcome_age = "age_at_fss",n_boot=1000)
msws_model = bootstrap_linreg(dat = matched_population,outcome = "msws",outcome_age = "age_at_msws",n_boot=1000)

# combine results 
res_df = bind_rows(edss_model,msis_model,eq5d_model,fss_model,msws_model)
write_csv(res_df,"model_outputs_sensitivity5.csv")

# read in all sensitivity analyses
model_list = c("Age at diagnosis","DMT","University education","Year of diagnosis","Matched","Primary analysis")
overall_res = list()
for(i in c(1:5)){
  overall_res[[i]] = read_csv(paste0("model_outputs_sensitivity",i,".csv")) %>%
    mutate(model_name = model_list[i])
} 
overall_res[[length(overall_res)+1]] = read_csv("model_outputs.csv") %>%
  mutate(model_name = "Primary analysis")

# bind 
overall_res = do.call("bind_rows",overall_res)

# clean 
plot_dat = overall_res %>%
  filter(grepl("ethnicity_simple",var)) %>%
  mutate(variable = str_remove(var,"ethnicity_simple")) %>%
  dplyr::select(variable,outcome_var,Main_estimate,lower_ci,upper_ci,model_name) %>%
  dplyr::rename("outcome"=outcome_var) %>%
  mutate(outcome = ifelse(outcome == "baseline_edss","EDSS",outcome)) %>%
  mutate(outcome = ifelse(outcome == "eq5d_vas","EQ5D",outcome))

plot_dat$model_name = factor(plot_dat$model_name,
                             levels = c("Age at diagnosis",
  "DMT",
  "Matched",
  "University education",
  "Year of diagnosis",
  "Primary analysis"),
  ordered = T)

p=ggplot(plot_dat,
       aes(Main_estimate,model_name,color=variable))+
  geom_point(position = ggstance::position_dodgev(height=0.5))+
  facet_wrap(~outcome,nrow=1)+
  theme_minimal()+
  geom_errorbarh(mapping = aes(xmin = lower_ci,xmax = upper_ci,y=model_name),height=0.3,
                 position = ggstance::position_dodgev(height=0.5))+
  scale_color_brewer(palette="Set1")+
  geom_vline(xintercept=0,linetype="dashed")+
  labs(x="Beta coefficient (95% CI)",y="Model",color="Ethnicity")
plot_fx("forest_sensitivity.png",p,plotwidth = 8)       

Power calcs

population %>% dplyr::count(ethnicity_simple)
simulate_msis = function(n,mean,sd){
  vals = rnorm(n,mean, sd)
  vals = vals[vals>0 & vals <100]
  vals2 = rnorm(n - length(vals),mean, sd)
  vals = c(vals,vals2)
  vals
}

diff = 10
simulate_power = function(diff = 10, n_alt = 151){
  pvals = list()
  for(i in c(1:1000)){
    df1 = data.frame(
    score = simulate_msis(13884,40,26),
    group = "ref"
  )
  df2 = data.frame(
    score = simulate_msis(n_alt,40+diff,26),
    group = "alt"
  )
  
  combo = df1 %>% bind_rows(df2)
  pval = summary(lm(data=combo,score ~ group))$coefficients[2,4]
  pvals[[i]] = pval
  }
  power = sum(unlist(pvals)<0.05) / 1000
  power
}

Prospective (longitudinal) analysis

# filter to ppl with >1 reading
msis29_all_readings = msis29_all_readings %>%
  dplyr::group_by(UserId) %>%
  arrange(msis_date) %>%
  mutate(sequence = row_number()) %>%
  ungroup()


# fu 
fu_msis = msis29_all_readings %>% 
  filter(sequence != 1)


# baseline msis
baseline_msis = msis29_all_readings %>% 
  filter(UserId %in% fu_msis$UserId) %>%
  filter(sequence == 1) %>%
  dplyr::rename("baseline_msis" = msis, "baseline_date" = msis_date) %>%
  mutate(min_increment = baseline_msis + 10)

# join 
all_msis = baseline_msis %>%
  left_join(fu_msis,by="UserId")

# filter to those in population file 
all_msis = all_msis %>%
  filter(UserId %in% population$UserId)

nrow(all_msis %>% distinct(UserId))

# find those with >=2 readings 
counts = all_msis %>%
  group_by(UserId) %>%
  slice_max(sequence.y)
counts %>% ungroup() %>% filter(sequence.y<2) %>% distinct(UserId) %>% nrow()
ids_to_keep = counts %>%filter(sequence.y>=2) %>% distinct(UserId)

# filter to people with >=2 readings
all_msis = all_msis %>%
  filter(UserId %in% ids_to_keep$UserId)
nrow(all_msis %>% distinct(UserId))

# filter to people who have a reading >1 yr from baseline
fu_readings = all_msis %>%
  filter(as.numeric(msis_date - baseline_date)/365.25 > 1 ) 
all_msis = all_msis %>% filter(UserId %in% fu_readings$UserId)
nrow(all_msis %>% distinct(UserId))

# filter to only readings within 5y of baseline
all_msis = all_msis %>%
  filter(as.numeric(msis_date - baseline_date)/365.25 < 5 ) 

nrow(all_msis %>% distinct(UserId))

# see who has had progression
all_msis = all_msis %>%
  mutate(above_ceiling = ifelse(msis > min_increment,"y","n")) 

# filter on baseline msis 
all_msis %>% filter(baseline_msis > 90) %>% distinct(UserId) %>% nrow()
all_msis= all_msis %>% 
  filter(baseline_msis <= 90) 
nrow(all_msis %>% distinct(UserId))

# find people who experience disability progression 
# get earliest date above ceiling
progressors = all_msis %>% 
  filter(above_ceiling=="y") %>%
  group_by(UserId) %>%
  slice_min(msis_date,with_ties = F) %>%
  ungroup()

# join with all readings & filter to readings >6 months after ceiling breached
sustained_progressors = progressors %>%
  mutate(ceiling_date = msis_date, ceiling_msis = msis) %>%
  dplyr::select(UserId,ceiling_msis,ceiling_date) %>%
  left_join(
    all_msis %>%
      filter(UserId %in% progressors$UserId),
    by="UserId") %>%
  filter(as.numeric(msis_date - ceiling_date)/365.25 > (3/12))
sustained_progressors %>% distinct(UserId) %>% nrow()

confirmed_msis_progression = sustained_progressors %>% 
  filter(msis >= min_increment) %>%
  distinct(UserId)

progressors = progressors %>%
  mutate(sustained_progression = case_when(
    UserId %in% confirmed_msis_progression$UserId ~ "Yes",
    UserId %in% sustained_progressors$UserId & !(UserId %in% confirmed_msis_progression$UserId) ~ "No"
  ))

progressors %>% 
  dplyr::count(sustained_progression) %>%
  mutate(total = sum(n), prop = n/sum(n))

# filter 
progressors = progressors %>% filter(sustained_progression=="Yes")

non_progressors = all_msis %>% 
  group_by(UserId) %>%
  slice_max(msis_date,with_ties = F) %>%
  ungroup() %>%
  filter(!UserId %in% progressors$UserId) %>%
  mutate(sustained_progression="No")

combo_dat = bind_rows(progressors,non_progressors)

combo_dat %>% 
  dplyr::count(sustained_progression) %>%
  mutate(total = sum(n),prop = n/sum(n))

# get time 
combo_dat = combo_dat %>%
  mutate(observed_time = as.numeric(msis_date - baseline_date)/365.25)

# join with main data frame 
msis_survival = population %>%
  filter(UserId %in% combo_dat$UserId) %>%
  left_join(combo_dat,by="UserId") %>%
  filter(!is.na(observed_time)) %>%
  mutate(status = ifelse(
    sustained_progression=="Yes",2,1)) %>%
  mutate(survival_time = observed_time  * 365.25)

msis_survival %>%
  distinct(UserId)
summary(msis_survival$observed_time)

msis_survival
msis_survival %>% dplyr::count(status) %>%
  mutate(prop = n/sum(n))

# plot fu
ggplot(msis_survival,
       aes(ethnicity_simple,observed_time,fill=factor(status)))+
  geom_boxplot(outlier.shape=NA)


library(survival)
library(survminer)

# basic model 
surv_model = survfit(Surv(survival_time,status) ~ ethnicity_simple,
        data=msis_survival)

pal = RColorBrewer::brewer.pal(3,"Set1")
p=survminer::ggsurvplot(surv_model,
                      conf.int = T,
                      fun="event",
                      conf.int.style="ribbon",
                      conf.int.alpha=0.3,
                      risk.table = T,
                      risk.table.y.text=F,
                      legend.labs = c("White","Black","South Asian"),legend.title="Ethnicity",
                      legend = c(0.1,0.9),
                      palette = pal,ylab="Cumulative probability of\ndisability progression",xlab="Time from baseline MSIS (years)",
                      xscale=365.25,
                      break.x.by=365.25,fontsize=4
                        )


# cox models 
cox_model = coxph(Surv(survival_time,status) ~ ethnicity_simple,
        data=msis_survival)
summary(cox_model)

cox_model0 = coxph(Surv(survival_time,status) ~ baseline_msis + ethnicity_simple,
        data=msis_survival)
summary(cox_model0)
cox_model1 = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + baseline_msis,
        data=msis_survival)
summary(cox_model1)

cox_model2 = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + pms_onset + baseline_msis,
        data=msis_survival)
summary(cox_model2)

cox_model3 = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + pms_onset + age_at_msis + baseline_msis ,
        data=msis_survival)
summary(cox_model3)

forest_from_cox = function(x){
  output = data.frame(summary(x)$coefficients) %>% 
    mutate(variable = rownames(summary(x)$coefficients))
    colnames(output) = c("beta","hr","se","z","pval","variable")
    output = output %>% filter(variable != "(Intercept)")
    
    #  rename 
    output = output %>%
        mutate(variable = case_when(
            variable == "GenderMALE" ~ "Male gender (vs Female)",
            variable == "pms_onsetY" ~ "PPMS (vs RMS)",
            variable == "ethnicity_simpleBlack" ~ "Black ethnicity (vs White)",
            variable == "ethnicity_simpleS_Asian" ~ "South Asian ethnicity (vs White)",
            variable == "age_at_msis" ~ "Age at baseline MSIS",
            variable == "baseline_msis" ~ "Baseline MSIS"
        ))
    
  # make plot 
    plot_dat = output
    plot_dat$variable = factor(plot_dat$variable,levels = unique(plot_dat$variable),ordered=T)
p=ggplot(plot_dat,aes(exp(beta),variable))+
    geom_errorbarh(show.legend = F,mapping = aes(xmin = exp(beta- 1.96*se), 
                                                 xmax = exp(beta + 1.96*se),
                                                 y=variable),
                   height=0.3,
                   color="black",
                   position = ggstance::position_dodgev(height=0.3))+
    geom_point(position = ggstance::position_dodgev(height=0.3),color="black",size=2)+
    theme_bw()+
    geom_vline(xintercept=1,linetype="dashed",alpha=0.3)+
    labs(x="Hazard ratio of 5-year\ndisability progression",
         y="Variable")+
    theme(legend.position="top")+
  scale_x_log10()
p

    

  
}

p2 = forest_from_cox(cox_model3)
png("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/survival_curves.png",
        res=900,height=6,width = 8,units="in")
p
dev.off()
png("S:/ADAMS_Study - ADAMS Study/ethnicity_severity_project/outputs/survival_cox_models.png",
        res=900,height=4,width = 4,units="in")
p2
dev.off()




cox_model3_matched = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + pms_onset + age_at_msis + baseline_msis ,
        data=msis_survival %>% filter(UserId %in% matched_population$UserId))
summary(cox_model3_matched)

cox_model3_matched_simple = coxph(Surv(survival_time,status) ~ ethnicity_simple,
        data=msis_survival %>% filter(UserId %in% matched_population$UserId))
summary(cox_model3_matched_simple)

msis_survival %>%
  group_by(ethnicity_simple) %>%
  dplyr::count(status) %>%
  mutate(total = sum(n),prop = n/sum(n))


cox_model4 = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + pms_onset + high_eff_dmt + baseline_msis + age_at_msis,
        data=msis_survival %>%
          mutate(high_eff_dmt = relevel(factor(high_eff_dmt), ref="no_high_eff_DMT")))
summary(cox_model4)

cox_model5 = coxph(Surv(survival_time,status) ~ ethnicity_simple + Gender + pms_onset + age_at_msis + age_at_dx,
        data=msis_survival)
summary(cox_model5)

cox_model6 = coxph(Surv(survival_time,status) ~ ethnicity_simple + high_eff_dmt,
        data=msis_survival)
summary(cox_model6)

# diagnostics 
cox.zph(cox_model3)
ggcoxzph(cox.zph(cox_model3))
ggcoxdiagnostics(cox_model3,type="dfbeta")