#include "cnvscoring.h"

void cnvscoring::troba_cnvscoring()
{
  arma::uvec index1, index2, index3;
  //arma::uvec::fixed<3> n;
  
  index1 = arma::find(genotypes==1);
  index2 = arma::find(genotypes==3);
  index3 = arma::find(genotypes==2);
  
  arma::uvec utemp = arma::find(genotypes == 0);
  nz = utemp.n_elem;
 
  arma::vec scor1 = arma::zeros(INT1.n_elem);
  arma::vec scor2 = arma::zeros(INT2.n_elem);
  
  //Independent analysis for each channel
  channel_analysis(index1, index3, INT1, scor1, Q1);
  channel_analysis(index2, index3, INT2, scor2, Q2);
  
  SC = arma::vec(INT1.n_elem);
  SC.fill(-1.);
  utemp = arma::find(genotypes == 0);
  for(unsigned int it = 0; it < utemp.n_elem; it++)
  {
    SC(utemp(it)) = 0.;
  }
  for(unsigned int it = 0; it < index1.n_elem; it++)
  {
    SC(index1(it)) = scor1(index1(it));
  }
  for(unsigned int it = 0; it < index2.n_elem; it++)
  {
    SC(index2(it)) = scor2(index2(it));
  }
  double temp;
  for(unsigned int it = 0; it < index3.n_elem; it++)
  {
    temp = scor1(index3(it))>scor2(index3(it))?scor1(index3(it)):scor2(index3(it));
    SC(index3(it)) = temp + 1.;
  }
}

