#include "../include/Voltage.h"

/////////////// Voltage  Boundary Conditions
template<int dim> Vbc<dim>::Vbc( const TestCase testcase, const double t ): Function<dim>( 1, t ),
                    tc( testcase ) {}

template<int dim> Vbc<dim>::~Vbc() {}

template<int dim> double Vbc<dim>::value( const Point<dim> &p, const unsigned ) const{
  double ret_val = 0.;
  const double tt = this->get_time();
  const double Vmax = 20.;
  switch( tc ){
    case TEST_CASE_MERGE:
        if( ( std::fabs( p(0) ) < 0.5 ) and ( p(1) <=0. ) and ( tt > 0. ) )
        ret_val = Vmax;
      break;
    case TEST_CASE_SPLIT:
      if( ( ( std::fabs( p(0) ) > 1.5 ) ) and ( p(1) <=0. ) and ( tt > 0. ) )
        ret_val = Vmax;
      break;
    case TEST_CASE_MOVE:
      if( ( (p(0) > 0. ) ) and ( p(1) <=0. ) and ( tt > 0. ) )
        ret_val  = Vmax;
      break;
    default:
      Assert( false, ExcInternalError() );
  }
  return ret_val;
}
/////////////////////////////////

/**
  Constructor
    * Store references and copies of the data
    * Setup DoFs
    * Init the Linear Algebra Data
**/
template<int dim> Voltage<dim>::Voltage( const Triangulation<dim> &tria, Material_Parameters &params, const unsigned deg,
                                         const unsigned u_prec, const double ee, const double thres, const unsigned sweeps,
                                         const TestCase testcase ) :
    Problem<dim>( tria, 1, ee, true, u_prec ), epsilon( params.epsilon ), epsilon_plate( params.epsilon_plate ),
    volt( testcase ), get_voltage( this->dh, this->sol, "voltage" ) {
  prec_data.elliptic = true;
  prec_data.higher_order_elements = deg > 1;
  prec_data.aggregation_threshold = thres;
  prec_data.smoother_sweeps = sweeps;
  BCs['s'] = &volt;
  this->fe.push_back( FE_Q<dim>( deg ) );
  this->fe.push_back( FE_Q<dim>( deg ) );
}

/**
  Destructor
    * Clear data
**/
template<int dim> Voltage<dim>::~Voltage() {}


/**
  AssembleSystem
    * Clear global matrix and rhs
    * Create a sample of the scratch and local data for the particular problem that we want to have
    * If running on multithreaded mode create a list of tasks to assemble each cell
    * Else do a for loop on cells
**/
template<int dim> void Voltage<dim>::AssembleSystem( std::vector< AsFunction<dim> * > &data, const double time,
                                                     const bool m_threaded ){
  Assert( data.size() == 2, ExcDimensionMismatch( data.size(), 2 ) );
  volt.set_time( time );
  this->K = 0.;
  this->rhs = 0.;
  hp::QCollection<dim> quad( QGauss<dim>( this->fe[0].degree + 1 ) );
  quad.push_back( QGauss<dim>( this->fe[1].degree + 1 ) );
  Iterator start, end, cc;
  start(0) = this->dh.begin_active();
  end(0) = this->dh.end();
  for( unsigned i=0; i<2; ++i ){
    start( i+1 ) = data[i]->get_begin();
    end( i+1 ) = data[i]->get_end();
  }
  data[0]->reset( quad, update_values );
  data[1]->reset( quad, update_values );
  typename Problem<dim>::PerTaskData ptdata( this->fe.max_dofs_per_cell() );
  ScratchData scratch( this->fe, data, quad, update_values | update_gradients | update_JxW_values );
  if( m_threaded )
    WorkStream::run( start, end,
                    std_cxx1x::bind( &Voltage<dim>::AssembleCell, this, std_cxx1x::_1, std_cxx1x::_2, std_cxx1x::_3 ),
                    std_cxx1x::bind( &Voltage<dim>::CopyToGlob, this, std_cxx1x::_1 ),
                    scratch, ptdata );
  else{
    for( cc = start; cc not_eq end; ++cc ){
      AssembleCell( cc, scratch, ptdata );
      this->CopyToGlob( ptdata );
     }
  }
  for( unsigned i=0; i<2; ++i )
    data[i]->reset();
}

/**
  AssembleCell
    * reinit fe_values and fe_functions
    * assemble the local matrix and rhs
    * m_ij = int_K epsilon(\phi) \grad p_j \grad p_i
    * f_i = int_K  q p_i
**/
template<int dim> void Voltage<dim>::AssembleCell( const Iterator Its, ScratchData &scratch,
                                                   typename Problem<dim>::PerTaskData &data ){
  scratch.fe_val.reinit( Its(0) );
  scratch.charge.reinit( Its(1) );
  scratch.phase.reinit( Its(2) );
  const FEValues<dim> &fe_val = scratch.fe_val.get_present_fe_values();
  const unsigned nqp = fe_val.n_quadrature_points;
  data.dpc = fe_val.get_fe().dofs_per_cell;
  scratch.loc_charge.resize( nqp );
  scratch.charge.get_function_values( scratch.loc_charge );
  scratch.loc_phase.resize( nqp );
  scratch.phase.get_function_values( scratch.loc_phase );
  data.loc_m.reinit( data.dpc, data.dpc );
  data.loc_rhs.reinit( data.dpc );
  for( unsigned q=0; q<nqp; ++q ){
    const double vare = ( Its(0)->material_id() == 'f' )?epsilon( scratch.loc_phase[q] ):epsilon_plate;
    for( unsigned i=0; i<data.dpc; ++i ){
      data.loc_rhs(i) += fe_val.JxW(q)*scratch.loc_charge[q]*fe_val.shape_value( i, q );
      for( unsigned j=0; j<data.dpc; ++j )
        data.loc_m( i, j ) += vare*fe_val.shape_grad( j, q )*fe_val.shape_grad( i, q )*fe_val.JxW( q );
    }
  }
  data.ldi.resize( data.dpc );
  Its(0)->get_dof_indices( data.ldi );
}

/**
  SetupDoFsSuffix
    * Apply boundary conditions
**/
template<int dim> void Voltage<dim>::SetupDoFsSuffix(){
  VectorTools::interpolate_boundary_values( this->dh, BCs, this->constraints );
}

/**
  DoSolve
    * Solve the system to the prescribed accuracy
**/
template<int dim> void Voltage<dim>::DoSolve( SolverControl &control ){
  control.set_tolerance( std::max( control.tolerance(), this->eps ) );
  TrilinosWrappers::SolverCG cg( control );
  cg.solve( this->K, this->sol, this->rhs, prec );
}

/**
  SetInitialData
    * V = 0
**/
template<int dim> void Voltage<dim>::SetInitialData(){
  this->sol = 0.;
}

/**
  ReinitPrec
    * Reinitialize the preconditioner
**/
template<int dim> void Voltage<dim>::ReinitPrec(){
  prec.initialize( this->K, prec_data );
}

template<int dim> void Voltage<dim>::PreRefinementPreffix(){
  this->x_sol[0] = this->sol;
}

template<int dim> void Voltage<dim>::PostRefinementSuffix( const std::vector<TrilinosWrappers::Vector> &sol_tmp ){
  this->sol = sol_tmp[0];
}

template class Vbc<DIM>;
template class Voltage<DIM>;
