#include "../include/EWOD.h"

template<int dim> EWOD<dim>::EWOD( DataStorage<dim> &data ) :
    t0( data.t_0 ), T( data.T ),
    params( data.epsilon, data.epsilon_plate, data.K, data.M, data.rho, data.beta, data.eta,
            data.lambda, data.gamma, data.theta_s, data.delta, data.alpha ),
    v_cout( std::cout, data.verbose ),
    dt0( data.dt ), dt( data.dt ),
    V( tri, params, data.volt_deg, data.volt_update_prec, data.volt_eps, data.volt_aggregation_threshold,
       data.volt_smoother_sweeps, data.testcase ),
    q( tri, params, data.volt_deg, data.charge_update_prec, data.charge_eps, data.charge_aggregation_threshold,
       data.charge_smoother_sweeps ),
    chnse( tri, data, v_cout ),
    plot( data.output ),
    max_refs( data.n_of_initial_refines + data.n_of_extra_refines ),
    initial_refs( data.n_of_initial_refines ){
  // We start by creating the fluid triangulation. We create some repetitions so that the size of the elements
  // is not too distorted.
  std::vector<unsigned> repetitions( dim );
  for( unsigned d=0; d<dim; ++d )
    repetitions[d] = data.pu[d] - data.pl[d];
  // We create two triangulations, one is going to be the fluid and the other the parallel plates
  Triangulation<dim> fluid, plate;
  // We also create the upper point for the plates
  Point<dim> pu_plate;
  for( unsigned d=0; d<dim-1; ++d )
    pu_plate[d] = data.pu[d];
  pu_plate[dim-1] = data.width;
  repetitions[dim-1] = 1;
  // Create the meshes and set their subdomain ids to what is needed
  GridGenerator::subdivided_hyper_rectangle( fluid, repetitions, data.pl, data.pu );
  {
    typename Triangulation<dim>::active_cell_iterator cell = fluid.begin_active(), end = fluid.end();
    for( ; cell not_eq end; ++cell )
      cell->set_material_id( 'f' );
  }
  GridGenerator::subdivided_hyper_rectangle( plate, repetitions, data.pl, pu_plate );
  {
    typename Triangulation<dim>::active_cell_iterator cell = plate.begin_active(), end = plate.end();
    for( ; cell not_eq end; ++cell )
      cell->set_material_id( 'p' );
  }
  // Shift the plate to where we want it
  Triangulation<dim> upper_plate, lower_plate, tmp;
  upper_plate.copy_triangulation( plate );
  Point<dim> upper_shift;
  upper_shift[dim-1] = data.pu[dim-1]-data.pl[dim-1];
  GridTools::shift( upper_shift, upper_plate );
  lower_plate.copy_triangulation( plate );
  Point<dim> lower_shift;
  lower_shift[dim-1] = -data.width;
  GridTools::shift( lower_shift, lower_plate );
  // And merge it with the fluid part
  GridGenerator::merge_triangulations( lower_plate, fluid, tmp );
  GridGenerator::merge_triangulations( tmp, upper_plate, tri );
  // Add boundary ids
  {
    typename Triangulation<dim>::active_cell_iterator cell = tri.begin_active(), end = tri.end();
    for( ; cell not_eq end; ++cell ){
      for( unsigned f = 0; f < GeometryInfo<dim>::faces_per_cell; ++f )
        if( cell->face(f)->at_boundary() ){
          const double last_coord = cell->face(f)->center()[dim-1];
          if( ( last_coord == data.pl[dim-1] - data.width ) or ( last_coord == data.pu[dim-1]+ data.width )  )
            cell->face(f)->set_boundary_indicator( 's' );
          else{
            if( ( last_coord > data.pl[dim-1] ) and ( last_coord < data.pu[dim-1] )  )
              cell->face(f)->set_boundary_indicator( 'g' );
            else
              cell->face(f)->set_boundary_indicator( 'n' );
          }
        }
    }
  }
  tri.refine_global( initial_refs );
  v_cout<<" Number of active cells = "<<tri.n_active_cells()<<std::endl;
  V.init();
  q.init();
  chnse.init();
  for( unsigned n = 0; n<=data.n_of_extra_refines; ++n )
    RefineMesh( true );
  v_cout<<" density_1 = "<<params.rho(1.)<<" density_2 = "<<params.rho(-1.)<<std::endl
        <<" n_dofs = "<<( V.size()+q.size()+chnse.size() )
        <<" = "<<V.size()<<"+"<<q.size()<<"+"<<chnse.size()<<std::endl;
}

