## File Name: bpcm.R
## File version: 1.0


#' Compute Bayesian Partial Credit Model (BPCM) for polytomous and dichotomous items
#'
#' This function computes a bayesian PCM, potentially accounting for DIF on specified items
#'
#' @param df data.frame containing the response data
#' @param grp vector containing the column where an optional group membership variable is stored in df
#' @param is.dif indicator vector containing 1 at the inded of each DIF item in df and 0 otherwise
#' @param is.unif indicator vector containing 1 at the inded of each uniform DIF item in df and 0 otherwise
#' @param prior prior function to be used
#' @param diag.plots boolean indicating whether the JAGS diagnosis plots should be displayed
#' @return A data.frame containing various model outputs
#' @import rjags
#' @import dclone
#' @export

bpcm <- function(df=NULL,grp=NULL,is.dif=NULL,is.unif=NULL,priors=NULL,param=list(),verbose=T,diag.plots=F) {

  ##### Detecting errors
  if ( (sum(is.dif)!=0 | sum(is.unif)!=0) & any(is.null(grp))) {
    stop('ERROR: no group variable provided, but is.dif or is.unif are not NULL')
  }
  if ( sum(is.dif)==0 & sum(is.unif)!=0) {
    warning('WARNING: no DIF item specified, but is.unif is not NULL. Ignoring is.unif')
  }
  if ( sum(is.dif)!=0 ) {
    if (length(is.dif)!=ncol(df)) {
      stop('ERROR: is.dif must have one element per item in df')
    }
    if (length(is.unif)!=ncol(df)) {
      stop('ERROR: is.unif must have one element per item in df')
    }
    if ( !("m.gamma"%in%names(priors)) ) {
      stop('ERROR: DIF requested. Please provide mean of prior gamma distribution as priors$m.gamma')
    }
    if ( !("s.gamma"%in%names(priors)) ) {
      stop('ERROR: DIF requested. Please provide SD of prior gamma distribution as priors$s.gamma')
    }
  }
  if ( !any(is.null(grp)) ) {
    if ( !("m.beta"%in%names(priors)) ) {
      stop('ERROR: Group effect requested. Please provide mean of prior beta distribution as priors$m.beta')
    }
    if ( !("s.gamma"%in%names(priors)) ) {
      stop('ERROR: Group effect requested. Please provide SD of prior beta distribution as priors$s.beta')
    }
    if (nrow(df)!=length(grp)) {
      stop('ERROR: grp must be of length nrow(df)')
    }
  }
  if ( !("m.delta"%in%names(priors)) ) {
    stop('ERROR: Please provide mean of prior delta distribution as priors$m.delta')
  }
  if ( !("s.delta"%in%names(priors)) ) {
    stop('ERROR: Please provide SD of prior delta distribution as priors$s.delta')
  }


  if (verbose) {
    cat('\n')
    cat("#################################################################################################\n")
    cat("######################################### FITTING MODEL #########################################\n")
    cat("#################################################################################################\n")
  }

  ###############################################
  ############### NO GROUP EFFECT ###############
  ###############################################
  if (is.null(grp)) {
    nam <- colnames(df)
    nam_o <- nam
    Y <- matrix(unlist(df),nrow=nrow(df))+1
    n <- nrow(Y)
    p <- ncol(Y)
    K <- apply(Y,2,max,na.rm=TRUE)
    m.delta <- priors$m.delta
    s.delta <- priors$s.delta

    namlist <- c(sapply(1:p, function(x) paste0("delta[",x,',',paste0(1:K[p],"]"))))
    if (sum(is.dif)!=0) {
      namlist <- c(namlist,sapply(1:p, function(x) paste0("gamma[",x,',',paste0(1:K[p],"]"))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist <- c(namlist,"beta")
      }
    }

    data <- list(Y=Y,n=n,p=p,K=K,m.delta=m.delta,s.delta=s.delta)
    params <- c("delta")


    if ("n.burn"%in%names(param)) {
      n.burn <- param$n.burn
    } else {
      n.burn <- 4000
    }
    if ("n.thin"%in%names(param)) {
      n.thin <- param$n.thin
    } else {
      n.thin <- 20
    }
    if ("n.chains"%in%names(param)) {
      n.chains <- param$n.chains
    } else {
      n.chains <- 2
    }
    if ("n.sim"%in%names(param)) {
      n.sim <- param$n.sim
    } else {
      n.sim <- 20000
    }
    if ("nb.cores"%in%names(param)) {
      nb.cores <- param$nb.cores
    } else {
      nb.cores <- 6
    }
    cl <- parallel::makeForkCluster(nb.cores)

    if (verbose) {
      cat('Initializing Chains...')
    }
    dclone::parJagsModel(cl,name="mod.jags",data = data,inits = NULL,file = system.file("jags","bpcm.jags",package="SPT"),n.chains=n.chains)
    if (verbose) {
      cat('DONE \n')
      cat('Applying burn-in iterations...')
    }
    dclone::parUpdate(cl,object="mod.jags",n.burn=n.burn)
    if (verbose) {
      cat('DONE \n')
      cat('Running Markov Chains...')
    }
    mod.samples <- dclone::parCodaSamples(cl,model="mod.jags",variable.names=params,n.iter = n.sim,thin = n.thin)
    if (verbose) {
      cat('DONE \n')
    }
    parallel::stopCluster(cl)
    if (diag.plots) {
      cat("Displaying traceplots...")
      traceplot(mod.samples)
      readline(prompt="Use the arrows to navigate between traceplots. Press [enter] to continue")
      cat("DONE\n")
      cat("Displaying autocorrelation plots...")
      autocorr.plot(mod.samples,ask=F)
      readline(prompt="Use the arrows to navigate between autocorr plots. Press [enter] to continue")
      cat("DONE\n ")
    }
    res <- mod.samples[[1]]
    res <- as.data.frame(res)[,namlist]
    namlist2 <- unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1))))
    if (sum(is.dif)!=0) {
      namlist2 <- c(namlist2,unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1,":grp")))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist2 <- c(namlist2,"beta")
      }
    }
    colnames(res) <- namlist2
    res <- res[,namlist2]
    res <- res[,apply(res,2,function(x) all(x==0))==0]
    xsi <- apply(res,2,function(x) c(mean(x),sd(x),quantile(x,0.05),quantile(x,0.95)) )
    rownames(xsi) <- c("post.mean","post.sd","post.90.cred.low","post.90.cred.high")
    xsi <- round(t(xsi),4)
    if ("beta" %in% rownames(xsi)) {
      beta <- xsi[rownames(xsi)=="beta",]
      xsi <- xsi[rownames(xsi)!="beta",]
    }
    out <- list(mcmc.res=res,
                dif.items=nam_o[which(is.dif==1)],
                beta=beta,
                thresholds=xsi)


  #####################################################
  ############### GROUP EFFECT / NO DIF ###############
  #####################################################
  } else if (is.null(is.dif) | sum(is.dif)==0) {
    nam <- colnames(df)
    nam_o <- nam
    Y <- matrix(unlist(df),nrow=nrow(df))+1
    Z <- matrix(unlist(grp),nrow=length(grp))
    n <- nrow(Y)
    p <- ncol(Y)
    K <- apply(Y,2,max,na.rm=TRUE)
    m.delta <- priors$m.delta
    s.delta <- priors$s.delta
    m.beta <- priors$m.beta
    s.beta <- priors$s.beta

    namlist <- c(sapply(1:p, function(x) paste0("delta[",x,',',paste0(1:K[p],"]"))))
    if (sum(is.dif)!=0) {
      namlist <- c(namlist,sapply(1:p, function(x) paste0("gamma[",x,',',paste0(1:K[p],"]"))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist <- c(namlist,"beta")
      }
    }

    data <- list(Y=Y,Z=Z,n=n,p=p,K=K,m.beta=m.beta,s.beta=s.beta,m.delta=m.delta,s.delta=s.delta)
    params <- c("delta","beta")


    if ("n.burn"%in%names(param)) {
      n.burn <- param$n.burn
    } else {
      n.burn <- 4000
    }
    if ("n.thin"%in%names(param)) {
      n.thin <- param$n.thin
    } else {
      n.thin <- 20
    }
    if ("n.chains"%in%names(param)) {
      n.chains <- param$n.chains
    } else {
      n.chains <- 2
    }
    if ("n.sim"%in%names(param)) {
      n.sim <- param$n.sim
    } else {
      n.sim <- 20000
    }
    if ("nb.cores"%in%names(param)) {
      nb.cores <- param$nb.cores
    } else {
      nb.cores <- 6
    }
    cl <- parallel::makeForkCluster(nb.cores)

    if (verbose) {
      cat('Initializing Chains...')
    }
    dclone::parJagsModel(cl,name="mod.jags",data = data,inits = NULL,file = system.file("jags","bpcm_beta.jags",package="SPT"),n.chains=n.chains)
    if (verbose) {
      cat('DONE \n')
      cat('Applying burn-in iterations...')
    }
    dclone::parUpdate(cl,object="mod.jags",n.burn=n.burn)
    if (verbose) {
      cat('DONE \n')
      cat('Running Markov Chains...')
    }
    mod.samples <- dclone::parCodaSamples(cl,model="mod.jags",variable.names=params,n.iter = n.sim,thin = n.thin)
    if (verbose) {
      cat('DONE \n')
    }
    parallel::stopCluster(cl)
    if (diag.plots) {
      cat("Displaying traceplots...")
      traceplot(mod.samples)
      readline(prompt="Use the arrows to navigate between traceplots. Press [enter] to continue")
      cat("DONE\n")
      cat("Displaying autocorrelation plots...")
      autocorr.plot(mod.samples,ask=F)
      readline(prompt="Use the arrows to navigate between autocorr plots. Press [enter] to continue")
      cat("DONE\n ")
    }
    res <- mod.samples[[1]]
    res <- as.data.frame(res)[,namlist]
    namlist2 <- unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1))))
    if (sum(is.dif)!=0) {
      namlist2 <- c(namlist2,unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1,":grp")))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist2 <- c(namlist2,"beta")
      }
    }
    colnames(res) <- namlist2
    res <- res[,namlist2]
    res <- res[,apply(res,2,function(x) all(x==0))==0]
    xsi <- apply(res,2,function(x) c(mean(x),sd(x),quantile(x,0.05),quantile(x,0.95)) )
    rownames(xsi) <- c("post.mean","post.sd","post.90.cred.low","post.90.cred.high")
    xsi <- round(t(xsi),4)
    if ("beta" %in% rownames(xsi)) {
      beta <- xsi[rownames(xsi)=="beta",]
      xsi <- xsi[rownames(xsi)!="beta",]
    }
    out <- list(mcmc.res=res,
                dif.items=nam_o[which(is.dif==1)],
                beta=beta,
                thresholds=xsi)

  #####################################################
  ############### GROUP EFFECT / NO DIF ###############
  #####################################################
  } else {


    nam <- colnames(df)
    nam_o <- nam
    Y <- matrix(unlist(df),nrow=nrow(df))+1
    Z <- matrix(unlist(grp),nrow=length(grp))
    n <- nrow(Y)
    p <- ncol(Y)

    pnodif <- p-sum(is.dif)
    pnodif1 <- p-sum(is.dif)+1
    pdif <- sum(is.dif)
    pnounif <- pnodif+pdif-sum(is.unif)
    pnounif1 <- pnodif+pdif-sum(is.unif)+1

    K <- apply(Y,2,max,na.rm=TRUE)
    m.delta <- priors$m.delta
    s.delta <- priors$s.delta
    m.beta <- priors$m.beta
    s.beta <- priors$s.beta

    m.gamma <- priors$m.gamma
    s.gamma <- priors$s.gamma

    namlist <- c(sapply(1:p, function(x) paste0("delta[",x,',',paste0(1:K[p],"]"))))
    if (sum(is.dif)!=0) {
      namlist <- c(namlist,sapply(1:p, function(x) paste0("gamma[",x,',',paste0(1:K[p],"]"))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist <- c(namlist,"beta")
      }
    }

    Y <- Y[,c(which(is.dif+is.unif==0),which(is.dif+is.unif==1),which(is.dif+is.unif==2))]
    nam <- nam[c(which(is.dif+is.unif==0),which(is.dif+is.unif==1),which(is.dif+is.unif==2))]

    data <- list(Y=Y,Z=Z,n=n,p=p,pnounif=pnounif,pnounif1=pnounif1,pdif=pdif,pnodif1=pnodif1,pnodif=pnodif,K=K,m.beta=m.beta,s.beta=s.beta,m.gamma=m.gamma,s.gamma=s.gamma,m.delta=m.delta,s.delta=s.delta,difff=as.factor(is.dif),unif=as.factor(is.unif))
    params <- c("delta","gamma","beta")


    if ("n.burn"%in%names(param)) {
      n.burn <- param$n.burn
    } else {
      n.burn <- 4000
    }
    if ("n.thin"%in%names(param)) {
      n.thin <- param$n.thin
    } else {
      n.thin <- 20
    }
    if ("n.chains"%in%names(param)) {
      n.chains <- param$n.chains
    } else {
      n.chains <- 2
    }
    if ("n.sim"%in%names(param)) {
      n.sim <- param$n.sim
    } else {
      n.sim <- 20000
    }
    if ("nb.cores"%in%names(param)) {
      nb.cores <- param$nb.cores
    } else {
      nb.cores <- 6
    }
    cl <- parallel::makeForkCluster(nb.cores)


    if (verbose) {
      cat('Initializing Chains...')
    }
    dclone::parJagsModel(cl,name="mod.jags",data = data,inits = NULL,file = system.file("jags","bpcm_dif.jags",package="SPT"),n.chains=n.chains)
    if (verbose) {
      cat('DONE \n')
      cat('Applying burn-in iterations...')
    }
    dclone::parUpdate(cl,object="mod.jags",n.burn=n.burn)
    if (verbose) {
      cat('DONE \n')
      cat('Running Markov Chains...')
    }
    mod.samples <- dclone::parCodaSamples(cl,model="mod.jags",variable.names=params,n.iter = n.sim,thin = n.thin)
    if (verbose) {
      cat('DONE \n')
    }
    parallel::stopCluster(cl)
    if (diag.plots) {
      cat("Displaying traceplots...")
      traceplot(mod.samples)
      readline(prompt="Use the arrows to navigate between traceplots. Press [enter] to continue")
      cat("DONE\n")
      cat("Displaying autocorrelation plots...")
      autocorr.plot(mod.samples,ask=F)
      readline(prompt="Use the arrows to navigate between autocorr plots. Press [enter] to continue")
      cat("DONE\n ")
    }
    res <- mod.samples[[1]]
    res <- as.data.frame(res)[,namlist]
    namlist2 <- unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1))))
    if (sum(is.dif)!=0) {
      namlist2 <- c(namlist2,unlist(c(sapply(nam,function(x) paste0(x,"_",1:K[which(nam_o==x)]-1,":grp")))),"beta")
    } else {
      if (!is.null(grp)) {
        namlist2 <- c(namlist2,"beta")
      }
    }
    colnames(res) <- namlist2
    res <- res[,namlist2]
    res <- res[,apply(res,2,function(x) all(x==0))==0]
    xsi <- apply(res,2,function(x) c(mean(x),sd(x),quantile(x,0.05),quantile(x,0.95)) )
    rownames(xsi) <- c("post.mean","post.sd","post.90.cred.low","post.90.cred.high")
    xsi <- round(t(xsi),4)
    if ("beta" %in% rownames(xsi)) {
      beta <- xsi[rownames(xsi)=="beta",]
      if (is.null(grp)) {
        beta <- NA
      }
      xsi <- xsi[rownames(xsi)!="beta",]
    }
    out <- list(mcmc.res=res,
                dif.items=nam_o[which(is.dif==1)],
                beta=beta,
                thresholds=xsi)
  }
  if (is.null(dim(out$beta))) {
    out$beta <- NA
    }
  return(out)
}