#include <config.h>

// iostream includes
#include <iostream>
#include <array>

//local includes
#include "datafunction.hh"
#include "phc.hh"

// include header of adaptive scheme
#include <dune/acfem/common/discretefunctionselector.hh>
#include <dune/acfem/functions/basicfunctions.hh>
#include <dune/acfem/models/basicmodels.hh>
#include <dune/acfem/models/modelexpression.hh>
#include <dune/acfem/algorithms/ellipticfemscheme.hh>
#include <dune/acfem/operators/l2projection.hh>
#include <dune/fem/space/common/adaptmanager.hh>
#include <dune/fem/quadrature/intersectionquadrature.hh>

// include norms
#include <dune/fem/misc/l2norm.hh>
#include <dune/fem/misc/h1norm.hh>

//quadrature
#include <dune/fem/quadrature/cachingquadrature.hh>

// include output
#include <dune/acfem/common/dataoutput.hh>

//include timer
#include <dune/fem/misc/femtimer.hh>


using namespace Dune::ACFem;

//local refinement of the grid
template<class HGridType>
void adaptGrid(HGridType &grid, int maxLevel, double markTolerance)
{
  //Function Space Types
  typedef Dune::Fem::FunctionSpace<double, double, HGridType::dimensionworld, 1> FunctionSpaceType;

  //the Forward function type
  typedef Dune::ACFem::DiscontinuousGalerkinLeafFunctionSelector<HGridType, FunctionSpaceType, 0> ForwardDiscreteFunctionSelector;
  typedef typename ForwardDiscreteFunctionSelector::DiscreteFunctionType ForwardDiscreteFunctionType;
  typedef typename ForwardDiscreteFunctionSelector::GridPartType GridPartType;
  typedef typename ForwardDiscreteFunctionSelector::DiscreteFunctionSpaceType ForwardDiscreteFunctionSpaceType;

  typedef typename ForwardDiscreteFunctionType :: RangeType RangeType;
  typedef typename ForwardDiscreteFunctionType :: LocalFunctionType :: DomainType DomainType;

  GridPartType gridPart(grid);

  const  int refineBisections = Dune::DGFGridInfo<HGridType>::refineStepsForHalf();

  //get g as a gridfunctionadapter - the DiscreteFunctionType as template parameter
#if NSM_USE_IMAGE
  //get the image filename
  const std::string filename = Dune::Fem::Parameter::getValue< std::string > ("nsm.image");
  //get g as a gridfunctionadapter - the DiscreteFunctionType as template parameter
  auto g = image<ForwardDiscreteFunctionType>(gridPart, filename, "image");
#else  
  auto g = data<ForwardDiscreteFunctionType>(gridPart, "image");
#endif
  //maxLevel is the maximum amount of local refinements
  for(int i=0 ; i < maxLevel ; ++i)
  {
    //iterate over grid
    for(auto it = gridPart.template begin<0>() ; it != gridPart.template end<0>();  ++it)
    {
      const auto & entity = *it;
      auto gLocal = g.localFunction(entity);
      RangeType gInside;
      //get the value at the center of the entity
      gLocal.evaluate(DomainType(0.5), gInside);
      //iterate over all neighbours
      for(auto iit = gridPart.ibegin(entity); iit != gridPart.iend(entity); ++iit)
      {
        const auto & intersection = *iit;
        if(!intersection.boundary() && intersection.neighbor())
        {
          const auto & neighbour = intersection.outside();
          auto gLocalNeighbour = g.localFunction(neighbour);
          RangeType gOutside;
          gLocalNeighbour.evaluate(DomainType(0.5),gOutside);
          gOutside -= gInside;
          //if difference to value at neighbour is bigger than tolerance,
          //mark element for refinement
          if(gOutside.two_norm() > markTolerance)
            grid.mark(refineBisections, entity);
        }
      }
    }
    //refine the grid
    grid.adapt();
    grid.postAdapt();
  }
}


//class to calcutate the energy that is minimized
template<class ForwardDiscreteFunctionType, class GridFunctionType>
class EnergyNorm
{

public:
  typedef typename ForwardDiscreteFunctionType::LocalFunctionType LocalFunctionType;
  typedef typename GridFunctionType::LocalFunctionType GridLocalFunctionType;
  typedef typename ForwardDiscreteFunctionType::GridPartType GridPartType;
  typedef typename GridPartType::template Codim<0>::EntityType EntityType;

