From 16bcbcc58c7a4aeddcd73bb43966e8ed0c225978 Mon Sep 17 00:00:00 2001
From: Florian Atteneder <florian.atteneder@uni-jena.de>
Date: Wed, 21 Aug 2024 15:18:04 +0200
Subject: [PATCH] wip

---
 grhd_tov_spherical1d.toml |  40 +++++++++++++
 src/GRHD/callbacks.jl     |  15 +++++
 src/GRHD/rhs.jl           | 122 +++++++++++++++++++++++++++++++-------
 src/GRHD/setup.jl         |  12 +++-
 src/GRHD/spherical1d.jl   |   6 +-
 src/callbacks.jl          |   4 +-
 src/dg_rhs.jl             |  30 +++++-----
 src/dgelement.jl          |  12 ++--
 src/mesh.jl               |  42 +++++++++----
 9 files changed, 225 insertions(+), 58 deletions(-)
 create mode 100644 grhd_tov_spherical1d.toml

diff --git a/grhd_tov_spherical1d.toml b/grhd_tov_spherical1d.toml
new file mode 100644
index 00000000..82b9ae45
--- /dev/null
+++ b/grhd_tov_spherical1d.toml
@@ -0,0 +1,40 @@
+[EquationOfState]
+polytrope_gamma = 2.0
+polytrope_k = 100.0
+eos = "polytrope"
+
+[Evolution]
+cfl = 0.1
+tend = 1000.0
+
+[Output]
+# aligned_ts = "$(collect(range(1.0,1000.0,step=1.0)))"
+every_iteration = 1
+variables1d = [ "D", "Sr", "Ï„",
+                "rhs_D", "rhs_Sr", "rhs_Ï„",
+                "D_modal", "Sr_modal", "τ_modal",
+                "rhs_D_modal", "rhs_Sr_modal", "rhs_τ_modal",
+                "c2p_reset_atm", "c2p_freeze_atm", "c2p_near_surface", "c2p_failed"]
+interpolate_every_iteration = 1
+interpolate_nodes = """$(collect(range(-20.0,20.0,step=0.05)))"""
+interpolate_variables = ["D", "Sr", "Ï„", "rhs_D", "rhs_Sr", "rhs_Ï„", "c2p_reset_atm", "c2p_freeze_atm" ]
+
+[GRHD]
+c2p_set_atmosphere_on_failure = true
+atm_threshold_factor = 100.0
+id_filename = "$(joinpath(ROOTDIR,\"initialdata\",\"TOV_stable.h5\"))"
+atm_factor = 1.0e-8
+id = "tov"
+formulation = "spherical1d"
+bc = "tov_symmetric_domain"
+freeze_vars = ["Ï„"]
+atm_evolve = false
+c2p_enforce_causal_atm = true
+c2p_enforce_atm = false
+
+[Mesh]
+periodic = false
+range = [-20.0, 20.0]
+k = 23
+basis = "lgl"
+n = 5
diff --git a/src/GRHD/callbacks.jl b/src/GRHD/callbacks.jl
index 66cf7115..1ffb9581 100644
--- a/src/GRHD/callbacks.jl
+++ b/src/GRHD/callbacks.jl
@@ -84,6 +84,7 @@ function compute_atmosphere_mask(env, P, mesh::Mesh1d{DGElement})
   if P.prms.c2p_dynamic_atm
     broadcast_volume!(determine_atmosphere, P.equation, env.mesh)
   end
+
   # this controls the effect of the stop_atmosphere_evolution_* methods
   # which trigger when c2p_reset_atm > 0.0 && c2p_freeze_atm > 0.0 at a point
   if enforce
@@ -92,6 +93,20 @@ function compute_atmosphere_mask(env, P, mesh::Mesh1d{DGElement})
   elseif enforce_causal
     update_atm_domain_of_dependence!(env.mesh)
   end