void cnvscoring::channel_analysis(arma::uvec index, arma::uvec index3, arma::vec INT, arma::vec & scor, arma::vec::fixed<3> & Qr)
{
  vmodel model1;
  vmodel model2;
  bool n = index.n_elem >= SMIN;
  bool n3 = index3.n_elem >= SMIN;
  arma::vec pr;
  
  model1.ncomp = -1;
  model1.Q.fill(-1);
  model2.ncomp = -1;
  model2.Q.fill(-1);
  Qr.fill(-1);
  
  //Homozygote and heterozygote models
  if(n)
  {
    model1 = model_selection( INT.elem(index), scg.elem(index), 0.6, 2, 0.8, 0.04 );
    Qr(0) = model1.Q(0);
    Qr(1) = model1.Q(1);
    Qr(2) = model1.Q(2);
  }
  if(n3)
  {
    model2 = model_selection( INT.elem(index3), scg.elem(index3), 0., 6, 0.4, 0.2 );
  }
  
  //Homozygote analysis
  if(n)
  {
    arma::vec INTindex = INT.elem(index);
    int kmod = 0;
    if(model1.ncomp == 1)
    {
      kmod = 2;
      if(n3)
      {
        //double dtemp1 = (double)index.n_elem;
        double dtemp2 = fabs(model2.Mean(0) - model1.Mean(0))/model1.Mean(0);
        if( ( (10*nz) > index.n_elem ) && ( dtemp2 < 0.1 ) )
        {
          kmod = 1;
        }
      }
      pr = INTindex - model1.Mean(0);
      pr = pr/arma::stddev(INTindex);
      pr = pr/stdCNP;
      arma::uvec utemp = arma::find(pr>=0.5);
      if(utemp.n_elem != 0)
      {
        if( arma::median( INT.elem( index.elem(utemp) ) ) < 1.3*model1.Mean(0) )
        {
          for(unsigned int it=0; it < utemp.n_elem; it++)
          {
            pr(utemp(it)) = 0.499;
          }
        }
      }
      utemp = arma::find(pr<=-0.5);
      if(utemp.n_elem != 0)
      {
        if( arma::median( INT.elem( index.elem(utemp) ) ) > 0.7*model1.Mean(0) )
        {
          for(unsigned int it=0; it < utemp.n_elem; it++)
          {
            pr(utemp(it)) = -0.499;
          }
        }
      }
      for(unsigned int it=0; it < pr.n_elem; it++)
      {
        if(pr(it) < -1.)
        {
          pr(it) = -1.;
        }
        if(pr(it) > 1.)
        {
          pr(it) = 1.;
        }
      }
      for(unsigned int it=0; it < index.n_elem; it++)
      {
        scor(index(it)) = (double)(double(kmod) + pr(it));
      }
    }
    else
    {
      int kmod1;
      int kmod2;
      if( model1.Prior(1) > model1.Prior(0) )
      {
        kmod1 = 1;
        kmod2 = 2;
      }
      else
      {
        if( (10*nz) > index.n_elem)
        {
          kmod1 = 1;
          kmod2 = 2;
        }
        else
        {
          kmod1 = 2;
          kmod2 = 3;
        }
      }
      arma::vec pr1 = model1.Prior(0)*pdfgauss(INTindex, model1.Mean(0), model1.Cov(0));
      arma::vec pr2 = model1.Prior(1)*pdfgauss(INTindex, model1.Mean(1), model1.Cov(1));
      arma::uvec i1 = arma::find( (INTindex >= model1.Mean(0)) % (INTindex <= model1.Mean(1)) );
      arma::vec temp1 = pr1.elem(i1);
      arma::vec temp2 = pr2.elem(i1);
      arma::vec temps = temp1 + temp2;
      arma::vec sc1 = temp1/temps;
      arma::vec sc2 = temp2/temps;
      
      arma::uvec utemp = index.elem(i1);
      for(unsigned int it=0; it < utemp.n_elem; it++)
      {
        scor(utemp(it)) = (double)(double(kmod1) + sc2(it));
      }
      if(kmod1==1)
      {
        for(unsigned int it=0; it < INTindex.n_elem; it++)
        {
          if(INTindex(it)<model1.Mean(0))
          {
            scor(index(it)) = 1;
          }
        }
        i1 = arma::find(INTindex > model1.Mean(1));
        arma::vec temp = INTindex.elem(i1);
        arma::vec g = temp - model1.Mean(1);
        g = join_cols(model1.Mean(1) + g, model1.Mean(1) - g);
        pr = temp - arma::mean(g);
        pr = pr/arma::stddev(g);
        pr = pr/stdCNP;
        arma::uvec utemp = arma::find(pr>=0.5);
        if(utemp.n_elem != 0)
        {
          if( arma::median( INTindex.elem( i1.elem(utemp) ) ) < 1.3*model1.Mean(1) )
          {
            for(unsigned int it=0; it < utemp.n_elem; it++)
            {
              pr(utemp(it)) = 0.499;
            }
          }
        }
        for(unsigned int it=0; it < pr.n_elem; it++)
        {
          if(pr(it) < -1.)
          {
            pr(it) = -1.;
          }
          if(pr(it) > 1.)
          {
            pr(it) = 1.;
          }
        }
        utemp = index.elem(i1);
        for(unsigned int it=0; it < utemp.n_elem; it++)
        {
          scor(utemp(it)) = (double)(double(kmod2) + pr(it));
        }
      }
      else
      {
        for(unsigned int it=0; it < INTindex.n_elem; it++)
        {
          if(INTindex(it)>model1.Mean(1))
          {
            scor(index(it)) = 3;
          }
        }
        i1 = arma::find(INTindex < model1.Mean(0));
        arma::vec temp = INTindex.elem(i1);
        arma::vec g = model1.Mean(0) - temp;
        g = join_cols(model1.Mean(0) + g, model1.Mean(0) - g);
        pr = temp - arma::mean(g);
        pr = pr/arma::stddev(g);
        pr = pr/stdCNP;
        arma::uvec utemp = arma::find(pr<=-0.5);
        if(utemp.n_elem != 0)
        {
          if( arma::median( INTindex.elem( i1.elem(utemp) ) ) > 0.7*model1.Mean(0) )
          {
            for(unsigned int it=0; it < utemp.n_elem; it++)
            {
              pr(utemp(it)) = -0.499;
            }
          }
        }
        for(unsigned int it=0; it < pr.n_elem; it++)
        {
          if(pr(it) < -1.)
          {
            pr(it) = -1.;
          }
          if(pr(it) > 1.)
          {
            pr(it) = 1.;
          }
        }
        utemp = index.elem(i1);
        for(unsigned int it=0; it < utemp.n_elem; it++)
        {
          scor(utemp(it)) = 2. + pr(it);
        }
      }
    }
  }
  else
  {
    //int kmod = 0;
    for(unsigned int it=0; it < index.n_elem; it++)
    {
      scor(index(it)) = 2.;
    }
  }
  
  //Heterozygote analysis
  if(n3)
  {
    if( model2.ncomp == 1 )
    {
      for(unsigned int it=0; it < index3.n_elem; it++)
      {
        scor(index3(it)) = 1.;
      }
    }
    else
    {
      arma::vec INTindex3 = INT.elem(index3);
      arma::vec pr1 = model2.Prior(0)*pdfgauss(INTindex3, model2.Mean(0), model2.Cov(0));
      arma::vec pr2 = model2.Prior(1)*pdfgauss(INTindex3, model2.Mean(1), model2.Cov(1));
      
      for(unsigned int it=0; it < index3.n_elem; it++)
      {
        if(INTindex3(it)<model2.Mean(0))
        {
          scor(index3(it)) = 1.;
        }
      }
      for(unsigned int it=0; it < index3.n_elem; it++)
      {
        if(INTindex3(it)>model2.Mean(1))
        {
          scor(index3(it)) = 2.;
        }
      }
      arma::uvec i1 = arma::find( (INTindex3 >= model2.Mean(0)) % (INTindex3 <= model2.Mean(1)) );
      arma::vec temp1 = pr1.elem(i1);
      arma::vec temp2 = pr2.elem(i1);
      arma::vec temps = temp1 + temp2;
      arma::vec sc1 = temp1/temps;
      arma::vec sc2 = temp2/temps;
      
      arma::uvec utemp = index3.elem(i1);
      for(unsigned int it=0; it < utemp.n_elem; it++)
      {
        scor(utemp(it)) = sc1(it) + 2.*sc2(it);
      }
      
    }
  }
  else
  {
    for(unsigned int it=0; it < index3.n_elem; it++)
    {
      scor(index3(it)) = scor(index3(it)) + 1.;
    }
  }
}

