#include "xfe_values.h"


//The constructor is just the standard constructor of the FEValues. We just make sure that the update_quadrature_flag is called, as we need it to calculated the additional shape functions.
template<int dim,int spacedim>
XFEValues_strong<dim,spacedim>::XFEValues_strong(const FiniteElement<dim, spacedim> &fe,
                                                 Quadrature<dim> quadrature,
                                                 UpdateFlags update_flags)
  : FEValues<dim,spacedim>(fe,quadrature,update_flags | update_quadrature_points),
    _dofs_per_cell(fe.dofs_per_cell),
    _n_q_points(quadrature.size()),
    _update_flags(update_flags)
{
}

//During the reinit function, the values of all shape functions in all quadrature points are saved in a table.
template <int dim,int spacedim>
void XFEValues_strong<dim,spacedim>::reinit(const typename hp::DoFHandler<dim,spacedim>::cell_iterator &cell)
{
  FEValues<dim,spacedim>::reinit(cell);

  _shape_values.reinit(_dofs_per_cell,_n_q_points);
  _shape_gradients.reinit(_dofs_per_cell,_n_q_points);

  if (FEValues<dim,spacedim>::update_flags & update_values)
    {
      for (unsigned int i=0; i<_dofs_per_cell; ++i)
        {
// The system_to_component_index is a pair containing the component of the current DoF and the index of the shape function of the current DoF. We use it to distinguish between standard and extended DoFs using its first element.
// In the standard part of the solution we just use the standard FEValues.
          if (cell->get_fe().system_to_component_index(i).first == 0)
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              _shape_values(i,q_point)=FEValues<dim,spacedim>::shape_value(i,q_point);
// In the extended part of the solution, we define the additional shape functions and its gradient.
          else
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              _shape_values(i,q_point)= FEValues<dim,spacedim>::shape_value(i,q_point)
                                        *( xfem.sign(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                           - xfem.sign(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))));
        }
    }
  if (FEValues<dim,spacedim>::update_flags & update_gradients)
    {
      for (unsigned int i=0; i<_dofs_per_cell; ++i)
        {
// The system_to_component_index is a pair containing the component of the current DoF and the index of the shape function of the current DoF. We use it to distinguish between standard and extended DoFs using its first element.
// In the standard part of the solution we just use the standard FEValues.
          if (cell->get_fe().system_to_component_index(i).first == 0)
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point);
// In the extended part of the solution, we define the additional shape functions and its gradient.
          else
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point)
                                          *( xfem.sign(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                             - xfem.sign(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))));
        }
    }
}

//When the shape_value or shape_grad function is called in the main program, we just need to read the values of the correct position in the table that was already calculated during reinit.
template <int dim,int spacedim>
double XFEValues_strong<dim,spacedim>::shape_value(const unsigned int function_no,
                                                   const unsigned int point_no)
{
  typedef FEValuesBase<dim,spacedim> FVB;
  Assert(_update_flags & update_values, typename FVB::ExcAccessToUninitializedField("update_values"));
  return _shape_values(function_no,point_no);
}

template <int dim,int spacedim>
Tensor<1,dim> XFEValues_strong<dim,spacedim>::shape_grad(const unsigned int function_no,
                                                         const unsigned int point_no)
{
  typedef FEValuesBase<dim,spacedim> FVB;
  Assert(_update_flags & update_gradients, typename FVB::ExcAccessToUninitializedField("update_gradients"));
  return _shape_gradients(function_no,point_no);
}

// In the case of XFEM the solution consists of two components, therefore the standard get_function_values calculates the solutions for each component separately and the get_function_gradients their gradients.
// The derived class XFEValues inherits the standard implementation from the base class FEValues which results in a wrong behaviour of the get_function_values and get_function_gradients for the XFEM case. In fact, these class members use only the standard shape functions. We include here an assert to avoid the wrong use of these functions. TODO: overload the get_function_values and get_function_gradients for the XFEM case.
template <int dim,int spacedim>
void XFEValues_strong<dim,spacedim>::get_function_values(const Vector<double> fe_function,
                                                         std::vector<Vector<double> > &values)
{
  // To suppress the compiler warning
  if (fe_function.size() && values.size()) {}
  std::cerr<<"get_function_values provides wrong results for XFEValues"<<std::endl;
  assert(0);
}

template <int dim,int spacedim>
void XFEValues_strong<dim,spacedim>::get_function_gradients(const Vector<double> fe_function,
                                                            std::vector<Vector<Tensor<1,dim> > > &values)
{
  // To suppress the compiler warning
  if (fe_function.size() && values.size()) {}
  std::cerr<<"get_function_gradients provides wrong results for XFEValues"<<std::endl;
  assert(0);
}


// XFEValues class for the weak discontinuity
template<int dim,int spacedim>
XFEValues_weak<dim,spacedim>::XFEValues_weak(const FiniteElement<dim, spacedim> &fe,
                                             Quadrature<dim> quadrature,
                                             UpdateFlags update_flags)
  : FEValues<dim,spacedim>(fe,quadrature,update_flags | update_quadrature_points),
    _dofs_per_cell(fe.dofs_per_cell),
    _n_q_points(quadrature.size()),
    _update_flags(update_flags)
{
}