+
+
+  @unpack c2p_reset_atm, c2p_freeze_atm = get_static_variables(mesh)
+  @unpack c2p_near_surface = get_cell_variables(mesh)
+  # 0. determine c2p_near_surface
+  fill!(c2p_near_surface, 0.0)
+  for (k,(v_reset,v_freeze)) in enumerate(zip(eachcell(mesh,c2p_reset_atm),
+                                              eachcell(mesh,c2p_freeze_atm)))
+    # if any(>(0.0), v_reset) && any(==(0.0), v_freeze)
+    if any(==(0.0), v_freeze)
+      c2p_near_surface[k] = 1.0
+    end
+  end
+
 end
 
 function update_atm_domain_of_dependence!(mesh::Mesh1d)
diff --git a/src/GRHD/rhs.jl b/src/GRHD/rhs.jl
index 71bed973..3e7c20fb 100644
--- a/src/GRHD/rhs.jl
+++ b/src/GRHD/rhs.jl
@@ -1241,13 +1241,15 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
   @unpack c2p_near_surface = get_cell_variables(cache)
   @unpack c2p_reset_atm, c2p_freeze_atm = get_static_variables(cache)
 
-  # 0. determine c2p_near_surface
-  for (k,(v_reset,v_freeze)) in enumerate(zip(eachcell(mesh,c2p_reset_atm),
-                                              eachcell(mesh,c2p_freeze_atm)))
-    if any(>=(0.0), v_reset) && any(>=(0.0), v_freeze)
-      c2p_near_surface[k] = 1.0
-    end
-  end
+  # # 0. determine c2p_near_surface
+  # fill!(c2p_near_surface, 0.0)
+  # for (k,(v_reset,v_freeze)) in enumerate(zip(eachcell(mesh,c2p_reset_atm),
+  #                                             eachcell(mesh,c2p_freeze_atm)))
+  #   # if any(>(0.0), v_reset) && any(==(0.0), v_freeze)
+  #   if any(==(0.0), v_freeze)
+  #     c2p_near_surface[k] = 1.0
+  #   end
+  # end
 
   # 1. map nodal to modal coefficients near surface
   Npts = mesh.element.Npts
@@ -1256,38 +1258,61 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
     cidxs = cellindices(mesh, k)
     for (var,var_modal) in ((D,D_modal), (Sr,Sr_modal), (τ,τ_modal))
       v_var       = view(var, cidxs)
+      vv_var      = view(var, 2:Npts-1)
       v_var_modal = view(var_modal, cidxs)
-      vv_var       = view(v_var, 2:Npts-1)
       vv_var_modal = view(v_var_modal, 1:Npts-1)
-      # compute modal coefficients from nodal data
-      @show size(vv_var_modal), size(P.prms.Tavg), size(vv_var)
-      mul!(vv_var_modal, P.prms.Tavg, vv_var)
+      # only compute 'cell interior' modal coefficients from nodal data
+      mul!(vv_var_modal, P.prms.Tavg, v_var)
+      # mul!(vv_var_modal, P.prms.pinvVsmpl, v_var)
     end
   end
 
   # 2. resample bspline to collocation points
+  niter = 0
   @label resample_bspline
+  niter += 1
+  if niter > 10
+    error("too many iterations in bspline adjustment")
+  end
+  @unpack r = get_static_variables(mesh)
   for (k,isnear) in enumerate(c2p_near_surface)
     isnear > 0.0 || continue
     cidxs = cellindices(mesh, k)
+    # ii = 1
     for (var,var_modal) in ((D,D_modal), (Sr,Sr_modal), (τ,τ_modal))
       v_var       = view(var, cidxs)
       v_var_modal = view(var_modal, cidxs)
       vv_var       = view(v_var, 2:Npts-1)
       vv_var_modal = view(v_var_modal, 1:Npts-1)
-      # sample Bspline Ansatz on collocation points
+      # tmp_modal = P.prms.workspace.buf1
+      # tmp_modal[1] = v_var[1]; tmp_modal[end] = v_var[end]
+      # @views tmp_modal[2:end-1] .= vv_var_modal
+      # mul!(vv_var, P.prms.Vsmpl, tmp_modal)
       mul!(vv_var, P.prms.Vsmpl, vv_var_modal)
