#include "../include/CHNSE.h"

/**
  Constructor
    * Copy parameters and initialize subproblems
**/
template<int dim> CHNSE<dim>::CHNSE( const Triangulation<dim> &tria, const DataStorage<dim> &data,
                                     ConditionalOStream &costream ):
    params( data.epsilon, data.epsilon_plate, data.K, data.M, data.rho, data.beta, data.eta, data.lambda,
            data.gamma, data.theta_s, data.delta, data.alpha ),
    ch( tria, params, data.ph_deg, data.testcase ),
    vel( tria, params, data.pres_deg, data.vel_update_prec, data.vel_eps, data.vel_aggregation_threshold,
         data.vel_smoother_sweeps, data.vel_Krylov_size ),
    pres( tria, params, data.pres_deg, data.penalty_eps, data.penalty_aggregation_threshold,
          data.penalty_smoother_sweeps ),
    v_cout( costream ),
    get_phase( ch.get_phase ), get_potential( ch.get_potential ), get_old_phase( ch.get_old_phase ),
    get_velocity( vel.get_velocity ),
    get_pressure( pres.get_pressure ), get_extrapolated_pressure( pres.get_extrapolated_pressure ) {}

/**
  Destructor
    * Does nothing
**/
template<int dim> CHNSE<dim>::~CHNSE(){}

/**
  set_dt
  * calls set_dt on all the subproblems
**/
template<int dim> void CHNSE<dim>::set_dt( const double _dt ){
  ch.set_dt( _dt );
  vel.set_dt( _dt );
  pres.set_dt( _dt );
}

/**
  init
    * calls init on all the subproblems
**/
template<int dim> void CHNSE<dim>::init(){
  ch.init();
  vel.init();
  pres.init();
}

/**
  size
    * returns the sum of the sizes of the subproblems
**/
template<int dim> unsigned CHNSE<dim>::size() const{
  return ch.size() + vel.size() + pres.size();
}

template<int dim> void CHNSE<dim>::solve( std::vector< AsFunction<dim> *> &ph_params, 
                                          std::vector<AsFunction<dim> *> v_params,
                                          std::vector< AsFunction<dim> *> &p_params, const double time ){
  double err_u = HUGE_VAL, err_ph = HUGE_VAL;
  unsigned its = 0;
  const unsigned MAX_ITS = 1;
  while( (++its <= MAX_ITS  ) and ( std::max( err_ph, err_u ) > 1e-3 ) ){
    v_cout<<" Solving for Phase Field . . ."<<std::endl;
    ch.solve( ph_params, time );
    v_cout<<" Solving for Velocity . . ."<<std::endl;
    vel.solve( v_params, time );
    err_ph = ch.it_error();
    err_u = vel.it_error();
    v_cout<<" Error_ph = "<<err_ph<<" Error_u = "<<err_u<<std::endl;
  };
  if( its >= MAX_ITS )
    v_cout<<" Weird..."<<std::endl;
  ch.advance();
  vel.advance();
  v_cout<<" Solving for Pressure . . ."<<std::endl;
  pres.solve( p_params, time );
}

template class CHNSE<DIM>;