  typedef typename GridLocalFunctionType::RangeType GridRangeType;
  typedef typename ForwardDiscreteFunctionType::RangeType RangeType;
  typedef typename ForwardDiscreteFunctionType::JacobianRangeType JacobianRangeType;
  typedef typename ForwardDiscreteFunctionType::DomainType DomainType;

  typedef Dune::Fem::CachingQuadrature< GridPartType, 0 > QuadratureType;

  typedef std::array<double, 3> EnergyType;

  //constructor gets g and the cost factors for L1 and L2 term
  EnergyNorm(const double lambda_1, const double lambda_2, const GridFunctionType &g)
    : lambda_1_(lambda_1), lambda_2_(lambda_2), g_(g)
  {}

  //calculating the energy
  EnergyType energy(ForwardDiscreteFunctionType & u )
  {
    EnergyType energy {0.,0.,0.};

    const auto end = u.gridPart().template end<0>();
    for(auto it = u.gridPart().template begin<0>(); it != end; ++it)
    {
      const EntityType & entity  = *it;
      LocalFunctionType uLocal = u.localFunction(entity);
      GridLocalFunctionType gLocal = g_.localFunction(entity);

      GridRangeType gValue(0.);
      gLocal.evaluate(DomainType(0.5), gValue);

      JacobianRangeType uJacobian(0.);
      uLocal.jacobian(DomainType(0.5), uJacobian);

      double absGradu = uJacobian[0].two_norm();

      //now quadrature loop

      // obtain quadrature order
      const int quadOrder = 2*uLocal.order() +1;

      QuadratureType quadrature(entity, quadOrder);
      const size_t numQuadraturePoints = quadrature.nop();
      for (size_t pt = 0; pt < numQuadraturePoints; ++pt)
      {
        const typename QuadratureType::CoordinateType &x = quadrature.point(pt);
        const double weight = quadrature.weight(pt) * entity.geometry().integrationElement(x);

        // compute the source contribution
        RangeType vu;
        uLocal.evaluate(quadrature[pt], vu);

        //if u is not continuous, we have to add jump contributions
        if(!Dune::Fem::Capabilities::isContinuous < typename ForwardDiscreteFunctionType::DiscreteFunctionSpaceType >::v)
        {
          typedef Dune::Fem::CachingQuadrature<GridPartType, 1> FaceQuadratureType;
          // use IntersectionQuadrature to create appropriate face quadratures
          typedef Dune::Fem::IntersectionQuadrature<FaceQuadratureType, true> IntersectionQuadratureType;
          typedef typename IntersectionQuadratureType::FaceQuadratureType Quadrature ;
          const auto iend = u.gridPart().iend(entity);
          //iterate over intersections
          for (auto iit = u.gridPart().ibegin(entity); iit != iend; ++iit) {
            const auto &intersection = *iit;
            if(intersection.neighbor() && ! intersection.boundary())
            {
              // get outside entity pointer
              const auto &outside = *intersection.outside();
              LocalFunctionType uOutside = u.localFunction(outside);
              IntersectionQuadratureType intersectionQuad(u.gridPart(), intersection, quadOrder);
              // get appropriate quadrature references
              const Quadrature &quadInside  = intersectionQuad.inside();
              const Quadrature &quadOutside = intersectionQuad.outside();
              const int numQuadraturePoints = intersectionQuad.nop();
              //iterate over intersectionquadrature
              for (int qp = 0; qp < numQuadraturePoints; ++qp)
              {
                // quadrature weight
                const double iweight = intersectionQuad.weight(qp);
                // Get the un-normalized outer normal
                DomainType integrationNormal
                  = intersection.integrationOuterNormal(intersectionQuad.localPoint(qp));
                //get integration element
                const double integrationElement = integrationNormal.two_norm();

                RangeType vInside, vOutside;
                uLocal.evaluate(quadInside[qp], vInside);
                uOutside.evaluate(quadOutside[qp],vOutside);
                //finally add jump contribution to TV term
                //halved because intersections are visited twice
                energy[0] += std::abs(vInside - vOutside) * iweight * integrationElement/ 2.;
              }
            }
          }
        }
        //add TV term
        energy[0] += weight * absGradu;

        //add L1 term
        energy[1] += weight * lambda_1_ * std::abs(vu - gValue);

        //add L2 term
        energy[2] += weight * lambda_2_ / 2. * (vu - gValue) * (vu -gValue);
      }  // end quad loop
    }
    return energy;
  }

private:
  const double lambda_1_, lambda_2_;
  const GridFunctionType & g_;
};