template<int dim> EWOD<dim>::~EWOD() {}

template<int dim> void EWOD<dim>::run(){
  double time = t0 - 25.*dt;
  unsigned step = 0;
  std::vector< AsFunction<dim> * > V_params(2), q_params(3), ph_mu_params(2), vel_params(6), pres_params(1);
  V_params[0] = &q.get_charge;
  V_params[1] = &chnse.get_phase;
  q_params[0] = &V.get_voltage;
  q_params[1] = &chnse.get_phase;
  q_params[2] = &chnse.get_velocity;
  ph_mu_params[0] = &V.get_voltage;
  ph_mu_params[1] = &chnse.get_velocity;
  vel_params[0] = &chnse.get_phase;
  vel_params[1] = &chnse.get_potential;
  vel_params[2] = &q.get_charge;
  vel_params[3] = &V.get_voltage;
  vel_params[4] = &chnse.get_extrapolated_pressure;
  vel_params[5] = &chnse.get_old_phase;
  pres_params[0] = &chnse.get_velocity;
  Plot( step );
  do{
    V.set_dt( dt );
    q.set_dt( dt );
    chnse.set_dt( dt );
    if( ( time >= 0 ) and ( ( step == 0 ) or ( step % plot == 0 ) ) ){
      v_cout<<" Plotting . . . "<<std::endl;
      Plot( step );
      RefineMesh();
    }
    if( time > 0. )
      step++;
    v_cout<<" Time = "<<time<<std::endl<<" Solving for voltage . . . "<<std::endl;
    V.solve( V_params, time );
    v_cout<<" Solving for charge . . . "<<std::endl;
    q.solve( q_params, time );
    v_cout<<" Solving for phase field Navier Stokes. . . "<<std::endl;
    chnse.solve( ph_mu_params, vel_params, pres_params, time );
    set_dt();
    time += dt;
    v_cout<<" dt = "<<dt<<std::endl;
  }while( time <= T );
  v_cout<<" Plotting final state . . ."<<std::endl;
  Plot( step );
}

template<int dim> void EWOD<dim>::set_dt(){
  AsFunction<dim> vel = chnse.get_velocity;
  typename hp::DoFHandler<dim>::active_cell_iterator cell = vel.get_dh().begin_active(), end = vel.get_dh().end();
  const unsigned deg = std::max( vel.get_dh().get_fe()[0].degree, vel.get_dh().get_fe()[1].degree ) + 3;
  const QGauss<dim> quad( deg );
  const unsigned nqp = quad.size();
  std::vector< Tensor<1,dim> > loc_vel( nqp );
  hp::QCollection<dim> q_collection( quad );
  q_collection.push_back( quad );
  vel.reset( q_collection, update_values );
  double vh = 1e-12;
  for( ; cell not_eq end; ++cell ){
    vel.reinit( cell );
    vel.get_function_values( loc_vel );
    const double h_loc = cell->diameter();
    for( unsigned q=0; q<nqp; ++q ){
      const double v_norm = loc_vel[q].norm();
        vh = std::max( vh, v_norm/( h_loc*sqrt( h_loc ) ) );
    }
  }
  vh += 1e-12;
  v_cout<<"CFL = "<<(1./vh)<<std::endl;
  dt = std::min( dt0, 0.5/vh );
  vel.reset();
}