+      # if ii == 1
+      #   @info k, Main.N, "before c2p"
+      #   v_r = view(r, cidxs)
+      #   @show v_r
+      #   # @show v_var
+      #   @show vv_var
+      #   # @show v_var_modal
+      #   @show vv_var_modal
+      # end
+      # ii += 1
     end
   end
+  # if Main.N > 1
+  #   error()
+  # end
+  # Main.N += 1
 
   # 3. cons2prim
   broadcast_volume!(cons2prim_spherical1d_freeze_flags, equation, mesh)
 
-  # 4. if cons2prim succeeded everywhere continue to 6
-  # 5. map nodal to modal coefficients and backpropagate any atmosphere resets by locally adjusting
+  # 4. map nodal to modal coefficients and backpropagate any atmosphere resets by locally adjusting
   #    modal coefficients, then go to 2.
   c2p_success = true
-  @unpack c2p_failed = get_static_variables(cache)
+  @unpack r, c2p_failed = get_static_variables(cache)
   for (k,v_c2p_failed) in enumerate(eachcell(mesh, c2p_failed))
     cidxs = cellindices(mesh, k)
     nodal_rng = 2:Npts-1
@@ -1298,20 +1323,29 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
     v_D_modal  = view(D_modal, modal_rng)
     v_Sr_modal = view(Sr_modal, modal_rng)
     v_τ_modal  = view(τ_modal, modal_rng)
+    v_r = view(r, cidxs)
     isnear = c2p_near_surface[k]
     for (i,failed) in enumerate(v_c2p_failed)
       failed > 0.0 || continue
-      isnear > 0.0 || error("c2p failed not near the surface, this should not have happened!")
+      @show Main.N, v_r[i]
+      (i == 1 || i == length(v_c2p_failed)) && TODO()
+      isnear > 0.0 || error("c2p failed at r=$(v_r[i]) which is not near the surface, this should not have happened!")
+      TODO()
       c2p_success = false
       for (v_var, v_var_modal) in zip((v_D,v_D_modal),
                                       (v_Sr,v_Sr_modal),
                                       (v_τ,v_τ_modal))
+
         for j in 1:size(P.prms.pinvVsmpl,2)
           v_var_modal[i] = P.prms.pinvVsmpl[i,j] * v_var[j]
         end
       end
     end
   end
+  Main.N += 1
+
+  # 5. if cons2prim succeeded everywhere we did not backpropagate any changes
+  #    in this case continue, otherwise redo cons2prim
   !c2p_success && @goto resample_bspline
 
   # 6. evaluate fluxes and sources
@@ -1358,39 +1392,72 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
   for (k,isnear) in enumerate(c2p_near_surface)
     isnear > 0.0 || continue
     cidxs = cellindices(mesh, k)
+    # ii = 1
     for (var,var_modal) in ((flr_D,flr_D_modal), (flr_Sr,flr_Sr_modal), (flr_τ,flr_τ_modal),
                             (src_D,src_D_modal), (src_Sr,src_Sr_modal), (src_τ,src_τ_modal))
       v_var       = view(var, cidxs)
       v_var_modal = view(var_modal, cidxs)
-      vv_var       = view(v_var, 2:Npts-1)
       vv_var_modal = view(v_var_modal, 1:Npts-1)
-      mul!(vv_var_modal, P.prms.pinvVsmpl, vv_var)
+      # only compute 'cell interior' modal coefficients from nodal data
+      mul!(vv_var_modal, P.prms.Tavg, v_var)
+      # if ii == 1
+      #   @info k, Main.N, "after c2p"
+      #   @show v_var
+      #   @show v_var_modal
+      #   @show vv_var_modal
+      # end
+      # ii += 1
     end
   end
