#include "../include/CahnHilliard.h"

////////////// Phase Field Initial Condition
/**
  Constructor
    * Set the correct number of components
    * Set the interface thickness
**/
template<int dim> Phase0<dim>::Phase0( const double dd, const TestCase _tc ) : Function<dim>(2), delta( dd ), tc( _tc ) {}

/**
  Destructor
    * Does nothing
**/
template<int dim> Phase0<dim>::~Phase0() {}

/**
  value
    * returns bubbles of 1 centered at <center> with radius <radius>
**/
template<int dim> double Phase0<dim>::value( const Point<dim> &p, const unsigned component ) const{
  double ret_val = 0.;
  switch( tc ){
    case TEST_CASE_MOVE:
    {
      // One droplet
      const double radius = 0.5;
      const Point<dim> center;
      double dist = p.distance( center );
      ret_val = std::tanh( ( radius - dist )/delta );
    }
    break;
    case TEST_CASE_SPLIT:
    {
      // Long droplet
      const double radius_x = 2.5, radius_y = 0.5;
      ret_val = 1. - p[0]*p[0]/(radius_x*radius_x) - p[1]*p[1]/(radius_y*radius_y);
      ret_val = std::tanh( ret_val/delta );
    }
    break;
    case TEST_CASE_MERGE:
    {
      // Two droplets
      const double radius = 0.5;
      const Point<dim> center1( -.7, 0. ), center2( .7, 0. );
      double dist1 = p.distance( center1 ), dist2 = p.distance( center2 );
      if( ( dist1 < radius ) or ( dist2 < radius ) )
        ret_val = 1.;
      else
        ret_val = -1.;
    }
      break;
    default:
      Assert( false, ExcInternalError() );
  };
  if( component == 0 )
    return ret_val;
  else
    return 0.;
}
////////////////////////////////////////////

/**
  Constructor
    * Copy the parameters
    * Initialize the data
**/
template<int dim> CahnHilliard<dim>::CahnHilliard( const Triangulation<dim> &tria, 
                                                   Material_Parameters &params,
                                                   const unsigned deg,
                                                   const TestCase _tc ) :
    Problem<dim>( tria, 2, 1e-8 ),
    mobility( params.M ), permitivity( params.epsilon ), density( params.rho ), W( params.W ),
    gamma_fs( params.gamma_fs ),
    gamma( params.gamma ), delta( params.delta ), alpha( params.alpha ), tc( _tc ), 
    get_phase( this->dh, this->sol, "phi" ),
    get_potential( this->dh, this->sol, "mu", AsFunction<dim>::isScalarValued, 1 ),
    get_old_phase( this->dh, oldsol, "phiold" ) {
  this->fe.push_back( FESystem<dim>( FE_Q<dim>( deg ), 2 ) );
  this->fe.push_back( FESystem<dim>( FE_Nothing<dim>(), 2 ) );
}

/**
  advance
    * copy the solution to the old solution
**/
template<int dim> void CahnHilliard<dim>::advance(){
  oldsol = this->sol;
}

/**
  it_error
    * Computes the error of the inner iteration
**/
template<int dim> double CahnHilliard<dim>::it_error(){
  olditsol.sadd(-1., this->sol );
  return olditsol.l2_norm();
}

/**
  Destructor
    * Does nothing
**/
template<int dim> CahnHilliard<dim>::~CahnHilliard() {}

/**
  InitLADataSuffix
    * set oldsol and olditsol to the correct size
**/
template<int dim> void CahnHilliard<dim>::InitLADataSuffix(){
  oldsol.reinit( this->n_dofs );
  olditsol.reinit( this->n_dofs );
}

/**
  SetInitialData
    * Simply interpolating the boundary condition wouldn't work, since
    * our finite element space has a component that is FENothing
    * What we do then is project to the finite element space
    * We do not care about parallelization, since this is done only a few times at the beginning
**/
template<int dim> void CahnHilliard<dim>::SetInitialData(){
  Phase0<dim> i_data( delta, tc );
  this->K = 0.;
  this->rhs = 0.;
  FEValuesExtractors::Scalar ph(0), mu(1);
  hp::QCollection<dim> quad( QGauss<dim>( this->fe[0].degree + 1 ) );
  quad.push_back( QGauss<dim>( this->fe[1].degree + 1 ) );
  hp::FEValues<dim> fv( this->fe, quad, update_values | update_quadrature_points | update_JxW_values );
  typename Problem<dim>::PerTaskData ptdata( this->fe.max_dofs_per_cell() );
  typename hp::DoFHandler<dim>::active_cell_iterator cell = this->dh.begin_active(), end = this->dh.end();
  for( ; cell not_eq end; ++cell ){
    fv.reinit( cell );
    const FEValues<dim> &fe_val = fv.get_present_fe_values();
    const unsigned nqp = fe_val.n_quadrature_points;
    const unsigned dpc = fe_val.get_fe().dofs_per_cell;
    ptdata.loc_m.reinit( dpc, dpc );
    ptdata.loc_rhs.reinit( dpc );
    for( unsigned q=0; q<nqp; ++q )
      for( unsigned i=0; i<dpc; ++i ){
        ptdata.loc_rhs(i) += fe_val.JxW(q)*fe_val[ph].value( i, q )*i_data.value( fe_val.quadrature_point(q) );
        for( unsigned j=0; j<dpc; ++j )
          ptdata.loc_m( i, j ) += fe_val.JxW(q)*( fe_val[ph].value( i, q )*fe_val[ph].value( j, q ) +
                                           fe_val[mu].value( i, q )*fe_val[mu].value( j, q ) );
      }
    ptdata.ldi.resize( dpc );
    cell->get_dof_indices( ptdata.ldi );
    CopyToGlob( ptdata );
  }
  SolverControl control;
  TrilinosWrappers::SolverDirect direct( control );
  direct.solve( this->K, this->sol, this->rhs );
  this->constraints.distribute( this->sol );
  oldsol = this->sol;
}