template<int dim> void EWOD<dim>::Plot( const unsigned step ){
  const hp::FECollection<dim> &fe_v = V.get_fe(), &fe_q = q.get_fe(), &fe_ph = chnse.get_ph_mu_fe(),
                              &fe_vel = chnse.get_vel_fe(), &fe_pres = chnse.get_pres_fe();
  hp::FECollection<dim> fe;
  fe.push_back( FESystem<dim>( fe_v[0], 1, fe_q[0], 1, fe_ph[0], 1, fe_vel[0], 1, fe_pres[0], 1 ) );
  fe.push_back( FESystem<dim>( fe_v[1], 1, fe_q[1], 1, fe_ph[1], 1, fe_vel[1], 1, fe_pres[1], 1 ) );
  hp::DoFHandler<dim> dh( tri );
  {
    typename hp::DoFHandler<dim>::active_cell_iterator cell = dh.begin_active(), end = dh.end();
    for( ; cell not_eq end; ++cell ){
      const unsigned char material_id = cell->material_id();
      cell->set_active_fe_index( ( material_id == 'f' )?0:1 );
    }
  }
  dh.distribute_dofs( fe );
  Assert( dh.n_dofs() == V.size() + q.size() + chnse.size(), ExcInternalError() );
  Vector<double> sol( dh.n_dofs() );
  std::vector<unsigned> j_ldi( fe[0].dofs_per_cell ), V_ldi( fe_v[0].dofs_per_cell ), q_ldi( fe_q[0].dofs_per_cell ),
                        ph_ldi( fe_ph[0].dofs_per_cell ), vel_ldi( fe_vel[0].dofs_per_cell ),
                        pres_ldi( fe_pres[0].dofs_per_cell );
  AsFunction<dim> &vv = V.get_voltage, &qq = q.get_charge, &phi = chnse.get_phase, &u = chnse.get_velocity,
                  &p = chnse.get_pressure;
  IteratorGroup<6, typename hp::DoFHandler<dim>::active_cell_iterator> cell, end;
  cell(0) = dh.begin_active();
  cell(1) = vv.get_begin();
  cell(2) = qq.get_begin();
  cell(3) = phi.get_begin();
  cell(4) = u.get_begin();
  cell(5) = p.get_begin();
  end(0) = dh.end();
  end(1) = vv.get_end();
  end(2) = qq.get_end();
  end(3) = phi.get_end();
  end(4) = u.get_end();
  end(5) = p.get_end();
  for( ; cell not_eq end; ++cell ){
    const unsigned a_i = cell(0)->active_fe_index();
    j_ldi.resize( fe[a_i].dofs_per_cell );
    cell(0)->get_dof_indices( j_ldi );
    V_ldi.resize( fe_v[a_i].dofs_per_cell );
    cell(1)->get_dof_indices( V_ldi );
    q_ldi.resize( fe_q[a_i].dofs_per_cell );
    cell(2)->get_dof_indices( q_ldi );
    ph_ldi.resize( fe_ph[a_i].dofs_per_cell );
    cell(3)->get_dof_indices( ph_ldi );
    vel_ldi.resize( fe_vel[a_i].dofs_per_cell );
    cell(4)->get_dof_indices( vel_ldi );
    pres_ldi.resize( fe_pres[a_i].dofs_per_cell );
    cell(5)->get_dof_indices( pres_ldi );
    for( unsigned i = 0; i < fe[a_i].dofs_per_cell; ++i ){
      switch( fe[a_i].system_to_base_index(i).first.first ){
        case 0:
          /// Voltage
          Assert( fe[a_i].system_to_base_index(i).first.second == 0, ExcInternalError() );
          Assert( fe[a_i].system_to_base_index(i).second < fe_v[a_i].dofs_per_cell, ExcInternalError() );
          sol( j_ldi[i] ) = vv( V_ldi[ fe[a_i].system_to_base_index(i).second ] );
          break;
        case 1:
          /// Charge
          Assert( fe[a_i].system_to_base_index(i).first.second == 0, ExcInternalError() );
          Assert( fe[a_i].system_to_base_index(i).second < fe_q[a_i].dofs_per_cell, ExcInternalError() );
          sol( j_ldi[i] ) = qq( q_ldi[ fe[a_i].system_to_base_index(i).second ] );
          break;
        case 2:
          /// Phase Field
          Assert( fe[a_i].system_to_base_index(i).first.second == 0, ExcInternalError() );
          Assert( fe[a_i].system_to_base_index( i ).second < fe_ph[a_i].dofs_per_cell, ExcInternalError() );
          sol( j_ldi[i] ) = /*params.rho(*/ phi( ph_ldi[ fe[a_i].system_to_base_index(i).second ] ) /*)*/;
          break;
        case 3:
          /// Velocity
          Assert( fe[a_i].system_to_base_index(i).first.second == 0, ExcInternalError() );
          Assert( fe[a_i].system_to_base_index(i).second < fe_vel[a_i].dofs_per_cell, ExcInternalError() );
          sol( j_ldi[i] ) = u( vel_ldi[ fe[a_i].system_to_base_index(i).second ] );
          break;
        case 4:
          /// Pressure
          Assert( fe[a_i].system_to_base_index(i).first.second == 0, ExcInternalError() );
          Assert( fe[a_i].system_to_base_index(i).second < fe_pres[a_i].dofs_per_cell, ExcInternalError() );
          sol( j_ldi[i] ) = p( pres_ldi[ fe[a_i].system_to_base_index(i).second ] );
          break;
        default:
          Assert( false, ExcInternalError() );
      }
    }
  }
  std::vector<std::string> names( 5 + dim );
  names[0] = vv.name;
  names[1] = qq.name;
  names[2] = phi.name;
  names[3] = chnse.get_potential.name;
  for( unsigned d = 1; d <= dim; ++d )
    names[d+3] = u.name;
  names[4+dim] = p.name;
  DataOut<dim, hp::DoFHandler<dim> > data_out;
  data_out.attach_dof_handler( dh );
  std::vector< DataComponentInterpretation::DataComponentInterpretation > ci( 5 + dim );
  ci[0] = vv.interpretation;
  ci[1] = qq.interpretation;
  ci[2] = phi.interpretation;
  ci[3] = chnse.get_potential.interpretation;
  for ( unsigned d = 1; d <= dim; ++d )
    ci[d+3] = u.interpretation;
  ci[4+dim] = p.interpretation;
  data_out.add_data_vector( sol, names, DataOut<dim,hp::DoFHandler<dim> >::type_dof_data, ci );
  data_out.build_patches( std::max( fe_q[0].degree, fe_vel[0].degree ) );
  std::ostringstream filename;
  filename<<"solution-"<<step<<".vtk";
  std::ofstream output( filename.str().c_str() );
  data_out.write_vtk( output );
}