//the projection that comes from the L1 term of the functional.
//can be done dofwise as all functions are of the same type
template<class ForwardDiscreteFunctionType>
void dofwiseProjection(const double beta, const ForwardDiscreteFunctionType & z, const ForwardDiscreteFunctionType & projG, ForwardDiscreteFunctionType & u)
{
  const auto zEnd = z.dend();

  for(auto zIt = z.dbegin(), gIt = projG.dbegin(), uIt = u.dbegin(); zIt != zEnd; ++zIt, ++gIt, ++uIt)
  {
    if (*zIt -*gIt > beta)
      *uIt = *zIt -beta;
    else if (*zIt - *gIt < -1. * beta)
      *uIt = *zIt + beta;
    else
      *uIt =*gIt;
  }
}

//the complete algorithm
template<class HGridType>
void algorithm(HGridType &grid)
{
  //Function Space Types
  typedef Dune::Fem::FunctionSpace<double, double, HGridType::dimensionworld, 1> FunctionSpaceType;

  typedef Dune::Fem::FunctionSpace<double, double, HGridType::dimensionworld, HGridType::dimensionworld> AdjointFunctionSpaceType;

  //the Forward discrete function type
#if NSM_U_DISCONT
  typedef Dune::ACFem::DiscontinuousGalerkinLeafFunctionSelector<HGridType, FunctionSpaceType, NSM_U_POLORDER> ForwardDiscreteFunctionSelector;
#else
  typedef Dune::ACFem::LagrangeLeafFunctionSelector<HGridType, FunctionSpaceType, NSM_U_POLORDER> ForwardDiscreteFunctionSelector;
#endif
  typedef typename ForwardDiscreteFunctionSelector::DiscreteFunctionType ForwardDiscreteFunctionType;
  typedef typename ForwardDiscreteFunctionSelector::GridPartType GridPartType;
  typedef typename ForwardDiscreteFunctionSelector::DiscreteFunctionSpaceType ForwardDiscreteFunctionSpaceType;

  //The adjoint discrete function type
#if NSM_P_DISCONT
  typedef Dune::ACFem::DiscontinuousGalerkinLeafFunctionSelector<HGridType, AdjointFunctionSpaceType, NSM_P_POLORDER> AdjointDiscreteFunctionSelector;
#else
  typedef Dune::ACFem::LagrangeLeafFunctionSelector<HGridType, AdjointFunctionSpaceType, NSM_P_POLORDER> AdjointDiscreteFunctionSelector;
#endif
  typedef typename AdjointDiscreteFunctionSelector::DiscreteFunctionType AdjointDiscreteFunctionType;
  typedef typename AdjointDiscreteFunctionSelector::DiscreteFunctionSpaceType AdjointDiscreteFunctionSpaceType;

  //the norm for the stopping criteria
  typedef Dune::Fem::L2Norm< GridPartType > NormType;

  //type of energy output - to get it from energynorm
  typedef std::array<double, 3> EnergyType;

  GridPartType gridPart(grid);

  //the Forward discrete Space
  ForwardDiscreteFunctionSpaceType forwardDiscreteSpace(gridPart);
  //the Forward functions we need for the algorithm
  ForwardDiscreteFunctionType u("u", forwardDiscreteSpace);
  ForwardDiscreteFunctionType uOld("uOld", forwardDiscreteSpace);
  ForwardDiscreteFunctionType uBar("uBar", forwardDiscreteSpace);
  ForwardDiscreteFunctionType z("z", forwardDiscreteSpace);
  ForwardDiscreteFunctionType projG("Projection of g", forwardDiscreteSpace);

  //initialize to zero
  u.clear();
  uOld.clear();
  uBar.clear();
  z.clear();
  projG.clear();

  //the piecewise constant space
  AdjointDiscreteFunctionSpaceType adjointDiscreteSpace(gridPart);
  //the piecewise constant functions
  AdjointDiscreteFunctionType p("p", adjointDiscreteSpace);

  //initialize to zero
  p.clear();

  //initialize the norm for the stopping criteria with quadrature of order 2
  NormType norm(gridPart, 2*forwardDiscreteSpace.order());


  const double theta = Dune::Fem::Parameter::getValue< double > ("nsm.theta",1.);

#if NSM_SET_STEPSIZE
  const double tau   = Dune::Fem::Parameter::getValue< double > ("nsm.tau",1.);
  const double sigma = Dune::Fem::Parameter::getValue< double > ("nsm.sigma",1.);
#else
  const double constantL = Dune::Fem::Parameter::getValue< double >("nsm.constantL", 1);
  const int maxLevel = grid.maxLevel();
  const double h =  1./std::pow(2,maxLevel);
  const double sigma = constantL * h ;
  const double tau = sigma;
#endif


  //the number of iterations
  const int maxIt = Dune::Fem::Parameter::getValue< int > ("nsm.maxIt",100);
  const double tolerance = Dune::Fem::Parameter::getValue< double > ("nsm.tolerance",1e-10);
  const int outputStep = Dune::Fem::Parameter::getValue< int > ("nsm.outputStep",100);

  //get problem parameters - no defauilt value, has to be set in parameter file
  const double lambda_1 = Dune::Fem::Parameter::getValue< double > ("nsm.lambda_1");
  const double lambda_2 = Dune::Fem::Parameter::getValue< double > ("nsm.lambda_2");

  //calculate beta
  const double beta = tau * lambda_1 / (1. + tau * lambda_2);

  //get ProjectionHelperclass
  ProjectionHelperClass< ForwardDiscreteFunctionType,  AdjointDiscreteFunctionType, Dune::Fem::Capabilities::isContinuous < AdjointDiscreteFunctionSpaceType >::v, Dune::Fem::Capabilities::isContinuous < ForwardDiscreteFunctionSpaceType >::v  > phc(projG, sigma, tau, lambda_2);


#if NSM_USE_IMAGE
  //get the image filename
  const std::string filename = Dune::Fem::Parameter::getValue< std::string > ("nsm.image");
  std::cout << filename << std::endl;
#if NSM_USE_NOISE
  const double noiseLevel1 = Dune::Fem::Parameter::getValue< double > ("nsm.noiseLevel1",-5);
  const int noiseType1 = Dune::Fem::Parameter::getValue< int > ("nsm.noiseType1",0);
  const double noiseLevel2 = Dune::Fem::Parameter::getValue< double > ("nsm.noiseLevel2",10);
  const int noiseType2 = Dune::Fem::Parameter::getValue< int > ("nsm.noiseType2",2);
  //get g as a gridfunctionadapter - the DiscreteFunctionType as template parameter
  auto g = noiseImage<ForwardDiscreteFunctionType>(gridPart, filename, noiseLevel1, noiseType1, noiseLevel2, noiseType2, "image");
#else 
  //get g as a gridfunctionadapter - the DiscreteFunctionType as template parameter
  auto g = image<ForwardDiscreteFunctionType>(gridPart, filename, "image");
#endif //NSM_USE_NOISE

#else // NSM_USE_IMAGE
  //get g as a gridfunctionadapter - the DiscreteFunctionType as template parameter
  auto g = data<ForwardDiscreteFunctionType>(gridPart, "image");
#endif //NSM_USE_IMAGE

  //! type of input/output tuple
  typedef std::tuple<ForwardDiscreteFunctionType *, const ForwardDiscreteFunctionType *, AdjointDiscreteFunctionType *,const decltype(g) *>
  IOTupleType;

  //! type of data writer
  typedef DataOutput<HGridType, IOTupleType> DataOutputType;

  // io tuple
  IOTupleType io(&u,&projG,&p,&g);
  //initialize data output
  DataOutputType dataOutput(grid, io, DataOutputParameters());


  //calculate projection of g onto Forward space
  L2Projection(g, projG);

  //one could start with projG as u
  // u.assign(projG);

  //initialize uBar to u
  uBar.assign(u);

  //initialize EnergyNorm with g
  EnergyNorm<ForwardDiscreteFunctionType, decltype(g)> energyNorm(lambda_1, lambda_2, g);
  //can also easily be initialized from projG
  EnergyNorm<ForwardDiscreteFunctionType, decltype(projG)> projectionEnergyNorm(lambda_1, lambda_2, projG);



  //start timer
  Dune::Fem::Timer<true>::start();

  std::cout << "\t \t Energy of g \t \t \t \t, \t \t Energy of Pg" << std::endl;
  std::cout << " TVg \t , \t L1g \t , \t L2g \t, \t Sumg \t, \t TVPg \t, \t L1Pg \t, \t L2Pg \t, \t SumPg" << std::endl;

  //now the main part of the algorithm
  for(int i = 0; i< maxIt; ++i)
  {
    //update p
    phc.entitywiseProjection(uBar, p);

    //calculateZ
    phc.calculateZ(p, uOld, z);

    //update u
    dofwiseProjection(beta, z, projG, u);

    //output (and calculate energyn Norm) only every n-th step
    if( i%outputStep == 0)
    {
      // write data
      dataOutput.write();
      EnergyType energy = energyNorm.energy(u);
      EnergyType projEnergy = projectionEnergyNorm.energy(u);
      std::cout << energy[0] << " ,\t" << energy[1] << " ,\t" << energy[2] << " ,\t" << energy[0] + energy[1] + energy[2] << " ,\t"
                << projEnergy[0] << " ,\t" << projEnergy[1] << " ,\t" << projEnergy[2] << " ,\t" << projEnergy[0] + projEnergy[1] + projEnergy[2] << std::endl;
    }
    //stopping criteria
    if(norm.distance(u,uOld) / norm.norm(uOld) < tolerance )
    {
      EnergyType energy = energyNorm.energy(u);
      EnergyType projEnergy = projectionEnergyNorm.energy(u);
      std::cout << energy[0] << " ,\t" << energy[1] << " ,\t" << energy[2] << " ,\t" << energy[0] + energy[1] + energy[2] << " ,\t"
                << projEnergy[0] << " ,\t" << projEnergy[1] << " ,\t" << projEnergy[2] << " ,\t" << projEnergy[0] + projEnergy[1] + projEnergy[2] << std::endl;
      std::cout << "Number of iterations: " << i << std::endl;
      std::cout << "Time Elapsed: " << Dune::Fem::Timer<true>::stop() << std::endl;
      break;
    }
    //update uBar
    uBar.assign(u);
    uBar *= 1+theta;
    uBar.axpy((-1. * theta), uOld);

    //update uOld
    uOld.assign(u);
  }

  dataOutput.write();

}

