diff --git a/src/GRHD/cons2prim.jl b/src/GRHD/cons2prim.jl
index 07a708d6b81c828d33ba8813f699cee9753cb97e..ebe938ba7dd3744a4312ae318d9ef6e8e6ce7b0a 100644
--- a/src/GRHD/cons2prim.jl
+++ b/src/GRHD/cons2prim.jl
@@ -395,3 +395,10 @@ end
   c2p_reset_atm = Float64(D<ρmin)
   @returns c2p_reset_atm
 end
+
+
+@with_signature function determine_freeze_mask(equation::Equation)
+  @accepts c2p_freeze_atm, c2p_reset_atm
+  freeze_mask = Float64(c2p_freeze_atm > 0.0 && c2p_reset_atm > 0.0)
+  @returns freeze_mask
+end
diff --git a/src/GRHD/rhs.jl b/src/GRHD/rhs.jl
index abe5a67b080438695720fe7fd0a2285a7b52ebb3..634be58e3f19263656d139f0cdccc8fd3bd88041 100644
--- a/src/GRHD/rhs.jl
+++ b/src/GRHD/rhs.jl
@@ -1225,14 +1225,16 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
   @unpack cache = mesh
   @unpack equation = P
 
-  @unpack D, Sr, Ï„                 = get_dynamic_variables(cache)
-  @unpack rhs_D, rhs_Sr, rhs_Ï„     = get_rhs_variables(cache)
+  @unpack D, Sr, Ï„                    = get_dynamic_variables(cache)
+  @unpack rhs_D, rhs_Sr, rhs_Ï„        = get_rhs_variables(cache)
   @unpack flr_D, flr_Sr, flr_Ï„,
           max_v, vr, p,
-          src_D, src_Sr, src_Ï„     = get_static_variables(cache)
+          src_D, src_Sr, src_Ï„,
+          freeze_mask, c2p_reset_atm  = get_static_variables(cache)
   @unpack nflr_D, nflr_Sr, nflr_Ï„,
           bdry_D, bdry_Sr, bdry_Ï„,
-          bdry_max_v, bdry_vr, bdry_p = get_bdry_variables(cache)
+          bdry_max_v, bdry_vr, bdry_p,
+          bdry_c2p_reset_atm          = get_bdry_variables(cache)
 
   if P.prms.c2p_enforce_causal_atm || P.prms.c2p_enforce_atm
     broadcast_volume!(cons2prim_spherical1d_freeze_flags, equation, mesh)
@@ -1247,6 +1249,13 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
   dg1d.interpolate_face_data!(mesh, max_v, bdry_max_v)
   dg1d.interpolate_face_data!(mesh, vr,    bdry_vr)
   dg1d.interpolate_face_data!(mesh, p,     bdry_p)
+  dg1d.interpolate_face_data!(mesh, c2p_reset_atm, bdry_c2p_reset_atm)
+
+  broadcast_volume!(determine_atmosphere, P.equation, mesh)
+  if P.prms.atm_equalize_on_interface
+    broadcast_faces!(equalize_atmosphere_spherical1d, P.equation, mesh)
+  end
+  broadcast_volume!(determine_freeze_mask, P.equation, mesh)
 
   broadcast_volume!(flux_source_spherical1d, equation, mesh)
   impose_symmetry_sources!(P, mesh)
@@ -1274,17 +1283,39 @@ function rhs!(mesh::Mesh1d, P::Project{:spherical1d}, hrsc::Nothing)
 
   if :D ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_D,  flr_D,  src_D,  nflr_D,  mesh)
+    compute_rhs_weak_form!(rhs_D,  flr_D,  src_D,  nflr_D,  mesh, P.prms.bernstein, freeze_mask)
   end
   if :Sr ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_Sr, flr_Sr, src_Sr, nflr_Sr, mesh)
+    compute_rhs_weak_form!(rhs_Sr, flr_Sr, src_Sr, nflr_Sr, mesh, P.prms.bernstein, freeze_mask)
   end
   if :τ ∉ P.prms.freeze_vars
     compute_rhs_weak_form!(rhs_Ï„,  flr_Ï„,  src_Ï„,  nflr_Ï„,  mesh)
+    compute_rhs_weak_form!(rhs_Ï„,  flr_Ï„,  src_Ï„,  nflr_Ï„,  mesh, P.prms.bernstein, freeze_mask)
   end
 
   if !P.prms.atm_evolve
-    broadcast_volume!(determine_atmosphere, P.equation, mesh)
+    # broadcast_volume!(determine_atmosphere, P.equation, mesh)
     broadcast_volume!(stop_atmosphere_evolution_spherical1d, P.equation, mesh)