template<int dim> void EWOD<dim>::RefineMesh( const bool is_the_first_step ){
  v_cout<<" Refining Mesh . . . "<<std::endl;
  Vector<float> error( tri.n_active_cells() );
  std::vector<bool> mask( 2, true );
  mask[1] = false;
  hp::QCollection<dim-1> fquad( QGauss<dim-1>( chnse.get_ph_mu_fe()[0].degree+1 ) );
  fquad.push_back( QGauss<dim-1>( chnse.get_ph_mu_fe()[1].degree+1 ) );
  KellyErrorEstimator<dim>::estimate( chnse.get_phase.get_dh(), fquad, typename FunctionMap<dim>::type(),
                                      chnse.get_phase.get_vec(), error, mask, 0, numbers::invalid_unsigned_int,
                                      types::invalid_subdomain_id, 'f' );
  const double eee = error.linfty_norm();
  if( is_the_first_step )
    GridRefinement::refine( tri, error, 0.25*eee );
  else{
  /*  Vector<float> error_v( tri.n_active_cells() );
    hp::QCollection<dim-1> fquad_v( QGauss<dim-1>( V.get_fe()[0].degree+1 ));
    fquad_v.push_back( QGauss<dim-1>( V.get_fe()[1].degree+1 ) );
    KellyErrorEstimator<dim>::estimate( V.get_voltage.get_dh(), fquad_v, typename FunctionMap<dim>::type(),
                                        V.get_voltage.get_vec(), error_v );
    error += error_v;*/
    //     GridRefinement::refine_and_coarsen_optimize( tri, error );
    GridRefinement::refine_and_coarsen_fixed_fraction( tri, error, 0.9, 0.01 );
    for( typename Triangulation<dim>::active_cell_iterator cell = tri.begin_active();
        cell not_eq tri.end(); ++cell ){
      if( cell->level() > max_refs )
        cell->clear_refine_flag();
      if( cell->level() < initial_refs )
        cell->clear_coarsen_flag();
    }
  }
  tri.prepare_coarsening_and_refinement();
  tri.execute_coarsening_and_refinement();
}

template class EWOD<DIM>;