// main
// ----

int main (int argc, char **argv)
try
{
  // initialize MPI, if necessary
  Dune::Fem::MPIManager::initialize(argc, argv);

  std::cerr << "MPI rank " << Dune::Fem::MPIManager::rank() << std::endl;

  // append overloaded parameters from the command line
  Dune::Fem::Parameter::append(argc, argv);

  // append possible given parameter files
  for (int i = 1; i < argc; ++i)
    Dune::Fem::Parameter::append(argv[ i ]);

  // append default parameter file
  Dune::Fem::Parameter::append("../data/parameter");

  // type of hierarchical grid
  typedef Dune::GridSelector::GridType HGridType ;

  // create grid from DGF file
  const std::string gridkey = Dune::Fem::IOInterface::defaultGridKey(HGridType::dimension);
  const std::string gridfile = Dune::Fem::Parameter::getValue<std::string>(gridkey);

  // the method rank and size from MPIManager are static
  if (Dune::Fem::MPIManager::rank() == 0)
    std::cout << "Loading macro grid: " << gridfile << std::endl;

  // construct macro using the DGF Parser
  Dune::GridPtr< HGridType > gridPtr(gridfile);
  HGridType& grid = *gridPtr ;

  // do initial load balance
  grid.loadBalance();

  // initial grid refinement
  const int level = Dune::Fem::Parameter::getValue< int >("nsm.initialRefinements",4);

  // number of global refinements to bisect grid width
  const int refineStepsForHalf = Dune::DGFGridInfo< HGridType >::refineStepsForHalf();

  // refine grid
  grid.globalRefine(level * refineStepsForHalf);

  //get adaptation parameters
  int numAdaptations = Dune::Fem::Parameter::getValue< int > ("nsm.localRefine",5);
  double adaptTolerance = Dune::Fem::Parameter::getValue< double > ("nsm.adaptTolerance",0.1);  
  
  //locally refine grid
  adaptGrid<HGridType>(grid, numAdaptations, adaptTolerance);

  // let it go ... quasi adapt_method_stat()
  algorithm(grid);


  return 0;
}
catch(const Dune::Exception &exception)
{
  std::cerr << "Error: " << exception << std::endl;
  return 1;
}