+  # if Main.N > 1
+  #   error()
+  # end
+  # Main.N += 1
 
   # 8. compute RHSs of nodal and modal method
   if :D ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_D, flr_D, src_D, nflr_D, mesh)
     compute_rhs_weak_form!(rhs_D,       flr_D,       src_D,
                            rhs_D_modal, flr_D_modal, src_D_modal,
-                           nflr_D, mesh, c2p_near_surface, P.prms.workspace)
+                           nflr_D, mesh, P.prms.bspline_element, c2p_near_surface, P.prms.workspace)
   end
   if :Sr ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_Sr, flr_Sr, src_Sr, nflr_Sr, mesh)
     compute_rhs_weak_form!(rhs_Sr,        flr_Sr,       src_Sr,
                            rhs_Sr_modal,  flr_Sr_modal, src_Sr_modal,
-                           nflr_Sr, mesh, c2p_near_surface, P.prms.workspace)
+                           nflr_Sr, mesh, P.prms.bspline_element, c2p_near_surface, P.prms.workspace)
   end
   if :τ ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_Ï„, flr_Ï„, src_Ï„, nflr_Ï„, mesh)
     compute_rhs_weak_form!(rhs_Ï„,        flr_Ï„,       src_Ï„,
                            rhs_τ_modal,  flr_τ_modal, src_τ_modal,
-                           nflr_Ï„, mesh, c2p_near_surface, P.prms.workspace)
+                           nflr_Ï„, mesh, P.prms.bspline_element, c2p_near_surface, P.prms.workspace)
   end
 
   # 9. impose static atmosphere
   broadcast_volume!(determine_atmosphere, P.equation, mesh)
   broadcast_volume!(stop_atmosphere_evolution_spherical1d, P.equation, mesh)
+  for (k,isnear) in enumerate(c2p_near_surface)
+    isnear > 0.0 || continue
+    cidxs = cellindices(mesh, k)
+    v_reset_atm = view(c2p_reset_atm, cidxs)
+    v_freeze_atm = view(c2p_freeze_atm, cidxs)
+    vv_reset_atm = view(v_reset_atm, 2:Npts-1)
+    vv_freeze_atm = view(v_freeze_atm, 2:Npts-1)
+    v_rhs_D_modal = view(rhs_D_modal, cidxs)
+    v_rhs_Sr_modal = view(rhs_Sr_modal, cidxs)
+    v_rhs_τ_modal = view(rhs_τ_modal, cidxs)
+    vv_rhs_D_modal = view(v_rhs_D_modal, 1:Npts-1)
+    vv_rhs_Sr_modal = view(v_rhs_Sr_modal, 1:Npts-1)
+    vv_rhs_τ_modal = view(v_rhs_τ_modal, 1:Npts-1)
+    for (i,(rst,frz)) in enumerate(zip(vv_reset_atm,vv_freeze_atm))
+      if rst > 0.0 && frz > 0.0
+        vv_rhs_D_modal[i]  = vv_rhs_D_modal[i+1]  = 0.0
+        vv_rhs_Sr_modal[i] = vv_rhs_Sr_modal[i+1] = 0.0
+        vv_rhs_τ_modal[i]  = vv_rhs_τ_modal[i+1]  = 0.0
+      end
+    end
+  end
 
   # if !P.prms.atm_evolve
   #   broadcast_volume!(determine_atmosphere, P.equation, mesh)
@@ -1402,11 +1469,24 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
     isnear > 0.0 || continue
     cidxs = cellindices(mesh, k)
     for (var,var_modal) in ((rhs_D,rhs_D_modal), (rhs_Sr,rhs_Sr_modal), (rhs_τ,rhs_τ_modal))
+      # v_var       = view(var, cidxs)
+      # v_var_modal = view(var_modal, cidxs)
+      # vv_var       = view(v_var, 2:Npts-1)
+      # vv_var_modal = view(v_var_modal, 1:Npts-1)
+      # mul!(vv_var, P.prms.Vsmpl, vv_var_modal)
       v_var       = view(var, cidxs)
       v_var_modal = view(var_modal, cidxs)
       vv_var       = view(v_var, 2:Npts-1)
       vv_var_modal = view(v_var_modal, 1:Npts-1)