/**
  SolveSystemPreffix
    * Store the old solution
**/
template<int dim> void CahnHilliard<dim>::SolveSystemPreffix(){
  olditsol = this->sol;
}

/**
  ReinitPrec
    *reinitialize the preconditioner
**/
template<int dim> void CahnHilliard<dim>::ReinitPrec(){}

/**
  DoSolve
    * Solve the system to the given accuracy
**/
template<int dim> void CahnHilliard<dim>::DoSolve ( SolverControl &control ){
  TrilinosWrappers::SolverDirect direct( control );
  direct.solve( this->K, this->sol, this->rhs );
}

/**
  AssembleSystem
    * Clear the system matrix and rhs
    * Create a sample of the scratch and local data for the particular problem that we want to have
    * If running on multithreaded mode create a list of tasks to assemble each cell
    * Else do a for loop on cells
**/
template<int dim> void CahnHilliard<dim>::AssembleSystem( std::vector< AsFunction<dim> *> &data, const double,
                                                          const bool m_threaded ){
  Assert( data.size() == 2, ExcDimensionMismatch( data.size(), 2 ) );
  this->K = 0.;
  this->rhs = 0.;
  hp::QCollection<dim> quad( QGauss<dim>( this->fe[0].degree + 1 ) );
  quad.push_back( QGauss<dim>( this->fe[1].degree + 1 ) );
  hp::QCollection<dim-1> f_quad( QGauss<dim-1>( this->fe[0].degree + 1 ) );
  f_quad.push_back( QGauss<dim-1>( this->fe[1].degree + 1 ) );
  data[0]->reset( quad, update_gradients ); // voltage
  data[1]->reset( quad, update_values ); // velocity
  data[1]->reset_face( f_quad, update_values ); // velocity
  typename Problem<dim>::PerTaskData ptdata( this->fe.max_dofs_per_cell() );
  ScratchData scratch( this->fe, data,
                       quad, update_values | update_gradients | update_JxW_values,
                       f_quad, update_values | update_gradients | update_JxW_values | update_normal_vectors );
  Iterator start, end, cc;
  start(0) = this->dh.begin_active();
  end(0) = this->dh.end();
  for( unsigned i=0; i<2; ++i ){
    start( i+1 ) = data[i]->get_begin();
    end( i+1 ) = data[i]->get_end();
  }
  if( m_threaded )
    WorkStream::run( start, end,
                     std_cxx1x::bind( &CahnHilliard<dim>::AssembleCell, this, std_cxx1x::_1, std_cxx1x::_2, std_cxx1x::_3 ),
                     std_cxx1x::bind( &CahnHilliard<dim>::CopyToGlob, this, std_cxx1x::_1 ),
                     scratch, ptdata );
  else{
    for( cc = start; cc not_eq end; ++cc ){
      AssembleCell( cc, scratch, ptdata );
      CopyToGlob( ptdata );
    }
  }
  for( unsigned i=0; i<2; ++i )
    data[i]->reset();
  data[1]->reset_face();
}