+    # for k in 1:K
+    #   cell = mesh.tree.cells[k]
+    #   cidxs = cellindices(mesh, k)
+    #   v_mask = view(mask, cidxs)
+    #   s_mask = sum(v_mask)
+    #   if s_mask ≈ Npts || s_mask ≈ 0.0
+    #     continue
+    #   end
+    #   v_rhs_D  = view(rhs_D, cidxs)
+    #   v_rhs_Sr = view(rhs_Sr, cidxs)
+    #   v_rhs_Ï„  = view(rhs_Ï„, cidxs)
+    #   if dg1d.has_neighbor(cell,Cart1d.Xmin)
+    #     v_ngb_mask = dg1d.view_neighbor(mask, cell, Cart1d.Xmin, mesh)
+    #     s_ngb_mask = sum(v_ngb_mask)
+    #     if !(s_ngb_mask ≈ 0.0)
+    #
+    #     end
+    #   end
+    # end
   end
 
   if bc == "tov_symmetric_domain"
diff --git a/src/GRHD/setup.jl b/src/GRHD/setup.jl
index 3c76c8bb3320a4cf0c4dbaa76d53108b3e526a44..e411951c5af1e4d5d3a42e34480a99ea8e4b056f 100644
--- a/src/GRHD/setup.jl
+++ b/src/GRHD/setup.jl
@@ -32,7 +32,8 @@ function Project(env::Environment, prms)
   slope_limiter_method = Symbol(prms["GRHD"]["slope_limiter_method"])
   slope_limiter_tvb_M = prms["GRHD"]["slope_limiter_tvb_M"]
   c2p_dynamic_atm = prms["GRHD"]["c2p_dynamic_atm"]
-  bernstein = BernsteinReconstruction(env.mesh)
+  bernstein_rec = HRSC.BernsteinReconstruction(env.mesh)
+  bernstein = dg1d.Bernstein.BernsteinElement(env.mesh.element.Npts)
   cold_K, cold_Γ = prms["GRHD"]["c2p_cold_eos_parameters"]
   cold_eos = Polytrope(cold_K, cold_Γ)
   c2p_set_atmosphere_on_failure = prms["GRHD"]["c2p_set_atmosphere_on_failure"]
@@ -45,7 +46,7 @@ function Project(env::Environment, prms)
   id = prms["GRHD"]["id"]
   bc = prms["GRHD"]["bc"]
   fixedprms = (; av_regularization=:covariant, id_smooth=true,
-                 bernstein, slope_limiter_method, slope_limiter_tvb_M,
+                 bernstein, bernstein_rec, slope_limiter_method, slope_limiter_tvb_M,
                  c2p_dynamic_atm, atm_evolve, atm_equalize_on_interface,
                  c2p_set_atmosphere_on_failure, c2p_enforce_causal_atm, c2p_enforce_atm,
                  av_drag, av_sensor_abslog_D, problem, freeze_vars,
@@ -480,7 +481,8 @@ end
 function register_analysis!(mesh::Mesh1d, P)
   register_variables!(mesh,
       static_variablenames = (:c2p_reset_ϵ, :c2p_reset_atm, :c2p_limit_vr,
-                              :c2p_freeze_atm, :c2p_init_admissible, :v),
+                              :c2p_freeze_atm, :c2p_init_admissible, :v,
+                              :freeze_mask),
       bdry_variablenames = (:bdry_c2p_reset_atm,)
   )
   if "Mtot" in P.prmsdb["GRHD"]["variables0d_analyze"]
@@ -491,7 +493,8 @@ function register_analysis!(mesh::Mesh2d, P)
   register_variables!(mesh,
       static_variablenames = (:c2p_reset_ϵ, :c2p_reset_atm, :c2p_limit_vr,
                               :c2p_freeze_atm, :c2p_init_admissible, :v,
-                              :buf_c2p_reset_atm,),
+                              :buf_c2p_reset_atm,
+                              :freeze_mask),
       bdry_variablenames = (:bdry_c2p_reset_atm,)
   )
   if "Mtot" in P.prmsdb["GRHD"]["variables0d_analyze"]
diff --git a/src/ScalarEq/rhs.jl b/src/ScalarEq/rhs.jl
index 756811d0cb7222a5366fca46a87f0bf7419e2eba..e63abb068676979500f1d9e6b18aea3d8c24752c 100644
--- a/src/ScalarEq/rhs.jl
+++ b/src/ScalarEq/rhs.jl
@@ -282,7 +282,9 @@ function rhs!(env, mesh::Mesh1d, P::Project, hrsc::Nothing)
   broadcast_faces!(llf_1d, equation, mesh)
   # broadcast_bdry!(bdryllf_1d, equation, P.bdrycond, mesh)
 
-  compute_rhs_weak_form!(rhs_u, flx_u, src_u, nflx_u, mesh)
+  # compute_rhs_weak_form!(rhs_u, flx_u, src_u, nflx_u, mesh)
+  compute_rhs_weak_form!(rhs_u, flx_u, src_u, nflx_u, mesh, P.prms.bernstein_element)
+  # TODO()
 
   return
 end