vmodel cnvscoring::model_selection(arma::vec Iprev, arma::vec sc, double scT, double klT1, double klT2, double klT3)
{
  vmodel resultat, model1, model2;
  //int tmax;
  //itpp::Array< itpp::vec > X;// = itpp::Array<itpp::vec>(1);
  //itpp::Array<double> X;

  arma::uvec utemp = arma::find(sc>scT);
  if(utemp.n_elem == 0)
  {
    resultat = model1;
    resultat.ncomp = 1;
    resultat.Q(0) = -1;
    resultat.Q(1) = -1;
    resultat.Q(2) = 1;
    return resultat;
  }
  
  arma::vec I = Iprev.elem( utemp );  
  double medianI = arma::median(I);
  double stdI = arma::stddev(I);
  
  //Fitting one component model
  if(I.n_elem > 100)
  {
    model1.Mean(0) = medianI;
    arma::uvec i;
    i = arma::find(I>model1.Mean(0));
    unsigned int T1 = i.n_elem;
    i = arma::find(I<model1.Mean(0));
    unsigned int T2 = i.n_elem;
    arma::vec temp, d;
    if(T1>=T2)
    {
      temp = I.elem(arma::find(I<=model1.Mean(0)));
    }
    else
    {
      temp = I.elem(arma::find(I>=model1.Mean(0)));
    }
    d = (2*model1.Mean(0)) - temp;
    temp = join_cols(temp, d );
    model1.Cov(0) = arma::stddev(temp);
    model1.Cov(0) = model1.Cov(0) * model1.Cov(0);
  }
  else
  {
    model1.Mean(0) = medianI;
    model1.Cov(0) = stdI;
    model1.Cov(0) = model1.Cov(0) * model1.Cov(0);
  }
  
  model1.Prior(0) = 1.;
  
  //Fitting of two components
  //int ncomp = 2;
  //int tmax = 20;
  if(arma::sum(I-medianI)>0.)
  {
    model2.Mean(0) = model1.Mean(0) - (0.25*stdI);
    model2.Mean(1) = model1.Mean(0) + (3.*stdI);
    model2.Prior(0) = 0.65;
    model2.Prior(1) = 0.35;
  }
  else
  {
    model2.Mean(0) = model1.Mean(0) - (3.*stdI);
    model2.Mean(1) = model1.Mean(0) + (0.25*stdI);
    model2.Prior(0) = 0.35;
    model2.Prior(1) = 0.65;
  }
  model2.Cov(0) = 0.1*stdI*stdI;
  model2.Cov(1) = model2.Cov(0);

  try
  {
    vmodel modeltemp;
    modeltemp = emgmm(I, model2, 2, 20);
    arma::uvec utemp = arma::find(model2.Prior < 1.e-2);
    if(utemp.n_elem > 0)
    {
      model2.Mean(0) = model1.Mean(0) - (0.5*stdI);
      model2.Mean(1) = model1.Mean(0) + (0.5*stdI);
      model2.Prior(0) = 0.5;
      model2.Prior(1) = 0.5;
      modeltemp = emgmm(I, model2, 2, 20);
    }
    model2 = modeltemp;
  }
  catch(...)
  {
    try
    {
      model2.Mean(0) = model1.Mean(0) - (0.5*stdI);
      model2.Mean(1) = model1.Mean(0) + (0.5*stdI);
      model2.Prior(0) = 0.5;
      model2.Prior(1) = 0.5;
      vmodel modeltemp = emgmm(I, model2, 2, 20);
      model2 = modeltemp;
    }
    catch(...)
    {
      resultat = model1;
      resultat.ncomp = 1;
      resultat.Q(0) = -1;
      resultat.Q(1) = -1;
      resultat.Q(2) = 1;
      return resultat;
    }
    
  }
  
  //Model selection criteria
  double kl = kld(model2);
  arma::vec x = arma::linspace(0,2,50);
  arma::vec pdfs1 = model2.Prior(0)*pdfgauss(x, model2.Mean(0), model2.Cov(0));  //Aqui s'ha d'arreglar?
  arma::vec pdfs2 = model2.Prior(1)*pdfgauss(x, model2.Mean(1), model2.Cov(1));  //Aqui s'ha d'arreglar?

  unsigned int x1, x2, uitemp;
  double n1, n2, dtemp;
  n1 = pdfs1.max(x1);
  n2 = pdfs2.max(x2);
  if(x1>x2)
  {
    arma::vec temp;
    temp = pdfs1;
    pdfs1 = pdfs2;
    pdfs2 = temp;
    uitemp = x1;
    x1 = x2;
    x2 = uitemp;
    dtemp = n1;
    n1 = n2;
    n2 = dtemp;
  }
  utemp = arma::find(model2.Prior>klT3);
  if( ( pdfs1(x1)>pdfs2(x1) ) && ( pdfs1(x2)<pdfs2(x2) ) && utemp.n_elem == 2 )
  {
    arma::vec temp = arma::abs(pdfs1.rows(x1,x2)-pdfs2.rows(x1,x2));
    unsigned int m;
    (void)temp.min(m);
    m = m + x1;
    double dtemp;
    dtemp = n1<n2?n1:n2;
    double c = pdfs1(m)/dtemp;
    double dtemp1, dtemp2;
    dtemp1 = double(x1) + 1.;
    dtemp2 = double(x2) + 1.;
    if( (( (c<klT2) && ((dtemp1/dtemp2)<0.85) ) || ( (c<0.8) && ((dtemp1/dtemp2)<0.6) )) && kl>klT1 )
    {
      resultat = model2;
      resultat.ncomp = 2;
      resultat.Q(0) = c;
      resultat.Q(1) = kl;
      resultat.Q(2) = 2;
    }
    else
    {
      resultat = model1;
      resultat.ncomp = 1;
      resultat.Q(0) = c;
      resultat.Q(1) = kl;
      resultat.Q(2) = 1;
    }
  }
  else
  {
    resultat = model1;
    resultat.ncomp = 1;
    resultat.Q(0) = 1;
    resultat.Q(1) = kl;
    resultat.Q(2) = 1;
  }
  return resultat;
}