template <int dim,int spacedim>
void XFEValues_weak<dim,spacedim>::reinit(const typename hp::DoFHandler<dim,spacedim>::cell_iterator &cell)
{
  FEValues<dim,spacedim>::reinit(cell);

  _shape_values.reinit(_dofs_per_cell,_n_q_points);
  _shape_gradients.reinit(_dofs_per_cell,_n_q_points);

// We need to distinguish between cut and blending cells in this class, as the additional shape functions are different whether a cell is cut or blending.
// We can distinguish these cells with the help of the interface_intersects_cell function, as blending cells are not intersected by the interface.
  if (xfem.interface_intersects_cell(cell))
    {
      for (unsigned int i=0; i<_dofs_per_cell; ++i)
        {
          if (cell->get_fe().system_to_component_index(i).first == 0)
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              {
                _shape_values(i,q_point)=FEValues<dim,spacedim>::shape_value(i,q_point);
                _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point);
              }
          else
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              {
                _shape_values(i,q_point)= FEValues<dim,spacedim>::shape_value(i,q_point)
                                          *( std::fabs(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                             - std::fabs(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))));
                _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point)
                                            *( std::fabs(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                               - std::fabs(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))))
                                            + FEValues<dim,spacedim>::shape_value(i,q_point)
                                            * xfem.sign(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                            * ls.grad_level_set(FEValues<dim,spacedim>::quadrature_point(q_point));
              }
        }
    }
  else
    {
// In the case of blending cells, the additional shape functions are multiplied by a ramp function, which is the sum over the shape functions of all standard DoFs belonging to intersected cells. The function indices_of_ramp determines these standard DoFs.
      std::vector<unsigned int> ramp_indices=xfem.indices_of_ramp(cell);

      std::vector<double> ramp_value(_n_q_points);
      std::vector<Tensor<1,dim> > ramp_grad(_n_q_points);
      for (unsigned int i=0; i<ramp_indices.size(); ++i)
        for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
          {
            ramp_value[q_point] += FEValues<dim,spacedim>::shape_value(ramp_indices[i],q_point);
            ramp_grad[q_point] += FEValues<dim,spacedim>::shape_grad(ramp_indices[i],q_point);
          }
      for (unsigned int i=0; i<_dofs_per_cell; ++i)
        {
          if (cell->get_fe().system_to_component_index(i).first == 0)
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              {
                _shape_values(i,q_point)=FEValues<dim,spacedim>::shape_value(i,q_point);
                _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point);
              }
          else
            for (unsigned int q_point=0; q_point<_n_q_points; ++q_point)
              {
                _shape_values(i,q_point)= FEValues<dim,spacedim>::shape_value(i,q_point)
                                          *( std::fabs(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                             - std::fabs(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))))
                                          * ramp_value[q_point];
                _shape_gradients(i,q_point)=FEValues<dim,spacedim>::shape_grad(i,q_point)
                                            *( std::fabs(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                               - std::fabs(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))))
                                            * ramp_value[q_point]
                                            + FEValues<dim,spacedim>::shape_value(i,q_point)
                                            * xfem.sign(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                            * ls.grad_level_set(FEValues<dim,spacedim>::quadrature_point(q_point))
                                            * ramp_value[q_point]
                                            + FEValues<dim,spacedim>::shape_value(i,q_point)
                                            *( std::fabs(ls.level_set(FEValues<dim,spacedim>::quadrature_point(q_point)))
                                               - std::fabs(ls.level_set(cell->vertex(cell->get_fe().system_to_component_index(i).second))))
                                            * ramp_grad[q_point];
              }
        }
    }
}

template <int dim,int spacedim>
double XFEValues_weak<dim,spacedim>::shape_value(const unsigned int function_no,
                                                 const unsigned int point_no)
{
  typedef FEValuesBase<dim,spacedim> FVB;
  Assert(_update_flags & update_values, typename FVB::ExcAccessToUninitializedField("update_values"));
  return _shape_values(function_no,point_no);
}

template <int dim,int spacedim>
Tensor<1,dim> XFEValues_weak<dim,spacedim>::shape_grad(const unsigned int function_no,
                                                       const unsigned int point_no)
{
  typedef FEValuesBase<dim,spacedim> FVB;
  Assert(_update_flags & update_gradients, typename FVB::ExcAccessToUninitializedField("update_gradients"));
  return _shape_gradients(function_no,point_no);
}

template <int dim,int spacedim>
void XFEValues_weak<dim,spacedim>::get_function_values(const Vector<double> fe_function,
                                                       std::vector<Vector<double> > &values)
{
  // To suppress the compiler warning
  if (fe_function.size() && values.size()) {}
  std::cerr<<"get_function_values provides wrong results for XFEValues"<<std::endl;
  assert(0);
}

template <int dim,int spacedim>
void XFEValues_weak<dim,spacedim>::get_function_gradients(const Vector<double> fe_function,
                                                          std::vector<Vector<Tensor<1,dim> > > &values)
{
  // To suppress the compiler warning
  if (fe_function.size() && values.size()) {}
  std::cerr<<"get_function_gradients provides wrong results for XFEValues"<<std::endl;
  assert(0);
}


template class XFEValues_strong<2>;
template class XFEValues_weak<2>;
