## File Name: pcbsm.R
## File version: 1.0

#' Compute Partial Credit Behavioral Selection Model (PCBSM) for polytomous and dichotomous items
#'
#' This function computes a frequentist PCBSM, potentially accounting for DIF on specified items
#'
#' @param df data.frame containing the data
#' @param items vector containing the names of columns where item responses are stored in df
#' @param grp string containing the name of the column where the group membership variable is stored in df
#' @param X vector of strings containing the name of additional adjustment variables to be included in the model
#' @param u vector of weights to be included in the model as a covariate to account for unobserved confounding. Can be obtained from the "select_weight" function extracting response residuals from a probit model with grp as dependent variable and confounders and instruments as independent variables.
#' @param dif.items vector containing the list of indexes in "items" corresponding to dif items
#' @param type.dif vector containing DIF form for each item specified in dif.items. 1 is homogeneous DIF, 0 is heterogeneous DIF
#' @param verbose set to TRUE to print a detailed output, FALSE otherwise
#' @param fit string determining the optimization algorithm. Values "ucminf" or "nlminb" ar recommended
#' @param method.theta string determining the estimation method for individual latent variable values. Either "eap", "mle" or "wle"
#' @return A data.frame containing various model outputs
#' @import vcrpart
#' @import PP
#' @export