arma::vec cnvscoring::pdfgauss(arma::vec v, double tmean, double sd2)
{
  arma::vec result, exponent;
  
  exponent = v - tmean;
  exponent = -exponent%exponent;
  exponent = exponent/(2 * sd2);

  result = exp(exponent);
  result = result/sqrt(2. * MY_PI * sd2);

  return result;
} 

itpp::Array<itpp::vec> cnvscoring::arma2itpp(arma::vec v, unsigned int n)
{
  itpp::Array<itpp::vec> result = itpp::Array<itpp::vec>(n);
  for (unsigned int i = 0; i<n; i++)
  {
    itpp::vec temp = itpp::vec(1);
    temp(0) = v(i);
    result(i) = temp;
  }
  return result;
}

itpp::vec cnvscoring::arma2itppVec(arma::vec v, unsigned int n)
{
  itpp::vec result = itpp::vec(n);
  for (unsigned int i = 0; i<n; i++)
  {
    result(i) = v(i);
  }
  return result;
}

vmodel cnvscoring::emgmm(arma::vec v, vmodel model, unsigned int ndim, unsigned int maxit)
{
  vmodel resultat;
  
  itpp::MOG_diag itppmodel = itpp::MOG_diag(ndim,1);
  itpp::Array< itpp::vec > X = arma2itpp(v, v.n_elem);
  
  itpp::Array< itpp::vec > itpptemp = arma2itpp(model.Cov, ndim);
  itppmodel.set_diag_covs( itpptemp );
  
  itpptemp = arma2itpp(model.Mean, ndim);
  itppmodel.set_means( itpptemp );
  
  itpp::vec itppVectemp = arma2itppVec(model.Prior, ndim);
  itppmodel.set_weights ( itppVectemp );
  
  //itppmodel.set_checks(false);
  
  itpp::MOG_diag_ML(itppmodel, X, maxit, 0.0, 0.0, false);
  
  itpp::Array< itpp::vec > itpptempMeans = itppmodel.get_means();
  itpp::Array< itpp::vec > itpptempCovs = itppmodel.get_diag_covs();
  itpp::vec itpptempWeights = itppmodel.get_weights();
  for(unsigned int it = 0; it < ndim; it++)
  {
    itppVectemp = itpptempMeans(it);
    resultat.Mean(it) = itppVectemp(0);
    itppVectemp = itpptempCovs(it);
    resultat.Cov(it) = itppVectemp(0);
    resultat.Prior(it) = itpptempWeights(it);
  }
  
  return resultat;
}

double cnvscoring::kld(vmodel model)
{
  double kl1, kl2, result;
  kl1 = model.Mean(0)-model.Mean(1);
  kl1 = (kl1*kl1) + model.Cov(0) - model.Cov(1);
  kl1 = 0.5*kl1/model.Cov(1);
  
  kl2 = model.Mean(1)-model.Mean(0);
  kl2 = (kl2*kl2) + model.Cov(1) - model.Cov(0);
  kl2 = 0.5*kl2/model.Cov(0);
  
  result = kl1>kl2?kl1:kl2;
  
  return result;
}