diff --git a/src/ScalarEq/setup.jl b/src/ScalarEq/setup.jl
index f92c621ec692f6e37359fd2bc050993e38fb87b4..3499ccea0074fab16359e1c5bfe9643a640329b3 100644
--- a/src/ScalarEq/setup.jl
+++ b/src/ScalarEq/setup.jl
@@ -11,9 +11,10 @@ function Project(env::Environment, mesh::Mesh1d, prms)
   av_recompute_substeps = prms["ScalarEq"]["av_recompute_substeps"]
   muscl_omega = prms["ScalarEq"]["muscl_omega"]
   bernstein = BernsteinReconstruction(mesh)
+  bernstein_element = dg1d.Bernstein.BernsteinElement(mesh.element.Npts)
   analyze_error = prms["ScalarEq"]["analyze_error"]
   analyze_error_norm = prms["ScalarEq"]["analyze_error_norm"]
-  fixedprms = (; bernstein, av_derivative_scheme, av_drag, hrsc, muscl_omega,
+  fixedprms = (; bernstein, bernstein_element, av_derivative_scheme, av_drag, hrsc, muscl_omega,
                  slope_limiter_method, slope_limiter_tvb_M,
                  av_recompute_substeps,
                  analyze_error, analyze_error_norm)
diff --git a/src/dg_rhs.jl b/src/dg_rhs.jl
index d45c126efb4bdaa2f5f5c7dc04ae5ed9ce8b8fb7..acd9a25323b8f1477f57a2b875951c9fb23ce56d 100644
--- a/src/dg_rhs.jl
+++ b/src/dg_rhs.jl
@@ -31,6 +31,76 @@ function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{SpectralElement})
 end
 
 
+function compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d{SpectralElement},
+    element::Bernstein.BernsteinElement)
+  @unpack invM, MDM, MB_lhs, MB_rhs, Npts = element
+  K = mesh.tree.dims[1]
+  @unpack invdetJ = get_static_variables(mesh.cache)
+  shape       = layout(mesh)
+  mat_rhs     = vreshape(rhs,     shape)
+  mat_f       = vreshape(f,       shape)
+  mat_invdetJ = vreshape(invdetJ, shape)
+  mul!(mat_rhs, MDM, mat_f)
+  nf_lhs       = view(nf, 1:2:2*K)
+  nf_rhs       = view(nf, 2:2:2*K)
+  @turbo @. begin # factor x10
+    mat_rhs  -= (nf_rhs' * MB_rhs + #= minus contained in normal vector =# nf_lhs' * MB_lhs)
+    mat_rhs *= mat_invdetJ
+  end
+  return
+end
+function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{SpectralElement},
+    element::Bernstein.BernsteinElement)
+  compute_rhs_weak_form!(rhs, f, nf, mesh, element)
+  rhs .+= s
+  return
+end
+
+
+function compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d{SpectralElement},
+    element::Bernstein.BernsteinElement, mask::Vector{Float64})
+  @unpack invM, MDM, MB_lhs, MB_rhs = element
+  @unpack invdetJ = get_static_variables(mesh.cache)
+  Npts, K = layout(mesh)
+  nf_lhs       = view(nf, 1:2:2*K)
+  nf_rhs       = view(nf, 2:2:2*K)
+  for k in 1:K
+    cidxs = cellindices(mesh, k)
+    v_mask = view(mask, cidxs)
+    s_mask = sum(v_mask)
+    if s_mask ≈ Npts || s_mask ≈ 0.0
+      continue
+    end
+    v_rhs  = view(rhs, cidxs)
+    v_f    = view(f, cidxs)
+    v_invJ = view(invdetJ, cidxs)
+    mul!(v_rhs, MDM, v_f)
+    @. v_rhs -= nf_rhs[k] * Ml_rhs + nf_lhs[k] * Ml_lhs
+    @. v_rhs *= v_invJ
+  end
+  return
+end
+function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{SpectralElement},
+    element::Bernstein.BernsteinElement, mask::Vector{Float64})
+  compute_rhs_weak_form!(rhs, f, nf, mesh)
+  @unpack invM, MDM, MB_lhs, MB_rhs = element
+  @unpack invdetJ = get_static_variables(mesh.cache)
+  Npts, K = layout(mesh)
+  for k in 1:K
+    cidxs = cellindices(mesh, k)
+    v_mask = view(mask, cidxs)
+    s_mask = sum(v_mask)
+    if s_mask ≈ Npts || s_mask ≈ 0.0
+      continue
+    end
+    v_rhs  = view(rhs, cidxs)
+    v_s    = view(s, cidxs)
+    @. v_rhs += v_s
+  end
+  return
+end
+
+
 """
     compute_rhs_weak_form!(rhs, fx, fy, nf, mesh::Mesh2d{SpectralElement})
     compute_rhs_weak_form!(rhs, fx, fy, s, nf, mesh::Mesh2d{SpectralElement})