/**
  AssembleCell
    * reinit fe_values and fe_functions
    * assemble the local matrix and rhs
    * m_ij and f_i are the ones for CahnHilliard (too long to write)
**/
template<int dim> void CahnHilliard<dim>::AssembleCell( const Iterator Its, ScratchData &scratch,
                                                        typename Problem<dim>::PerTaskData &data ){
  FEValuesExtractors::Scalar ph(0), mu(1);
  const double A = W.A, B = gamma_fs.B;
  scratch.fe_val.reinit( Its(0) );
  const FEValues<dim> &fe_val = scratch.fe_val.get_present_fe_values();
  const unsigned nqp = fe_val.n_quadrature_points;
  data.dpc = fe_val.get_fe().dofs_per_cell;
  scratch.loc_grad_voltage.resize( nqp );
  scratch.loc_vel.resize( nqp );
  scratch.phase_grad.resize( nqp );
  scratch.phase_val.resize( nqp );
  scratch.voltage.reinit( Its(1) );
  scratch.velocity.reinit( Its(2) );
  fe_val[ph].get_function_values( this->oldsol, scratch.phase_val );
  fe_val[ph].get_function_gradients( this->oldsol, scratch.phase_grad );
  scratch.voltage.get_function_gradients( scratch.loc_grad_voltage );
  scratch.velocity.get_function_values( scratch.loc_vel );
  data.loc_m.reinit( data.dpc, data.dpc );
  data.loc_rhs.reinit( data.dpc );
  for( unsigned q=0; q<nqp; ++q ){
    const double gammaddelta = gamma/delta, Axgammaddelta = A*gammaddelta, gammaxdelta = gamma*delta,
                 M_q = mobility( scratch.phase_val[q] ), W_q = W( scratch.phase_val[q] ),
                 h_eps_q = 0.5*permitivity.deriv( scratch.phase_val[q] ),
                 h_dens_q = 0.5*density.deriv( scratch.phase_val[q] );
    for( unsigned i=0; i<data.dpc; ++i ){
      data.loc_rhs(i) += fe_val.JxW(q)*(
                            (
                              scratch.phase_val[q]/this->dt
                              - scratch.phase_grad[q]*scratch.loc_vel[q]
                            )*fe_val[ph].value( i, q )
                            +(
                              Axgammaddelta*scratch.phase_val[q] - gammaddelta*W_q
                              + h_eps_q*scratch.loc_grad_voltage[q].norm_square()
                              - h_dens_q * scratch.loc_vel[q].norm_square()
                            )*fe_val[mu].value( i, q )
                        );
      for( unsigned j=0; j < data.dpc; ++j )
        data.loc_m( i, j ) += fe_val.JxW(q)*(
                                  fe_val[ph].value( j, q )*fe_val[ph].value( i, q )/this->dt
                                  + M_q*fe_val[mu].gradient( j, q )*fe_val[ph].gradient( i, q )
                                  + Axgammaddelta*fe_val[ph].value( j, q )*fe_val[mu].value( i, q )
                                  + gammaxdelta*fe_val[ph].gradient( j, q )*fe_val[mu].gradient( i, q )
                                  - fe_val[mu].value( j, q )*fe_val[mu].value( i, q )
                              );
    }
  }
  // loop over faces
  for( unsigned f=0; f < GeometryInfo<dim>::faces_per_cell; ++f ){
    if( isNeeded( Its(0), f ) ){
      scratch.fe_face_val.reinit( Its(0), f );
      const FEFaceValues<dim> &fe_face_val = scratch.fe_face_val.get_present_fe_values();
      const unsigned nfqp = fe_face_val.n_quadrature_points;
      scratch.loc_face_vel.resize( nfqp );
      scratch.phase_face_grad.resize( nfqp );
      scratch.normals.resize( nfqp );
      scratch.phase_face_val.resize( nfqp );
      scratch.velocity.reinit( Its(2), f );
      scratch.normals = fe_face_val.get_normal_vectors();
      fe_face_val[ph].get_function_values( this->oldsol, scratch.phase_face_val );
      fe_face_val[ph].get_function_gradients( this->oldsol, scratch.phase_face_grad );
      scratch.velocity.get_function_face_values( scratch.loc_face_vel );
      for( unsigned q=0; q <nfqp; ++q ){
        const double alphaddtpB = ( alpha/( this->dt ) ) + B,
                     gamma_fs_q = gamma_fs( scratch.phase_face_val[q] ),
                     un = scratch.loc_face_vel[q]*scratch.normals[q],
                     Dphin = scratch.phase_face_grad[q]*scratch.normals[q],
                     transp = alpha*( scratch.loc_face_vel[q]*scratch.phase_face_grad[q] - un*Dphin );
        for( unsigned i=0; i<data.dpc; ++i ){
          data.loc_rhs(i) += fe_face_val.JxW(q)*(
                                alphaddtpB*scratch.phase_face_val[q]
                                - transp
                                - gamma_fs_q
                            )*fe_face_val[mu].value( i, q );
          for( unsigned j=0; j<data.dpc; ++j )
            data.loc_m( i, j ) += fe_face_val.JxW(q)*(
                                      alphaddtpB
                                  )*fe_face_val[ph].value( j, q )*fe_face_val[mu].value( i, q );
        }
      }
    }
  }
  data.ldi.resize( data.dpc );
  Its(0)->get_dof_indices( data.ldi );
}

/**
  PreRefinementPreffix
    * copy solution and old solution to the transferable vector
**/
template<int dim> void CahnHilliard<dim>::PreRefinementPreffix(){
  this->x_sol[0] = this->sol;
  this->x_sol[1] = oldsol;
}

/**
  PostRefinementSuffix
    * copy from the transferred vector
**/
template<int dim> void CahnHilliard<dim>::PostRefinementSuffix( const std::vector<TrilinosWrappers::Vector> &sol_tmp ){
  this->sol = sol_tmp[0];
  oldsol = sol_tmp[1];
}

template class CahnHilliard<DIM>;
template class Phase0<DIM>;