+      # sample Bspline Ansatz on collocation points
+      tmp_modal = P.prms.workspace.buf1
+      # tmp_modal[1] = v_var[1]; tmp_modal[end] = v_var[end]
+      # @views tmp_modal[2:end-1] .= vv_var_modal
       mul!(vv_var, P.prms.Vsmpl, vv_var_modal)
+      # mul!(vv_var, P.prms.Vsmpl, vv_var_modal)
+      # @show size(vv_var), size(P.prms.pinvTavg), size(vv_var_modal)
+      # mul!(v_var, P.prms.pinvTavg, vv_var_modal)
+      @show v_var, vv_var_modal
     end
   end
 
diff --git a/src/GRHD/setup.jl b/src/GRHD/setup.jl
index cfaff2ea..c2c1c3eb 100644
--- a/src/GRHD/setup.jl
+++ b/src/GRHD/setup.jl
@@ -51,10 +51,16 @@ function Project(env::Environment, prms)
   end
   pinvTavg = pinv(Tavg)
   display(Tavg)
-  Vsmpl = dg1d.Bspline.vandermonde_matrix(bspline_element.z[2:end-1], bspline_element.bs)
+  display(pinvTavg)
+  Vsmpl = dg1d.Bspline.vandermonde_matrix(mesh.element.z[2:end-1], bspline_element.bs)[:,2:end-1]
   pinvVsmpl = pinv(Vsmpl)
-  workspace = (tmp_rhs=zeros(Float64,mesh.element.Npts+1),
-               tmp_f=zeros(Float64,mesh.element.Npts+1))
+  # display(Vsmpl)
+  # display(pinvVsmpl)
+  # # display(Vsmpl*pinvVsmpl)
+  # # display(pinvVsmpl*Vsmpl)
+  # error()
+  workspace = (buf1=zeros(Float64,mesh.element.Npts+1),
+               buf2=zeros(Float64,mesh.element.Npts+1))
   fixedprms = (; av_regularization=:covariant, id_smooth=true,
                  bernstein, slope_limiter_method, slope_limiter_tvb_M,
                  c2p_dynamic_atm, atm_evolve, atm_equalize_on_interface,
diff --git a/src/GRHD/spherical1d.jl b/src/GRHD/spherical1d.jl
index b2eda01c..9f262ae9 100644
--- a/src/GRHD/spherical1d.jl
+++ b/src/GRHD/spherical1d.jl
@@ -433,11 +433,11 @@ end
 
 
 @with_signature function stop_atmosphere_evolution_spherical1d(eq::Equation)
-  @accepts rhs_D, rhs_Sr, rhs_τ, rhs_D_modal, rhs_Sr_modal, rhs_τ_modal, c2p_reset_atm, c2p_freeze_atm
+  @accepts rhs_D, rhs_Sr, rhs_Ï„, c2p_reset_atm, c2p_freeze_atm
   if c2p_reset_atm > 0 && c2p_freeze_atm > 0
-    rhs_D = rhs_Sr = rhs_τ = rhs_D_modal = rhs_Sr_modal = rhs_τ_modal = 0.0
+    rhs_D = rhs_Sr = rhs_Ï„ = 0.0
   end
-  @returns rhs_D, rhs_Sr, rhs_τ, rhs_D_modal, rhs_Sr_modal, rhs_τ_modal
+  @returns rhs_D, rhs_Sr, rhs_Ï„
 end
 
 
diff --git a/src/callbacks.jl b/src/callbacks.jl
index f1c99d86..1624dc4d 100644
--- a/src/callbacks.jl
+++ b/src/callbacks.jl
@@ -627,6 +627,7 @@ mutable struct InterpolationCallback <: AbstractSaveCallback
     h5group_idx = 0 # use internal iteration index to keep contiguous group name increments
     savefn = let h5group_idx=h5group_idx, save_groups_vars=save_groups_vars,
                  interpolator=interpolator, intrp_buf=intrp_buf
+      @unpack c2p_near_surface = get_cell_variables(mesh)
       function savefn(u, t)
         # generate group for new time step
         h5group_idx += 1
@@ -635,7 +636,8 @@ mutable struct InterpolationCallback <: AbstractSaveCallback
         for (group, var) in save_groups_vars
           name = string(var)
           v = get_variable(mesh.cache, var, group)
-          interpolator(intrp_buf, v)
+          # interpolator(intrp_buf, v, c2p_near_surface)
+          interpolator(intrp_buf, v, mask=c2p_near_surface)
           h5group[name] = intrp_buf
         end
         # important: use flush to not loose data
diff --git a/src/dg_rhs.jl b/src/dg_rhs.jl
index ae5e4876..58a5d3b0 100644
--- a/src/dg_rhs.jl
+++ b/src/dg_rhs.jl
@@ -30,45 +30,45 @@ function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{DGElement})
   return
 end
 function compute_rhs_weak_form!(rhs, f, rhs_modal, f_modal, nf, mesh::Mesh1d{DGElement},
-                                mask::AbstractVector, workspace)
-  @unpack element = mesh
-  @unpack invM, MDM, Ml_lhs, Ml_rhs, Npts = element
+                                element::DGElement, mask::AbstractVector, workspace)
+  @unpack Npts = mesh.element
+  @unpack invM, MDM, Ml_lhs, Ml_rhs = element
   # @unpack K = mesh
   K = mesh.tree.dims[1]