pcbsm <- function(df=NULL,items=NULL,grp=NULL,u=NULL,X=NULL,dif.items=NULL,type.dif=NULL,verbose=T,fit="ucminf",method.theta="eap") {
  ##### Detecting errors

  if (any(!(items %in% colnames(df)))) {
    stop("ERROR: provided item name does not exist in df")
  }
  if (any(!(grp %in% colnames(df)))) {
    stop("ERROR: provided group variable name does not exist in df")
  }
  if (any(!is.null(grp))) {
    if (any(!(grp%in%colnames(df)))) {
      stop("ERROR: group name does not exist in df")
    }
  }
  if (any(is.null(grp))) {
    stop("ERROR: group variable required in PCSM. Please use PCM if no group variable is needed")
  }
  if (!is.null(dif.items) & length(dif.items)!=length(type.dif)) {
    stop('ERROR: type.dif is not the same length as dif.items')
  }
  if (!is.null(dif.items) & is.null(type.dif)) {
    warning("WARNING: no type.dif provided, assuming non-homogeneous DIF on all items")
  }
  if (!("id"%in%colnames(df))) {
    stop('ERROR: no column named id provided')
  }
  if ( any(apply(df[df[,grp]==0,items],2,max)<max(df[,items])) | any(apply(df[df[,grp]==1,items],2,max)<max(df[,items])) ) {
    if (fit=="ucminf") {
      fit <- "optim"
    }
  }
  ##### Analysis
  restab.diftype <- NULL
  se.beta <- NULL
  beta.ci <- NULL
  beta.p <- NULL
  nbitems <- length(items)
  items_o <- items
  colnames(df)[which(colnames(df)%in%items_o)] <- paste0("item",1:nbitems)
  items <- paste0("item",1:nbitems)
  # If no group
  if (is.null(grp)) {
    if (verbose) {
      cat('\n')
      cat("#################################################################################################\n")
      cat("######################################### FITTING MODEL #########################################\n")
      cat("#################################################################################################\n")
    }
    grp <- NULL
    # prepare data
    df <- df[,c('id',items)]
    colnames(df)[2:(length(colnames(df)))] <- paste0("item",seq(1,length(colnames(df))-1))
    df.long <- reshape(df,v.names=c("item"),direction="long",varying=c(items))
    colnames(df.long) <- c("id","item","resp")
    nbitems <- length(2:(length(colnames(df))))
    maxmod <- max(df[,2:(length(colnames(df)))])
    df.long$item <- factor(df.long$item,levels=seq(1,length(colnames(df))-1),ordered = F)
    df.long$resp <- factor(df.long$resp,0:maxmod,ordered=T)
    df.long$id <- factor(df.long$id)
    # fit pcm
    mod <- olmm(resp ~ 0 + ce(item) + re(0|id),data=df.long,family = adjacent(link = "logit"))
    comod <- coef(mod)
    # output results
    restab <- t(sapply(1:nbitems,function(x) comod[seq(x,length(comod)-1,nbitems)]))
    rownames(restab) <- paste0("item",1:nbitems)
    colnames(restab) <- paste0("delta_",1:maxmod)
    restab.dif <- NULL
    beta <- NULL
  }
  # If group
  else {
    grp <- df[,grp]
    df$grp <- grp

    # If group and DIF
    if (!is.null(dif.items)) {
      if (verbose) {
        cat('\n')
        cat("#################################################################################################\n")
        cat("######################################### FITTING MODEL #########################################\n")
        cat("#################################################################################################\n")
      }
      # prepare data
      uu <- df[,u]
      xx <- df[,X]
      df <- df[,c('id',items,"grp")]
      colnames(df)[2:(length(colnames(df))-1)] <- paste0("item",seq(1,length(colnames(df))-2))
      df.long <- reshape(df,v.names=c("item"),direction="long",varying=c(items))
      colnames(df.long) <- c("id","grp","item","resp")
      nbitems <- length(2:(length(colnames(df))-1))
      maxmod <- max(df[,2:(length(colnames(df))-1)])
      df.long$item <- factor(df.long$item,levels=seq(1,length(colnames(df))-2),ordered = F)
      df.long$resp <- factor(df.long$resp,0:maxmod,ordered=T)
      df.long$id <- factor(df.long$id)
      df$u <- uu
      df.long$u <- rep(uu,nbitems)

      # Create 1 dif column per dif item
      for (i in 1:length(dif.items)) {
        df.long[,paste0("dif",i)] <- ifelse(df.long$item==dif.items[i],1,0)
      }
      difvar <- sapply(1:length(dif.items),function(x) paste0("dif",x))
      difvar.unif <- difvar[type.dif==1]
      difvar.nonunif <- difvar[type.dif==0]
      # fit pcm
      k <- 1
      formudif <- "resp ~ 0 + ge(u"
      for (x in X) {
        df.long[,x] <- rep(xx[,k],nbitems)
        k <- k+1
        formudif <- paste0(formudif,"+",x)
      }
      formudif <- paste0(formudif,"+grp",ifelse(length(difvar.unif>0),"+",""),ifelse(length(difvar.unif>0),paste0(difvar.unif,":grp",collapse="+"),""),")+ce(item",ifelse(length(difvar.nonunif>0),"+",""),ifelse(length(difvar.nonunif)>0,paste0(difvar.nonunif,":grp",collapse="+"),""),")+re(0|id)")
      formudif <- as.formula(formudif)
      mod <- olmm(formudif,data=df.long,family = adjacent(link = "logit"),control=olmm_control(fit=fit))
      comod <- coef(mod)
      # output results
      nbcoef <- nbitems+length(difvar.nonunif)
      restab <- t(sapply(1:nbcoef,function(x) comod[seq(x,length(comod)-3-length(difvar.unif)-length(X),nbitems+length(difvar.nonunif))]))
      difcoef.unif <- NULL
      if (length(difvar.unif)>0) {
        difcoef.unif <- comod[(length(comod)-length(difvar.unif)):(length(comod)-1)]
        if (length(difvar.unif)!=1) {
          difcoef.unif <- as.matrix(difcoef.unif)
        } else {
          difcoef.unif <- t(as.matrix(difcoef.unif))
        }
        rname <- paste0("item",dif.items[type.dif==1])
        rownames(difcoef.unif) <- paste0("dif.",items_o[which(items%in%rname)])
        colnames(difcoef.unif) <- "gamma"
        difcoef.unif <- as.data.frame(difcoef.unif)
        for (k in 1:maxmod) {
          difcoef.unif[,paste0("gamma_",k)] <- difcoef.unif[,"gamma"]
        }
        difcoef.unif <- as.matrix(difcoef.unif[,2:ncol(difcoef.unif)])
      }
      difcoef.nonunif <- NULL
      if (length(difvar.nonunif)>0) {
        difcoef.nonunif <- restab[nbitems+c(1:length(difvar.nonunif)),]
        if (length(difvar.nonunif)==1) {
          difcoef.nonunif <- t(as.matrix(difcoef.nonunif))
        } else {
          difcoef.nonunif <- as.matrix(difcoef.nonunif)
        }
        rname <- paste0("item",dif.items[type.dif==0])
        rownames(difcoef.nonunif) <- paste0("dif.",items_o[which(items%in%rname)])
        colnames(difcoef.nonunif) <- paste0("gamma_",1:maxmod)
      }
      restab <- restab[1:nbitems,]
      rownames(restab) <- items_o
      colnames(restab) <- paste0("delta_",1:maxmod)
      restab.dif <- rbind(difcoef.nonunif,difcoef.unif)
      restab.diftype <- matrix(ifelse(type.dif==1,"HOMOGENEOUS","NON-HOMOGENEOUS"))
      restab.diftype <- noquote(restab.diftype)
      rownames(restab.diftype) <- rownames(restab.dif)
      colnames(restab.diftype) <- "dif.type"
      lambda <- as.numeric(comod["u"])
      beta <- as.numeric(comod["grp"])#+lambda
      beta <- -1*beta
      se.beta <- sqrt(vcov(mod)["grp","grp"])
      beta.ci <- c("2.5%"=beta-1.96*se.beta,"97.5%"=beta+1.96*se.beta)
      #se.beta <- sqrt(vcov(mod)["grp","grp"]+vcov(mod)["u","u"]+2*vcov(mod)["u","grp"])
      #beta.ci <- c("2.5%"=beta-1.96*se.beta,"97.5%"=beta+1.96*se.beta)
      names(beta.ci) <- c("2.5%","97.5%")
      beta.p <- 2*pnorm(-abs(beta/se.beta))
      beta <- as.numeric(beta)
      se.beta <- as.numeric(se.beta)
      beta.p <- as.numeric(beta.p)

    } else {
      # If group no DIF
      if (verbose) {
        cat('\n')
        cat("#################################################################################################\n")
        cat("######################################### FITTING MODEL #########################################\n")
        cat("#################################################################################################\n")
      }
      # prepare data
      uu <- df[,u]
      xx <- df[,X]
      xx <- as.data.frame(xx)
      colnames(xx) <- X
      df <- df[,c('id',items,"grp")]
      colnames(df)[2:(length(colnames(df))-1)] <- paste0("item",seq(1,length(colnames(df))-2))
      df.long <- reshape(df,v.names=c("item"),direction="long",varying=c(items))
      colnames(df.long) <- c("id","grp","item","resp")
      nbitems <- length(2:(length(colnames(df))-1))
      maxmod <- max(df[,2:(length(colnames(df))-1)])
      df.long$item <- factor(df.long$item,levels=seq(1,length(colnames(df))-2),ordered = F)
      df.long$resp <- factor(df.long$resp,0:maxmod,ordered=T)
      df.long$id <- factor(df.long$id)
      df$u <- uu
      df.long$u <- rep(uu,nbitems)
      k <- 1
      formu <- "resp ~ 0 + ge(u"
      for (x in X) {
        df.long[,x] <- rep(xx[,k],nbitems)
        k <- k+1
        formu <- paste0(formu,"+",x)
      }
      formu <- paste0(formu,"+grp) + ce(item) + re(0|id)")
      # fit pcm
      mod <- olmm(formula = as.formula(formu),data=df.long,family = adjacent(link = "logit"),control=olmm_control(fit=fit))
      comod <- coef(mod)
      # output results
      restab <- t(sapply(1:nbitems,function(x) comod[seq(x,length(comod)-3-length(X),nbitems)]))
      rownames(restab) <- items_o
      colnames(restab) <- paste0("delta_",1:maxmod)
      restab.dif <- NULL
      lambda <- as.numeric(comod["u"])
      beta <- as.numeric(comod["grp"])#+lambda
      beta <- -1*beta
      se.beta <- sqrt(vcov(mod)["grp","grp"])
      beta.ci <- c("2.5%"=beta-1.96*se.beta,"97.5%"=beta+1.96*se.beta)
      #se.beta <- sqrt(vcov(mod)["grp","grp"]+vcov(mod)["u","u"]+2*vcov(mod)["u","grp"])
      #beta.ci <- c("2.5%"=beta-1.96*se.beta,"97.5%"=beta+1.96*se.beta)
      names(beta.ci) <- c("2.5%","97.5%")
      beta.p <- 2*pnorm(-abs(beta/se.beta))
      se.beta <- as.numeric(se.beta)
      beta.p <- as.numeric(beta.p)
    }

  }
  if (method.theta=="eap") {
    theta <- c(-1*ranef(mod,norm=F)+ifelse(grp==1,beta,0))
  } else if (method.theta=="wle") {
    theta <- PP::PP_gpcm(as.matrix(df[,items]),t(restab),rep(1,length(items)))$resPP$resPP[,1]
  } else if (method.theta=="mle") {
    theta <- PP::PP_gpcm(as.matrix(df[,items]),t(restab),rep(1,length(items)),type="mle")$resPP$resPP[,1]
  }
  resid <- apply(matrix(1:nbitems,ncol=length(nbitems)),1, function(k) sapply(1:nrow(df), function(j) res_ij(theta[j],restab[k,],df[j,items[k]],beta=0)))
  colnames(resid) <- items_o

  ##### Output
  if (verbose) {
    cat(paste0('Number of individuals: ',nrow(df),"\n"))
    cat(paste0('Number of items: ',length(items),"\n"))
    cat(paste0('Item Thresholds and DIF parameters: ',"\n"))
  }


  out <- list(
    beta=beta,
    beta.se=se.beta,
    beta.ci=beta.ci,
    beta.p=beta.p,
    lambda=as.numeric(lambda),
    dif.items=dif.items,
    dif.type=restab.diftype,
    thresholds=restab,
    dif.param=restab.dif,
    theta=theta,
    residuals=resid
  )
  return(out)
}