-  @unpack invdetJ = get_static_variables(mesh.cache)
-  @unpack tmp_rhs, tmp_f = workspace
+  @unpack invdetJ = get_static_variables(mesh)
+  tmp_rhs, tmp_f = workspace.buf1, workspace.buf2
   invJ = invdetJ[1]
   nf_lhs = view(nf, 1:2:2*K)
   nf_rhs = view(nf, 2:2:2*K)
   for k in 1:K
     mask[k] > 0.0 || continue
     cidxs = cellindices(mesh, k)
-    rng_modal = 2:Npts-1
-    v_f       = view(f,   cidxs)
+    rng_modal = 1:Npts-1
+    v_f       = view(f, cidxs)
     v_f_modal = view(f_modal, cidxs)
-    @views tmp_f[rng_modal] .= v_f_modal[rng_modal]
+    @views tmp_f[1 .+ rng_modal] .= v_f_modal[rng_modal]
     tmp_f[1] = v_f[1]; tmp_f[end] = v_f[end]
     mul!(tmp_rhs, MDM, tmp_f)
-    @. tmp_rhs -= (nf_rhs[k] * Ml_rhs + #= minus contained in normal vector =# nf_lhs' * Ml_lhs)
+    @. tmp_rhs -= (nf_rhs[k] * Ml_rhs + #= minus contained in normal vector =# nf_lhs[k] * Ml_lhs)
     @. tmp_rhs *= invJ
     v_rhs       = view(rhs, cidxs)
     v_rhs_modal = view(rhs_modal, cidxs)
-    @views v_rhs_modal[rng_modal] .= tmp_rhs[rng_modal]
+    @views v_rhs_modal[rng_modal] .= tmp_rhs[1 .+ rng_modal]
     v_rhs[1] = tmp_rhs[1]; v_rhs[end] = tmp_rhs[end]
   end
   return
 end
 function compute_rhs_weak_form!(rhs, f, s, rhs_modal, f_modal, s_modal, nf, mesh::Mesh1d{DGElement},
-                                mask::AbstractVector, workspace)
-  compute_rhs_weak_form!(rhs, f, rhs_modal, f_modal, nf, mesh, mask, workspace)
+                                bspline::DGElement, mask::AbstractVector, workspace)
+  compute_rhs_weak_form!(rhs, f, rhs_modal, f_modal, nf, mesh, bspline, mask, workspace)
+  @unpack Npts = mesh.element
   for (k,(v_rhs,v_s)) in enumerate(zip(eachcell(mesh,rhs),eachcell(mesh,s)))
     mask[k] > 0.0 || continue
-    v_rhs[1] += v_s[1]; v_rhs[end] += v_s[end]
     cidxs = cellindices(mesh, k)
     v_rhs_modal = view(rhs_modal, cidxs)
     v_s_modal   = view(s_modal, cidxs)
-    rng_modal = 2:Npts-1
-    @views v_rhs_modal[rng_modal] .+= v_s_modal[rng_modal]
+    rng_modal = 1:Npts-1
+    v_rhs[1] += v_s[1]; v_rhs[end] += v_s[end]
   end
   return
 end
diff --git a/src/dgelement.jl b/src/dgelement.jl
index b5698cf4..a1606176 100644
--- a/src/dgelement.jl
+++ b/src/dgelement.jl
@@ -100,9 +100,9 @@ struct DGElement
       # This is more than requested, so we reduce it by one.
       # bs_z is now the grid across which the Bsplines are C^1 continuous.
       if quadr === :lgl || quadr === :lgl_lumped
-        bs_z, _ = LGL.rule(N)
+        bs_z, _ = LGL.rule(Npts)
       elseif quadr === :glgl
-        bs_z, _ = GLGL.rule(N)
+        bs_z, _ = GLGL.rule(Npts)
       end
       purge_zeros!(bs_z)
       bs = Bspline.Bspline2(bs_z)
@@ -119,8 +119,8 @@ struct DGElement
       S = Bspline.stiffness_matrix(bs)
       invM = inv(M)
       V = Bspline.vandermonde_matrix(z, bs)
-      l_lhs      = zeros(Float64, Npts)
-      l_rhs      = zeros(Float64, Npts)
+      l_lhs      = zeros(Float64, Npts+1)
+      l_rhs      = zeros(Float64, Npts+1)
       l_lhs[1]   = 1.0
       l_rhs[end] = 1.0
     end
@@ -146,7 +146,9 @@ struct DGElement
     purge_zeros!(Ml_rhs)
 
     # TODO Filter operators should likely go into separate struct
-    Λ = diagm([ spectral_filter((i-1)/N, (ηc,sh,α)) for i = 1:Npts ])
+    # TODO Clean this up!
+    Np = kind === :modal_bspline2 ? Npts+1 : Npts
+    Λ = diagm([ spectral_filter((i-1)/N, (ηc,sh,α)) for i = 1:Np ])
     F = V * Λ * invV
 
     return new(N, Npts, kind,
diff --git a/src/mesh.jl b/src/mesh.jl
index 0202180b..01226eee 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -653,7 +653,7 @@ Note that for inplace interpolation `intrp_data` and `reference_data` cannot be
 struct MeshInterpolator{T_Mesh<:Mesh1d,T<:AbstractArray}
   reference_mesh::T_Mesh
   sample_nodes::T
-  interpolation_matrices::Vector{Tuple{Int64,Matrix{Float64}}}
+  interpolation_matrices::Vector{Tuple{Int64,Matrix{Float64},Matrix{Float64}}}
   reference_cell_indices::Vector{Int64}
 end
 
@@ -680,21 +680,35 @@ function MeshInterpolator(smpl_x::AbstractArray{<:Float64}, ref_mesh::Mesh1d)
   ref_cell_indices = find_cell_index.(smpl_x, Ref(ref_mesh))
   @assert !any(isnothing, ref_cell_indices)
 
-  interpolation_matrices = Tuple{Int64, Matrix{Float64}}[]
+  @unpack Npts = ref_mesh.element
+  Tavg = zeros(Float64, Npts+1, Npts)
+  Tavg[1,1] = 1.0; Tavg[end,end] = 1.0
+  for i in 2:Npts
+    Tavg[i,i-1:i] .= 0.5
+  end
+  bspline_element = DGElement(ref_mesh.element.N, ref_mesh.element.quadr, :modal_bspline2)
+
+  interpolation_matrices = Tuple{Int64, Matrix{Float64}, Matrix{Float64}}[]
   current_idx = 1 # first index of sample node that falls into a new cell in the reference mesh
   while true
     cell_idx = ref_cell_indices[current_idx]
     # find all sample indices of nodes which fall into the same reference cell
     same_idxs = findall(ref_idx -> ref_idx == cell_idx, ref_cell_indices)
     smpl_idxs = [ current_idx ]
-    if length(same_idxs) == 1
+    if length(same_idxs) != 1
       append!(smpl_idxs, same_idxs[2:end])
     end
     # extract the sample nodes
     smpl_xs = smpl_x[smpl_idxs]
+    box = ref_mesh.boxes[cell_idx]
+    (xmin, xmax), = box.extends
+    ref_xs = @. (2*smpl_xs - (xmax+xmin))/(xmax-xmin)
     # generate interpolation matrix
     push!(interpolation_matrices,
-          (cell_idx, dg1d.barycentric_interp_matrix(smpl_xs, ref_x[:,cell_idx])))
+          (cell_idx,
+           dg1d.barycentric_interp_matrix(smpl_xs, ref_x[:,cell_idx]),
+           dg1d.Bspline.vandermonde_matrix(ref_xs, bspline_element.bs) * Tavg)
+         )
 
     current_idx += length(smpl_idxs)
     if current_idx > length(ref_cell_indices)
@@ -706,14 +720,14 @@ function MeshInterpolator(smpl_x::AbstractArray{<:Float64}, ref_mesh::Mesh1d)
 end
 
 
-function (intrp::MeshInterpolator)(ref_data::T) where {T<:AbstractArray}
+function (intrp::MeshInterpolator)(ref_data::T; mask::Union{Nothing,AbstractArray}=nothing) where {T<:AbstractArray}
   result = similar(intrp.sample_nodes)
-  intrp(result, ref_data)
+  intrp(result, ref_data; mask)
   return result
 end
 
 
-function (intrp::MeshInterpolator{T_Mesh,T})(storage::S, ref_data::T) where {T_Mesh,T<:AbstractArray,S<:AbstractArray}
+function (intrp::MeshInterpolator{T_Mesh,T})(storage::S, ref_data::T; mask::Union{Nothing,AbstractArray}=nothing) where {T_Mesh,T<:AbstractArray,S<:AbstractArray}
 
   if length(ref_data) != length(intrp.reference_mesh)
     error("Size of reference data does not match with reference mesh.")
@@ -721,15 +735,23 @@ function (intrp::MeshInterpolator{T_Mesh,T})(storage::S, ref_data::T) where {T_M
   if length(storage) != length(intrp.sample_nodes)
     error("Size of storage for interpolated result does not match with sample mesh.")
   end
+  if !isnothing(mask) && length(mask) != intrp.reference_mesh.tree.dims[1]
+    error("Size of mask does not match number of cells in reference mesh.")
+  end
 
   mat_ref_data   = vreshape(ref_data, layout(intrp.reference_mesh))
   current_idx = 1 # first index of sample node that falls into a new cell in the reference mesh
-  @views for (ref_cell_idx, intrp_mat) in intrp.interpolation_matrices
-    n_intrp_data = size(intrp_mat)[1]
+  @views for (ref_cell_idx, intrp_mat, bspline_intrp_mat) in intrp.interpolation_matrices
+    if !isnothing(mask) && mask[ref_cell_idx] > 0.0
+      mat = bspline_intrp_mat
+    else
+      mat = intrp_mat
+    end
+    n_intrp_data = size(mat)[1]
     idx_range = current_idx:current_idx+(n_intrp_data-1)
     vec_intrp_data = storage[idx_range]
     vec_ref_data   = mat_ref_data[:, ref_cell_idx]
-    mul!(vec_intrp_data, intrp_mat,  vec_ref_data)
+    mul!(vec_intrp_data, mat, vec_ref_data)
     current_idx += n_intrp_data
   end
 
-- 
GitLab