Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • dg/dg1d.jl
1 result
Show changes
Showing with 295 additions and 727 deletions
function make_BoundaryConditions(env, equation, rsolver, prms)
@unpack bc = prms
lhs_bc, rhs_bc = if bc == "periodic"
PeriodicBC(), PeriodicBC()
elseif bc == "no_inflow"
DirichletBC((0.0,)), OutflowBC()
elseif bc == "from_id"
@unpack u = get_dynamic_variables
u_lhs, u_rhs = deepcopy(u[1]), deepcopy(u[end])
lhs_bdry = t -> (u_lhs,)
rhs_bdry = t -> (u_rhs,)
DirichletBC(lhs_bdry), DirichletBC(rhs_bdry)
else
error("Unknown boundary condition requested: '$bc'")
end
return BoundaryConditions(lhs_bc, rhs_bc, rsolver)
end
function make_BoundaryConditions_LDG(env, equation, ldg_rsolver, prms)
@unpack bc = prms
isnothing(ldg_rsolver) && return nothing
lhs_bc, rhs_bc = if bc == "periodic"
PeriodicBC(), PeriodicBC()
elseif bc == "no_inflow"
OutflowBC(), OutflowBC()
elseif bc == "from_id"
OutflowBC(), OutflowBC()
else
error("Unknown boundary condition requested: '$bc'")
end
return BoundaryConditions(lhs_bc, rhs_bc, ldg_rsolver)
end
function make_callback(env, P::Project, bdryconds::BoundaryConditions)
cbfn_equation(u, t) = callback_equation(u, t, env, P, isperiodic(bdryconds), P.equation, env.mesh)
function make_callback(env, P::Project, isperiodic::Bool)
cbfn_equation(u, t) = callback_equation(u, t, env, P, isperiodic, P.equation, env.mesh)
cb_equation = FunctionCallback(cbfn_equation,
CallbackTiming(every_iteration=1,every_dt=0))
cbfn_tci(u, t) = callback_tci(u, t, env, P, isperiodic(bdryconds), P.tci, env.mesh)
cbfn_tci(u, t) = callback_tci(u, t, env, P, isperiodic, P.tci, env.mesh)
cb_tci = FunctionCallback(cbfn_tci,
CallbackTiming(every_iteration=1, every_dt=0))
cbfn_hrsc(u, t) = callback_hrsc(u, t, env, P, isperiodic(bdryconds), P.hrsc, env.mesh)
cbfn_hrsc(u, t) = callback_hrsc(u, t, env, P, isperiodic, P.hrsc, env.mesh)
cb_hrsc = FunctionCallback(cbfn_hrsc,
CallbackTiming(every_iteration=1,every_dt=0))
callbackset = CallbackSet(cb_equation, cb_tci, cb_hrsc)
......@@ -35,8 +35,8 @@ function callback_equation(state_u, state_t, env, P, isperiodic, eq::AbstractSca
@. tm1 = t
t .= state_t
# compute new values
broadcast_volume!(entropy, equation, cache)
broadcast_volume!(entropy_flux, equation, cache)
broadcast_volume_2!(entropy, equation, mesh)
broadcast_volume_2!(entropy_flux, equation, mesh)
end
......@@ -62,7 +62,7 @@ function callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc::HRSC.Abstract
@unpack v = get_static_variables(cache)
@unpack mu, max_v = get_cell_variables(cache)
@unpack u = get_dynamic_variables(cache)
broadcast_volume!(speed, equation, cache)
broadcast_volume_2!(speed, equation, mesh)
Npts, K = layout(mesh)
mat_max_v = reshape(view(v, :), (Npts, K))
for k = 1:K
......@@ -85,7 +85,6 @@ end
function callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc::HRSC.SmoothedArtificialViscosity, mesh::Mesh1d)
callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc.av, mesh)
@unpack cache = env
@unpack rsolver = P
@unpack mu = get_cell_variables(cache)
@unpack smoothed_mu = get_static_variables(cache)
hrsc.smoother(smoothed_mu,mu,mesh,isperiodic)
......@@ -95,7 +94,6 @@ end
function callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc::HRSC.SmoothedArtificialViscosity, mesh::Mesh2d)
callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc.av, mesh)
@unpack cache = env
@unpack rsolver = P
@unpack mu = get_cell_variables(cache)
@unpack smoothed_mu = get_static_variables(cache)
hrsc.smoother(smoothed_mu,mu,mesh,true,true)
......@@ -109,7 +107,7 @@ function callback_hrsc(state_u, state_t, env, P, isperiodic, hrsc::HRSC.EntropyV
@unpack E, F, Em1, Fm1, v = get_static_variables(cache)
@unpack mu, max_v = get_cell_variables(cache)
@unpack t, tm1 = get_global_variables(cache)
broadcast_volume!(speed, equation, cache)
broadcast_volume_2!(speed, equation, mesh)
Npts, K = layout(mesh)
mat_max_v = reshape(view(v, :), (Npts, K))
for k = 1:K
......@@ -206,8 +204,8 @@ end
function callback_tci(state_u, state_t, env, P, isperiodic,
tci::Union{TCI.Diffusion, TCI.Minmod, TCI.Threshold,
TCI.ModalDecayAverage, TCI.ModalDecayHighest}, mesh)
@unpack cache, mesh = env
@unpack u = get_dynamic_variables(cache)
@unpack flag = get_cell_variables(cache)
@unpack cache, mesh = env
@unpack u = get_dynamic_variables(cache)
@unpack flag = get_cell_variables(cache)
TCI.compute_indicator!(flag, u, tci)
end
......@@ -84,13 +84,20 @@ end
@with_signature function flux(eq::Advection)
@accepts State(u)
@accepts u
flx_u = eq.v * u
@returns flx_u
end
@with_signature function fv_bdry_flux(eq::Advection)
@accepts u, flx_u, init_u
bdry_u = -u + 2*init_u
nflx_u = eq.v * bdry_u
@returns bdry_u, nflx_u
end
@with_signature function speed(eq::Advection)
@accepts State(u)
@accepts u
v = eq.v
@returns v
end
......@@ -113,13 +120,13 @@ end
@with_signature function flux(eq::Advection2d)
@accepts State(u)
@accepts u
flx_x, flx_y = eq.v_x * u, eq.v_y * u
@returns flx_x, flx_y
end
@with_signature function speed(eq::Advection2d)
@accepts State(u)
@accepts u
v = sqrt(eq.v_x^2 + eq.v_y^2)
@returns v
end
......@@ -145,13 +152,20 @@ end
@with_signature function flux(eq::Burgers)
@accepts State(u)
@accepts u
flx_u = u^2 / 2
@returns flx_u
end
@with_signature function fv_bdry_flux(eq::Burgers)
@accepts u, flx_u, init_u
bdry_u = -u + 2*init_u
nflx_u = bdry_u^2 / 2
@returns bdry_u, nflx_u
end
@with_signature function speed(eq::Burgers)
@accepts State(u)
@accepts u
v = u
@returns v
end
......@@ -175,13 +189,13 @@ end
@with_signature function flux(eq::Burgers2d)
@accepts State(u)
@accepts u
flx_x, flx_y = u^2/2, u^2/2
@returns flx_x, flx_y
end
@with_signature function speed(eq::Burgers2d)
@accepts State(u)
@accepts u
v = sqrt(2) * u
@returns v
end
......
......@@ -2,40 +2,40 @@
# Timestep #
#######################################################################
function timestep(env, P::Project, ::Mesh1d, hrsc::Maybe{HRSC.AbstractHRSC})
@unpack cache, mesh = env
@unpack equation = P
@unpack v = get_static_variables(cache)
broadcast_volume!(speed, equation, cache)
function maxspeed(mesh, equation, cache)
@unpack v = get_static_variables(cache)
broadcast_volume_2!(speed, equation, mesh)
vmax = dg1d.absolute_maximum(v)
vmax_limit = 1e4
if vmax > vmax_limit
@warn "Limiting timestep due to maximum speed exceeding $vmax_limit"
vmax = vmax_limit
end
return vmax
end
@unpack CFL, element = mesh
@unpack N, z = element
dz = z[2]-z[1]
dx, = minimum(dg1d.widths(mesh.boxes[1]))
dl = dx * dz
dt = CFL * dl / (N * vmax)
function timestep(env, P::Project, mesh::Mesh1d, hrsc::Maybe{HRSC.AbstractHRSC})
vmax = maxspeed(mesh, P.equation, mesh.cache)
dl = min_grid_spacing(mesh)
dt = dl / (vmax * dtfactor(mesh))
return dt
end
dtfactor(mesh::Mesh1d{FVElement}) = 2
dtfactor(mesh::Mesh1d{SpectralElement}) = mesh.element.N
function timestep(env, P::Project, ::Mesh2d, hrsc::Maybe{HRSC.AbstractHRSC})
function timestep(env, P::Project, ::Mesh2d{SpectralElement}, hrsc::Maybe{HRSC.AbstractHRSC})
@unpack cache, mesh = env
@unpack equation = P
broadcast_volume!(speed, equation, cache)
broadcast_volume_2!(speed, equation, mesh)
@unpack v = get_static_variables(cache)
vmax = dg1d.absolute_maximum(v)
@unpack CFL, element = mesh
@unpack element = mesh
@unpack N, z = element
dx, dy = dg1d.widths(mesh.boxes[1])
......@@ -47,7 +47,7 @@ function timestep(env, P::Project, ::Mesh2d, hrsc::Maybe{HRSC.AbstractHRSC})
dz = z[2]-z[1]
dl = min(dx, dy) * dz
dt = CFL * dl / (N * vmax)
dt = dl / (N * vmax)
return dt
end
......@@ -66,11 +66,30 @@ function rhs!(env, P::Project, hrsc, bdryconds...)
end
function rhs!(env, mesh::Mesh1d{FVElement}, P::Project, hrsc::Nothing,
bdryconds, ldg_bdryconds::Nothing, av_bdryconds::Nothing)
@unpack cache, mesh = env
@unpack equation = P
@unpack flx_u = get_static_variables(cache)
@unpack u = get_dynamic_variables(cache)
@unpack rhs_u = get_rhs_variables(cache)
@unpack nflx_u, bdry_u = get_bdry_variables(cache)
broadcast_volume_2!(flux, equation, mesh)
dt = timestep(env, P, mesh, hrsc)
fv_update_step!(rhs_u, u, flx_u, bdry_u, nflx_u, dt, mesh)
return
end
function rhs!(env, mesh::Mesh1d, P::Project, hrsc::Nothing,
bdryconds, ldg_bdryconds::Nothing, av_bdryconds::Nothing)
@unpack cache, mesh = env
@unpack equation, rsolver = P
@unpack equation = P
broadcast_volume_2!(flux, equation, mesh)
broadcast_faces_2!(llf_1d, equation, mesh)
......@@ -90,13 +109,11 @@ function rhs!(env, mesh::Mesh2d, P::Project, hrsc::Nothing,
bdryconds, ldg_bdryconds::Nothing, av_bdryconds::Nothing)
@unpack cache = env
@unpack equation, rsolver = P
@unpack equation = P
broadcast_volume_2!(flux, equation, mesh)
broadcast_faces_2!(llf_2d, equation, mesh)
dg1d.broadcast_bdry_2!(bdryllf_2d, equation, P.bdrycond, mesh)
# broadcast_boundaryconditions!(central_flux, bdryconds_x, cache, mesh, 0.0)
# broadcast_boundaryconditions!(central_flux, bdryconds_y, cache, mesh, 0.0)
@unpack flx_x, flx_y = get_static_variables(cache)
@unpack u = get_dynamic_variables(cache)
......@@ -114,7 +131,7 @@ function rhs!(env, mesh::Mesh1d, P::Project, hrsc::Union{BernsteinReconstruction
TODO()
@unpack cache, mesh = env
@unpack equation, rsolver = P
@unpack equation = P
@unpack flag = get_cell_variables(cache)
@unpack flx_u = get_static_variables(cache)
@unpack u = get_dynamic_variables(cache)
......@@ -123,9 +140,9 @@ function rhs!(env, mesh::Mesh1d, P::Project, hrsc::Union{BernsteinReconstruction
HRSC.reconstruct!(u, flag, hrsc, isperiodic=isperiodic(bdryconds))
broadcast_volume!(flux, equation, cache)
broadcast_faces!(lax_friedrich_flux, rsolver, cache, mesh)
broadcast_boundaryconditions!(lax_friedrich_flux, bdryconds, cache, mesh, 0.0)
broadcast_volume_2!(flux, equation, mesh)
broadcast_faces_2!(llf, equation, mesh)
broadcast_bdry_2!(bdry_llf, bdryconds, mesh)
compute_rhs_weak_form!(rhs_u, flx_u, lhs_numflx_u, rhs_numflx_u, mesh)
......@@ -137,7 +154,7 @@ function rhs!(env, mesh::Mesh1d, P::Project, hrsc::HRSC.AbstractArtificialViscos
bdryconds, ldg_bdryconds, av_bdryconds)
@unpack cache, mesh = env
@unpack equation, rsolver, ldg_rsolver, av_rsolver = P
@unpack equation = P
@unpack flx_u, flx_q, q = get_static_variables(cache)
@unpack u = get_dynamic_variables(cache)
......@@ -181,17 +198,14 @@ function rhs!(env, mesh::Mesh2d, P::Project, hrsc::HRSC.AbstractArtificialViscos
broadcast_volume_2!(ldg_flux_2d, equation, mesh)
broadcast_faces_2!(ldg_nflux_qx, equation, mesh)
broadcast_faces_2!(ldg_nflux_qy, equation, mesh)
# broadcast_boundaryconditions!(central_flux, ldg_bdryconds, cache, mesh, 0.0)
compute_rhs_weak_form!(qx, flx_qx_x, flx_qx_y, nflx_qx, mesh)
compute_rhs_weak_form!(qy, flx_qy_x, flx_qy_y, nflx_qy, mesh)
# ## compute rhs of regularized equation: ∂t u + ∂x f + ∂y f + ∂x μ qx + ∂y μ qy = 0
broadcast_volume_2!(av_flux_2d, equation, mesh)
broadcast_faces_2!(av_nflux_2d, equation, mesh)
# broadcast_boundaryconditions!(mean_flux, av_bdryconds, cache, mesh, 0.0)
broadcast_volume_2!(flux, equation, mesh)
broadcast_faces_2!(llf_2d, equation, mesh)
# broadcast_boundaryconditions!(lax_friedrich_flux, bdryconds, cache, mesh, 0.0)
# add up fluxs and numerical fluxes
@. flx_x += flx_g_x
......
......@@ -9,10 +9,9 @@ function Project(env::Environment, mesh::Mesh1d, prms)
hrsc = HRSC.make_HRSC(env.mesh, prms["HRSC"])
# TODO remove this
rsolver, ldg_rsolver, av_rsolver = nothing, nothing, nothing
bdryconds, ldg_bdryconds, av_bdryconds = nothing, nothing, nothing
P = Project(equation, rsolver, hrsc, tci, ldg_rsolver, av_rsolver, bdryconds)
P = Project(equation, hrsc, tci, bdryconds)
# register variables
# TODO Somehow replace _register_variables! with register_variables!
......@@ -45,11 +44,9 @@ function Project(env::Environment, mesh::Mesh2d, prms)
equation = make_Equation(prms["ScalarEq"], dimension(mesh))
tci = TCI.make_TCI(mesh, prms["TCI"])
hrsc = HRSC.make_HRSC(mesh, prms["HRSC"])
rsolver = ApproxRiemannSolver2d(flux, speed, equation)
ldg_rsolver, av_rsolver = nothing, nothing
bdrycond = dg1d.DirichletBC2()
P = Project(equation, rsolver, hrsc, tci, ldg_rsolver, av_rsolver, bdrycond)
P = Project(equation, hrsc, tci, bdrycond)
# register variables
# TODO Somehow replace _register_variables! with register_variables!
......@@ -58,19 +55,18 @@ function Project(env::Environment, mesh::Mesh2d, prms)
_register_variables!(mesh, tci)
_register_variables!(mesh, hrsc)
_register_variables!(cache, bdrycond)
register_variables!(cache, rsolver)
display(cache)
# setup initial data
initialdata!(env, P, prms["ScalarEq"])
# setup boundary conditions
bdryconds_x = BoundaryConditions(PeriodicBC(), PeriodicBC(), rsolver)
bdryconds_y = BoundaryConditions(PeriodicBC(), PeriodicBC(), rsolver)
# # setup boundary conditions
bdryconds_x = nothing
bdryconds_y = nothing
ldg_bdryconds, av_bdryconds = nothing, nothing
# setup callbacks
projectcb = make_callback(env, P, bdryconds_x)
projectcb = make_callback(env, P, isperiodic(mesh))
append!(env.callbacks, CallbackSet(projectcb.callbacks...))
return P, ((bdryconds_x, bdryconds_y), ldg_bdryconds, av_bdryconds)
......@@ -99,7 +95,7 @@ function _register_variables!(cache, eq::AbstractScalarEq)
rhs_variablenames = (:rhs_u,),
static_variablenames = (:flx_u, :E, :Em1, :F, :Fm1, :v),
cell_variablenames = (:max_v,),
bdry_variablenames = (:nflx_u,),
bdry_variablenames = (:nflx_u,:bdry_u),
global_variablenames = (:t, :tm1))
end
......@@ -159,6 +155,7 @@ function _register_variables!(cache, eq::Union{Advection2d,Burgers2d})
rhs_variablenames = (:rhs_u,),
static_variablenames = (:flx, :flx_x, :flx_y, :v, :v_x, :v_y),
cell_variablenames = (:max_v,),
bdry_variablenames = (:nflx_u,),
global_variablenames = (:t, :tm1))
end
......
......@@ -28,18 +28,12 @@ struct Burgers2d <: AbstractScalarEq end
struct Project{T_Equation <:AbstractScalarEq,
T_RSolver <:Maybe{AbstractRiemannSolver},
T_HRSC <:Maybe{HRSC.AbstractHRSC},
T_TCI <:Maybe{TCI.AbstractTCI},
T_LDG_RSolver <:Maybe{AbstractRiemannSolver},
T_AV_RSolver <:Maybe{AbstractRiemannSolver},
T_BC <:Maybe{dg1d.AbstractBC}} <: dg1d.AbstractProject
equation::T_Equation
rsolver::T_RSolver
hrsc::T_HRSC
tci::T_TCI
ldg_rsolver::T_LDG_RSolver
av_rsolver::T_AV_RSolver
bdrycond::T_BC
end
......@@ -8,7 +8,7 @@ end
function compute_indicator!(
flag, # mutated outputs
u, # inputs
tci::Threshold{Mesh1d})
tci::Threshold{<:Mesh1d})
@unpack mesh, threshold = tci
L = layout(mesh)
......@@ -34,7 +34,7 @@ end
function compute_indicator!(
flag, # mutated outputs
u, # inputs
tci::Threshold{Mesh2d})
tci::Threshold{<:Mesh2d})
@unpack mesh, threshold = tci
@unpack Npts = mesh.element
......
......@@ -34,74 +34,3 @@ DirichletBC(bdry_vals::NTuple) = DirichletBC((t)->bdry_vals)
struct OutflowBC <: AbstractBC end
struct InflowBC <: AbstractBC end
struct PeriodicBC <: AbstractBC end
#######################################################################
# BoundaryConditions interface #
#######################################################################
"""
BoundaryConditions
Boundary condition interface for computing numerical fluxes at the domain boundaries.
Conditions are enforced weakly or strongly using an `AbstractRiemannSolver` and
its interface.
See: `ApproxRiemannSolver`.
---
BoundaryConditions(lhs_bc<:AbstractBC, rhs_bc<:AbstractBC, rsolver<:AbstractRiemannSolver)
Assumes that `lhs_bc, lhs_bc` return the boundary state at the lhs/rhs domain ends.
See `AbstractBC` and its subtypes for available boundary condition types.
# Example
```julia
julia> # mesh, cache, rsolver already defined
julia> lhs_bc = DirichletBC(t -> 0.0)
julia> rhs_bc = OutflowBC()
julia> bdryconds = BoundaryConditions(lhs_bc, rhs_bc, rsolver)
```
"""
struct BoundaryConditions{T_BC_LHS<:AbstractBC,
T_BC_RHS<:AbstractBC,
T_RSolver<:AbstractRiemannSolver}
lhs_bc::T_BC_LHS
rhs_bc::T_BC_RHS
rsolver::T_RSolver
isperiodic::Bool
function BoundaryConditions(lhs_bc, rhs_bc, rsolver)
@assert !xor(lhs_bc isa PeriodicBC, rhs_bc isa PeriodicBC)
isperiodic = lhs_bc isa PeriodicBC && rhs_bc isa PeriodicBC
if !isperiodic
n_accepts = length(accepts(rsolver.flux, rsolver.equation)[1])
lhs_bc_iscomptabiel = if lhs_bc isa OutflowBC
true
else
lhs_state = lhs_bc(0)
@assert lhs_state isa Tuple "Boundary state must be a tuple"
n_accepts == length(lhs_state)
end
rhs_bc_iscomptabiel = if rhs_bc isa OutflowBC
true
else
rhs_state = rhs_bc(0)
@assert rhs_state isa Tuple "Boundary state must be a tuple"
n_accepts == length(rhs_state)
end
@assert lhs_bc_iscomptabiel && rhs_bc_iscomptabiel "Boundary conditions \
must return the same number of arguments the flux function '$(rsolver.flux)' accepts"
end
return new{typeof(lhs_bc), typeof(rhs_bc), typeof(rsolver)}(lhs_bc, rhs_bc,
rsolver, isperiodic)
end
end
isperiodic(bdry::BoundaryConditions) = bdry.isperiodic
......@@ -72,6 +72,9 @@ include("lgl.jl")
export SpectralElement
include("spectralelement.jl")
export FVElement
include("fvelement.jl")
include("tensorbasis.jl")
include("box.jl")
include("tree.jl")
......@@ -88,13 +91,14 @@ export Cache, register_variables!,
get_global_variables, get_dynamic_variables, get_cell_variables,
get_bdry_variables, get_rhs_variables, get_static_variables,
get_variable, wrap_dynamic_variables!, wrap_rhs_variables!,
broadcast_volume!, variablenames, variablegroups,
variablenames, variablegroups,
broadcast_volume_2!, broadcast_faces_2!, broadcast_bdry_2!
include("cache.jl")
export Mesh, Mesh1d, Mesh2d, MeshInterpolator, layout, n_points, n_cells, dimension,
grid_average, cellwise_average, cellwise_inner_product, broken_inner_product,
locate_point, find_cell_index, eachcell, cellindices
locate_point, find_cell_index, eachcell, cellindices,
min_grid_spacing
include("mesh.jl")
export AbstractEquation
......@@ -110,13 +114,7 @@ export CallbackSet, CallbackTiming, FunctionCallback, SaveCallback, PlotCallback
TimeAlignedSaveCallback
include("callbacks.jl")
export AbstractRiemannSolver, ApproxRiemannSolver, ApproxRiemannSolver2d,
lax_friedrich_flux, avg_lax_friedrich_flux,
central_flux, mean_flux, broadcast_faces!
include("riemannsolver.jl")
export DirichletBC, OutflowBC, InflowBC, PeriodicBC, BoundaryConditions, isperiodic,
broadcast_boundaryconditions!
export DirichletBC, OutflowBC, InflowBC, PeriodicBC, BoundaryConditions, isperiodic
include("bdryconditions.jl")
export @with_signature, @accepts, @returns,
......@@ -126,6 +124,9 @@ include("with_signature.jl")
export compute_rhs_weak_form!
include("dg_rhs.jl")
export fv_update_step!
include("fv_rhs.jl")
# TODO Move this to top; also requires gather all type defintions
# like SpectralElement, AbstractMesh, Tree etc. into types.jl
export Maybe, Environment
......
"""
compute_rhs_weak_form!(rhs, f, nf_lhs, nf_rhs, mesh)
compute_rhs_weak_form!(rhs, f, s, nf_lhs, nf_rhs, mesh)
compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d{SpectralElement})
compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{SpectralElement})
Compute the `rhs` of the weak DG formulation using
the (bulk) flux `f`, the numerical flux `nf` and sources `s` for a `mesh`.
"""
function compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d)
function compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d{SpectralElement})
@unpack invjac, element = mesh
@unpack invM, MDM, Ml_lhs, Ml_rhs, Npts = element
@unpack K = mesh
......@@ -18,13 +21,21 @@ function compute_rhs_weak_form!(rhs, f, nf, mesh::Mesh1d)
end
return
end
function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d)
function compute_rhs_weak_form!(rhs, f, s, nf, mesh::Mesh1d{SpectralElement})
compute_rhs_weak_form!(rhs, f, nf, mesh)
rhs .+= s
return
end
function compute_rhs_weak_form!(rhs, fx, fy, nf, mesh::Mesh2d)
"""
compute_rhs_weak_form!(rhs, fx, fy, nf, mesh::Mesh2d{SpectralElement})
compute_rhs_weak_form!(rhs, fx, fy, s, nf, mesh::Mesh2d{SpectralElement})
Compute the `rhs` of the weak DG formulation using
the (bulk) flux `fx, fy`, the numerical flux `nf` and sources `s` for a `mesh`.
"""
function compute_rhs_weak_form!(rhs, fx, fy, nf, mesh::Mesh2d{SpectralElement})
@unpack element = mesh
@unpack dxdX, dydY = get_static_variables(mesh.cache)
@unpack invM, MDM, Ml_lhs, Ml_rhs, Npts = element
......@@ -93,7 +104,15 @@ function compute_rhs_weak_form!(rhs, fx, fy, nf, mesh::Mesh2d)
end
return
end
function compute_rhs_weak_form!(rhs, fx, fy::Real, nf, mesh::Mesh2d)
"""
compute_rhs_weak_form!(rhs, fx, _::Real, nf, mesh::Mesh2d{SpectralElement})
compute_rhs_weak_form!(rhs, _::Real, fy, nf, mesh::Mesh2d{SpectralElement})
Specialized version of `compute_rhs_weak_form!` where only the fluxs
`fx` or `fy` are applied, respectively.
"""
function compute_rhs_weak_form!(rhs, fx, _::Real, nf, mesh::Mesh2d{SpectralElement})
@unpack element = mesh
@unpack dxdX, dydY = get_static_variables(mesh.cache)
@unpack invM, MDM, Ml_lhs, Ml_rhs, Npts = element
......@@ -149,7 +168,7 @@ function compute_rhs_weak_form!(rhs, fx, fy::Real, nf, mesh::Mesh2d)
end
return
end
function compute_rhs_weak_form!(rhs, fx::Real, fy, nf, mesh::Mesh2d)
function compute_rhs_weak_form!(rhs, _::Real, fy, nf, mesh::Mesh2d{SpectralElement})
@unpack element = mesh
@unpack dxdX, dydY = get_static_variables(mesh.cache)
@unpack invM, MDM, Ml_lhs, Ml_rhs, Npts = element
......@@ -205,7 +224,7 @@ function compute_rhs_weak_form!(rhs, fx::Real, fy, nf, mesh::Mesh2d)
end
return
end
function compute_rhs_weak_form!(rhs, fx, fy, s, nf, mesh::Mesh2d)
function compute_rhs_weak_form!(rhs, fx, fy, s, nf, mesh::Mesh2d{SpectralElement})
compute_rhs_weak_form!(rhs, fx, fy, nf, mesh)
rhs .+= s
return
......
"""
Evolution(rhs_fn!, u0::Vector{Float64}, timestep, tend, alg::TimeStepAlgorithm;
cfl = 1.0,
callback_fullstep::Union{Function,CallbackSet,Nothing}=nothing,
callback_substep::Union{Function,CallbackSet,Nothing}=nothing)
......@@ -16,7 +15,6 @@ the current iteration index. The use of `i` is deprecated.
- time stepping algorithm `alg`; see the function `algorithm`
Optional values
- `cfl` is a global fudge factor to control the time step size; we require `cfl ≥ 0`
- `callback_fullstep` to run a set of functions after a full time step; possible values
- `callback_fullstep = nothing` is default
- `callback_fullstep <: CallbackSet` to run a set of functions in a specified order;
......@@ -28,9 +26,9 @@ Optional values
`callback_fullstep` for allowed values
"""
function Evolution(rhs_fn!, u0::Vector{Float64}, timestep, tend, alg::TimeStepAlgorithm;
cfl = 1.0,
callback_fullstep::Union{Function,CallbackSet,Nothing}=nothing,
callback_substep::Union{Function,CallbackSet,Nothing}=nothing)
callback_substep::Union{Function,CallbackSet,Nothing}=nothing,
mesh::Union{Mesh,Nothing}=nothing)
# enforce interface
timestep_fn = if timestep isa Number
......@@ -59,18 +57,19 @@ function Evolution(rhs_fn!, u0::Vector{Float64}, timestep, tend, alg::TimeStepAl
end
callback_substep
end
if cfl <= 0
throw(ArgumentError("require cfl ≥ 0"))
end
stages = make_RK_stages(size(u0), alg.nstages)
stepper!(up1, u0, t, dt) = step!(up1, rhs_fn!, u0, t, dt, stages, alg, cb_substep)
stepper! = if mesh isa Mesh1d{FVElement}
(up1, u0, t, dt) -> rhs_fn!(up1, u0, t)
else
(up1, u0, t, dt) -> step!(up1, rhs_fn!, u0, t, dt, stages, alg, cb_substep)
end
up1 = deepcopy(u0) # we need to initialize memory anyways
# and the stepper will swap up1, u before he steps
return Evolution{typeof(rhs_fn!), typeof(timestep_fn), typeof(cb_fullstep),
typeof(cb_substep), typeof(alg), typeof(stepper!)}(
rhs_fn!, timestep_fn, cb_fullstep, cb_substep, tend, cfl, alg,
rhs_fn!, timestep_fn, cb_fullstep, cb_substep, tend, alg,
stepper!,
up1, u0, 0.0, 0, stages)
end
......@@ -86,13 +85,10 @@ function step!(evo::Evolution)
@unpack up1, u, t, tend, timestep, stepper!, cb_fullstep = evo
# TODO Remove cfl from all projects and enable it here
# dt = evo.cfl * timestep(u, t, 0)
dt = timestep(u, t, 0)
if isnothing(dt)
println()
@info """Termination requested by timestep function!"""
return :timestep_termination
if isinvalid(dt)
@info """Invalid timestep proposed at t = $(t)!"""
return :timestep_invalid
end
# shorten time step to hit tend exactly
......@@ -115,7 +111,7 @@ function step!(evo::Evolution)
evo.it += 1
if cb_fullstep(up1, evo.t, evo.it) == false
@info """Callback after step failed at t = $(t)!"""
@info """Callback failed after step at t = $(t)!"""
return :callback_failed
end
......@@ -299,6 +295,7 @@ function isinvalid(u)
end
return false
end
isinvalid(u::Real) = isnan(u) || isinf(u)
"""
......
"""
fv_update_step!(up1, u, f, mesh::Mesh1d{FVElement})
fv_update_step!(up1, u, f, s, mesh::Mesh1d{FVElement})
Update the state `up1` of the FV formulation using
the (bulk) flux `f` and sources `s` for a `mesh`.
"""
function fv_update_step!(up1, u, f, bdry_u, bdry_f, dt, mesh::Mesh1d{FVElement})
@unpack invjac = mesh
@unpack K = mesh
dl = widths(mesh)[1] / K
dtdl = dt/dl
@turbo for j = 2:K-1
up1[j] = (u[j+1] + 2*u[j] + u[j-1])/4 - dtdl/2 * (f[j+1] - f[j-1])
end
if mesh.tree.periodic[1]
up1[1] = (u[2] + 2*u[1] + u[end])/4 - dtdl/2 * (f[2] - f[end])
up1[end] = (u[1] + 2*u[end] + u[end-1])/4 - dtdl/2 * (f[1] - f[end-1])
else
up1[1] = (u[2] + 2*u[1] + bdry_u[1])/4 - dtdl/2 * (f[2] - bdry_f[1])
up1[end] = (bdry_u[end] + 2*u[end] + u[end-1])/4 - dtdl/2 * (bdry_f[end] - f[end-1])
end
return
end
function fv_update_step!(up1, u, f, s, bdry_u, bdry_f, bdry_s, dt, mesh::Mesh1d{FVElement})
TODO()
@unpack invjac = mesh
@unpack K = mesh
dl = widths(mesh)[1] / K
dtdl = dt/dl
@turbo for j = 2:K-1
up1[j] = (u[j+1] + 2*u[j] + u[j-1])/4 - dtdl/2 * (f[j+1] - f[j-1]) + s[j]
end
if mesh.tree.periodic[1]
up1[1] = (u[2] + 2*u[1] + u[end])/4 - dtdl/2 * (f[2] - f[end]) + s[1]
up1[end] = (u[1] + 2*u[end] + u[end-1])/4 - dtdl/2 * (f[1] - f[end-1]) + s[end]
else
TODO()
end
return
end
struct FVElement
N::Int64 # polynomial order
Npts::Int64 # number of quadrature points = N + 1
z::Vector{Float64} # quadrature points
function FVElement()
N = 0
Npts = N + 1
z = [0.0]
return new(N, Npts, z)
end
end
@inline function inner_product(u1, u2, el::FVElement)
TODO()
@unpack quadr = el
@toggled_assert quadr in (:LGL, :LGL_lumped, :GLGL)
if quadr === :LGL || quadr === :LGL_lumped
return LGL.integrate(el.w, u1, u2)
else # quadr === :GLGL
return GLGL.integrate(el.w, el.v, el.D, u1, u2)
end
end
@inline function integrate(u, el::FVElement)
TODO()
@unpack quadr = el
@toggled_assert quadr in (:LGL, :LGL_lumped, :GLGL)
if quadr === :LGL || quadr === :LGL_lumped
return LGL.integrate(el.w, u)
else # quadr === :GLGL
return GLGL.integrate(el.w, el.v, el.D, u)
end
end
......@@ -174,7 +174,8 @@ function setup(project_name, prms, outputdir)
mprms = prms["Mesh"]
mesh = Mesh(; N=mprms["n"], K=mprms["k"], range=mprms["range"],
basis=mprms["basis"], CFL=mprms["cfl"], periodic=mprms["periodic"])
basis=mprms["basis"], periodic=mprms["periodic"],
scheme=mprms["scheme"])
callbacks = CallbackSet()
callbacks_substeps = CallbackSet()
env = Environment(mesh, mesh.cache, callbacks, callbacks_substeps)
......@@ -204,13 +205,17 @@ function setup(project_name, prms, outputdir)
rhs!(env, project, bdryconds)
end
end
CFL = prms["Evolution"]["cfl"]
function wrapped_timestep!(state_u, t, n)
@timeit TO "Wrap dynamic" begin
wrap_dynamic_variables!(env.cache, state_u)
end
# provided by project interface
@timeit TO "timestep" begin
timestep!(env, project)
# need to apply CFL for all projects here,
# because different places might call this function (e.g. SaveCallback, Evolution)
# and compute a timestep multiple times, although, that should be deprecated in the future
CFL * timestep!(env, project)
end
end
function wrap_dynamic(state_u, t)
......@@ -275,14 +280,13 @@ function evolve!(env, rhs_fn, timestep_fn, prms)
# TODO Move cfl parameter from Mesh to Evolution
# @unpack tend, method, cfl = prms["Evolution"]
@unpack tend, method = prms["Evolution"]
@unpack cfl = prms["Mesh"]
@unpack tend, method, cfl = prms["Evolution"]
@unpack mesh, cache, callbacks, callbacks_substep = env
# TODO This should be hidden behind an interface
u0 = reduce(vcat, fields(cache.dynamic_variables))
alg = algorithm(Symbol(method))
evolution = Evolution(rhs_fn, u0, timestep_fn, tend, alg; cfl=cfl,
evolution = Evolution(rhs_fn, u0, timestep_fn, tend, alg; mesh=mesh,
callback_fullstep=callbacks)#, callback_substep=callbacks_substep)
# run once on initial data
......
......@@ -6,24 +6,22 @@
abstract type AbstractMesh end
struct Mesh{N_Dim,N_Sides} <: AbstractMesh
struct Mesh{Element,N_Dim,N_Sides} <: AbstractMesh
tree::Tree{N_Dim,N_Sides} # abstract representation of mesh
boxes::Vector{Box{N_Dim}} # physical extends of each cell
extends::NTuple{N_Dim,Tuple{Float64,Float64}} # total extend of mesh
element::SpectralElement
element::Element
cache::Cache
bulkfaceindices::Vector{Int64}
bulkbdryindices::Vector{Int64}
faceindices::Vector{Int64}
bdryindices::Vector{Int64}
offsets::Vector{Int64} # data index offsets for each cell
# TODO Move to evolution part!
CFL::Float64
end
const Mesh1d = Mesh{1,2}
const Mesh2d = Mesh{2,4}
const Mesh1d{Element} = Mesh{Element,1,2}
const Mesh2d{Element} = Mesh{Element,2,4}
#######################################################################
......@@ -31,21 +29,25 @@ const Mesh2d = Mesh{2,4}
#######################################################################
function Mesh1d(; N=5, K=10, range=[-1.0,1.0], CFL=0.5, basis="lgl_lumped", periodic=(true,))
function Mesh1d(; N=5, K=10, range=[-1.0,1.0], basis="lgl_lumped", periodic=(true,), scheme="DG")
@toggled_assert N > 0
@toggled_assert K > 0
@toggled_assert CFL > 0
@toggled_assert length(range) == 2
@toggled_assert range[1] < range[2]
@toggled_assert basis in ["lgl", "glgl", "lgl_lumped"]
# TODO Remove this after updating parameter names in SpectralElement
quadrature_method = Dict("lgl" => :LGL, "glgl" => :GLGL, "lgl_lumped" => :LGL_lumped)[basis]
tree = Tree1d(K,periodic=tuple(periodic...))
boxes = place_boxes(tree, Float64.(range))
element = SpectralElement(N, Symbol(quadrature_method))
element = if scheme == "DG"
# TODO Remove this after updating parameter names in SpectralElement
quadrature_method = Dict("lgl" => :LGL, "glgl" => :GLGL, "lgl_lumped" => :LGL_lumped)[basis]
SpectralElement(N, Symbol(quadrature_method))
elseif scheme == "FV"
FVElement()
else
error("unknown approximation scheme $scheme")
end
@unpack Npts, z = element
offsets = [ (i-1)*Npts for i = 1:length(tree) ]
extends = (Tuple(range),)
......@@ -83,13 +85,14 @@ function Mesh1d(; N=5, K=10, range=[-1.0,1.0], CFL=0.5, basis="lgl_lumped", peri
Int64[1, Npts*K], Int64[1, 2*K]
end
return Mesh1d(tree, boxes, extends, element, cache,
bulkfaceindices, bulkbdryindices, faceindices, bdryindices, offsets, CFL)
return Mesh1d{typeof(element)}(tree, boxes, extends, element, cache,
bulkfaceindices, bulkbdryindices, faceindices, bdryindices, offsets)
end
function Mesh2d(; N=[5,5], K=[4,4], range=[-1.0,1.0, -1.0,1.0],
CFL=0.5, basis="lgl_lumped", periodic=(true,true))
basis="lgl_lumped", periodic=(true,true),
scheme="DG")
@toggled_assert length(N) == 2
@toggled_assert length(K) == 2
......@@ -102,18 +105,24 @@ function Mesh2d(; N=[5,5], K=[4,4], range=[-1.0,1.0, -1.0,1.0],
@toggled_assert Ny > 0
@toggled_assert Kx > 0
@toggled_assert Ky > 0
@toggled_assert CFL > 0
@toggled_assert xrange[1] < xrange[2]
@toggled_assert yrange[1] < yrange[2]
@toggled_assert basis in ["lgl", "glgl", "lgl_lumped"]
# TODO Remove this after updating parameter names in SpectralElement
quadrature_method = Dict("lgl" => :LGL, "glgl" => :GLGL, "lgl_lumped" => :LGL_lumped)[basis]
tree = Tree2d(Kx,Ky,periodic=tuple(periodic...))
boxes = place_boxes(tree, Float64.(xrange), Float64.(yrange))
element_x = SpectralElement(Nx, Symbol(quadrature_method))
element_y = SpectralElement(Ny, Symbol(quadrature_method))
if scheme == "DG"
# TODO Remove this after updating parameter names in SpectralElement
quadrature_method = Dict("lgl" => :LGL, "glgl" => :GLGL, "lgl_lumped" => :LGL_lumped)[basis]
element_x = SpectralElement(Nx, Symbol(quadrature_method))
element_y = SpectralElement(Ny, Symbol(quadrature_method))
elseif scheme == "FV"
TODO()
element_x = FVElement()
element_y = FVElement()
else
error("unknown approximation scheme $scheme")
end
Nptsx, Nptsy = element_x.Npts, element_y.Npts
offsets = [ (i-1)*Nptsx*Nptsy for i = 1:length(tree) ]
extends = (Tuple(xrange),Tuple(yrange))
......@@ -203,16 +212,16 @@ function Mesh2d(; N=[5,5], K=[4,4], range=[-1.0,1.0, -1.0,1.0],
bo += Nptsx
end
return Mesh2d(tree, boxes, extends, element_x, cache,
bulkfaceindices, bulkbdryindices, faceindices, bdryindices, offsets, CFL)
return Mesh2d{typeof(element_x)}(tree, boxes, extends, element_x, cache,
bulkfaceindices, bulkbdryindices, faceindices, bdryindices, offsets)
end
# TODO Deprecate and remove this
function Mesh(; N, K, range, CFL, basis, periodic)
function Mesh(; N, K, range, basis, periodic, scheme)
dims = Int(length(range)/2)
MType = dims == 1 ? Mesh1d : Mesh2d
return MType(; N, K, range, CFL, basis, periodic)
return MType(; N, K, range, basis, periodic, scheme)
end
......@@ -269,11 +278,31 @@ end
#######################################################################
function widths(m::Mesh{N}) where N
function widths(m::Mesh{Element,N}) where {Element,N}
return NTuple{N,Float64}(abs(r-l) for (l,r) in m.extends)
end
function min_grid_spacing(m::Mesh{SpectralElement})
@unpack z = m.element
ws = widths(m.boxes[1])
minw = minimum(ws)
dz = z[2]-z[1]
return minw .* dz
end
function min_grid_spacing(m::Mesh{FVElement})
return minimum(widths(m)) / m.K
end
function isperiodic(m::Mesh)
if !all(p -> p == first(m.tree.periodic), m.tree.periodic)
TODO("mixed periodic domain")
end
return all(m.tree.periodic)
end
Base.broadcastable(mesh::Mesh) = Ref(mesh)
......@@ -454,7 +483,7 @@ Base.extrema(mesh::Mesh) = mesh.extends
npoints(mesh::Mesh)::Int = mesh.element.Npts
ncells(mesh::Mesh) = prod(mesh.tree.dims)
layout(mesh::Mesh)::Tuple{Int,Int} = (n_points(mesh), n_cells(mesh))
dimension(mesh::Mesh{N}) where N = N
dimension(mesh::Mesh{Element,N}) where {Element,N} = N
cache(mesh::Mesh) = mesh.cache
@deprecate n_points(mesh::Mesh) npoints(mesh)
......@@ -572,14 +601,14 @@ end
Base.length(it::CellDataIterator) = ncells(it.mesh)
Base.size(it::CellDataIterator) = size(it.mesh)
@inline function Base.iterate(it::CellDataIterator{Mesh1d}, state=1)
@inline function Base.iterate(it::CellDataIterator{<:Mesh1d}, state=1)
state > length(it) && return nothing
idx = it.mesh.offsets[state]
Npts = it.mesh.element.Npts
v = view(it.data, idx+1:idx+Npts)
return v, state+1
end
@inline function Base.iterate(it::CellDataIterator{Mesh2d}, state=1)
@inline function Base.iterate(it::CellDataIterator{<:Mesh2d}, state=1)
state > length(it) && return nothing
idx = it.mesh.offsets[state]
Npts = it.mesh.element.Npts
......
......@@ -486,11 +486,6 @@ end
@check length(n) in [ 1, 2 ]
@check all(n .>= 0)
# TODO Move to Evolution
"Courant-Friedrichs-Lewy factor"
cfl = 0.5
@check cfl > 0.0
"""
Nodal basis type used for the polynomial approximation. Available options
- `lgl_lumped` ... mass-lumped Legendre-Gauss-Lobatto grid
......@@ -500,6 +495,14 @@ end
basis = "lgl_lumped"
@check basis in [ "lgl", "glgl", "lgl_lumped" ]
"""
Cellwise approximation scheme of solution. Available options
- `DG` ... Discontinuous Galerkin
- `FV` ... Finite Volume (central scheme)
"""
scheme = "DG"
@check scheme in [ "DG", "FV" ]
"""
Periodicity in cartesian directions `x,y`
- 1d: [ `x_dir` ]
......@@ -533,6 +536,10 @@ end
@check method in [ "midpt", "lserk4", "rk4", "rk4_twothirds", "rk_ralston",
"ssprk3", "rkf10" ]
"Courant-Friedrichs-Lewy factor"
cfl = 0.5
@check cfl > 0.0
end
......
"""
abstract AbstractRiemannSolver
Supertype for `RiemannSolver` interface.
Any Subtype `Solver <: RiemannSolver` must implement the following methods
- `register_variables!(cache, rsolver::Solver)`,
- `broadcast_faces!(f, rsolver::Solver, cache::Cache, mesh::Mesh)` where `f` is a
`@with_signature` function.
- `broadcast_boundaries!(f, rsolver::Solver, cache::Cache, mesh::Mesh)` where `f` is a
`@with_signature` function.
See: `ApproxRiemannSolver`
"""
abstract type AbstractRiemannSolver end
"""
ApproxRiemannSolver <: AbstractRiemannSolver
Approximate Riemann solver interface for numerical flux computation.
Available implementations: `lax_friedrich_flux`, `central_flux`.
---
ApproxRiemannSolver(flux, speed, equation; lhs_bc=PeriodicBC(), rhs_bc=PeriodicBC())
Assumes an `equation::SomeEquation` that implements two `@with_signature` functions
- `flux` - compute fluxes given a state,
- `speed` - compute maximum wave speed given a state.
# Example
```julia
julia> mesh = Mesh1d(); cache = Cache(mesh);
julia> rsolver = ApproxRiemannSolver(flux, speed, equation)
# register numerical fluxes as bdry variables in cache
julia> register_variables!(cache, rsolver);
# compute numerical fluxes inplace in cache
julia> broadcast_faces!(lax_friedrich_flux, rsolver, cache, mesh)
```
"""
struct ApproxRiemannSolver{T_Flux<:Function,
T_Speed<:Function,
T_Equation<:AbstractEquation,
N,M} <: AbstractRiemannSolver
flux::T_Flux
speed::T_Speed
equation::T_Equation
accepts::NTuple{N,Symbol}
returns::NTuple{M,Symbol}
function ApproxRiemannSolver(flux, speed, equation)
# TODO Verify that accepts and returns have been registered
@assert has_signature(flux, equation) begin
"ApproxRiemannSolver: Require a @with_signature function for '$flux'"
end
@assert has_signature(speed, equation) begin
"ApproxRiemannSolver: Require a @with_signature function for '$speed'"
end
_accepts_flux, returns_flux = signature(flux, equation)
accepts_flux = _accepts_flux[1]
_accepts_speed, returns_speed = signature(speed, equation)
accepts_speed = _accepts_speed[1]
state_indices_flux = state_indices(flux, equation)
state_indices_speed = state_indices(speed, equation)
@assert all(accepts_flux .=== accepts_speed) && state_indices_flux == state_indices_speed begin
"ApproxRiemannSolver: '$flux' and '$speed' must accept the same variables"
end
@assert length(returns_flux) == length(state_indices_flux) begin
"ApproxRiemannSolver: '$flux' must return one value per state variable!"
end
@assert length(returns_speed) == 1 begin
"ApproxRiemannSolver: '$speed' must return exaclty one value!"
end
rsolve_accepts = accepts_flux
rsolve_returns = Tuple( Symbol(:nflx_, accepts_flux[i]) for i in state_indices_flux )
return new{typeof(flux), typeof(speed), typeof(equation),
length(rsolve_accepts), length(rsolve_returns)}(
flux, speed, equation, rsolve_accepts, rsolve_returns)
end
end
struct ApproxRiemannSolver2d{T_Flux<:Function,
T_Speed<:Function,
T_Equation<:AbstractEquation,
N,M} <: AbstractRiemannSolver
flux::T_Flux
speed::T_Speed
equation::T_Equation
accepts::NTuple{N,Symbol}
returns::NTuple{M,Symbol}
function ApproxRiemannSolver2d(flux, speed, equation)
# TODO Verify that accepts and returns have been registered
@assert has_signature(flux, equation) begin
"ApproxRiemannSolver2d: Require a @with_signature function for '$flux'"
end
@assert has_signature(speed, equation) begin
"ApproxRiemannSolver2d: Require a @with_signature function for '$speed'"
end
_accepts_flux, returns_flux = signature(flux, equation)
accepts_flux = _accepts_flux[1]
_accepts_speed, returns_speed = signature(speed, equation)
accepts_speed = _accepts_speed[1]
state_indices_flux = state_indices(flux, equation)
state_indices_speed = state_indices(speed, equation)
@assert all(accepts_flux .=== accepts_speed) && state_indices_flux == state_indices_speed begin
"ApproxRiemannSolver2d: '$flux' and '$speed' must accept the same variables"
end
@assert length(returns_flux) == 2 * length(state_indices_flux) begin
"ApproxRiemannSolver2d: '$flux' must return two values per state variable!"
end
@assert length(returns_speed) == 1 begin
"ApproxRiemannSolver2d: '$speed' must return exaclty one value!"
end
rsolve_accepts = accepts_flux
rsolve_returns = Tuple( Symbol(:nflx_, accepts_flux[i]) for i in state_indices_flux )
return new{typeof(flux), typeof(speed), typeof(equation),
length(rsolve_accepts), length(rsolve_returns)}(
flux, speed, equation, rsolve_accepts, rsolve_returns)
end
end
# handmade definitions for @with_signature flux functions
# - accepts the same variables as rsolve.flux, rsolve.speed
# - state_indices are taken from rsolve.flux, rsolve.speed
# - returns a tuple of tuples of numerical fluxes, e.g. ( (lhs_numflx_u,), (rhs_numflx_u,) )
# this is different from the standard @with_signature interface, but we need it to separate
# the StructArrays belonging to the numerical fluxes on the left and right interfaces.
state_indices(f, rsolve::AbstractRiemannSolver) = state_indices(rsolve.flux, rsolve.equation)
function signature(f, rsolve::AbstractRiemannSolver)
return rsolve.accepts, rsolve.returns
end
state(tpl, rsolve::AbstractRiemannSolver) = state(tpl, state_indices(rsolve.flux, rsolve.equation))
function register_variables!(mesh::Mesh, rsolver::AbstractRiemannSolver; dont_register=false)
register_variables!(mesh.cache, rsolver; dont_register)
end
function register_variables!(cache, rsolver::AbstractRiemannSolver; dont_register=false)
@unpack flux, equation = rsolver
returns_vars = returns(lax_friedrich_flux, rsolver)
if !dont_register
register_variables!(cache, bdry_variablenames=returns_vars)
end
end
#######################################################################
# Abstractimate Riemann Solver implementations #
#######################################################################
project(normal::NTuple{1}, flxs::NTuple{1}) = (dot(normal,flxs),)
project(normal::NTuple{2}, flxs::NTuple{2}) = (dot(normal,flxs),)
project(normal::NTuple{1}, flxs::NTuple{N}) where N = dot.(Ref(normal),flxs)
project(normal::NTuple{2}, flxs::NTuple{N,<:NTuple}) where N = dot.(Ref(normal),flxs)
maxspeed(vl::NTuple{1}, vr::NTuple{1}) = max(abs(vl[1]), abs(vr[1]))
"""
@with_signature lax_friedrich_flux(args, rsolver::AbstractRiemannSolver)
See: `AbstractRiemannSolver`
"""
function lax_friedrich_flux(args, rsolver::AbstractRiemannSolver, normal=(1.0,))
args_lhs, args_rhs = args
@unpack flux, speed, equation = rsolver
ul, ur = state(args_lhs, rsolver), state(args_rhs, rsolver)
vl, vr = speed(args_lhs, equation), speed(args_rhs, equation)
fl, fr = flux(args_lhs, equation), flux(args_rhs, equation)
nfl, nfr = project(normal, fl), project(normal, fr)
C = maxspeed(vl,vr)
nf = @. 0.5 * (nfl + nfr) + 0.5 * C * (ur - ul)
nf
end
has_signature(::typeof(lax_friedrich_flux), ::AbstractRiemannSolver) = true
"""
@with_signature lax_friedrich_flux(args, rsolver::AbstractRiemannSolver)
See: `AbstractRiemannSolver`
"""
function avg_lax_friedrich_flux(args, rsolver::AbstractRiemannSolver, normal=(1.0,))
args_lhs, args_rhs = args
@unpack flux, speed, equation = rsolver
ul, ur = state(args_lhs, rsolver), state(args_rhs, rsolver)
vl, vr = speed(args_lhs, equation), speed(args_rhs, equation)
avg_args = @. 0.5 * (args_lhs + args_rhs)
avg_f = flux(avg_args, equation)
avg_nf = project(normal, avg_f)
C = maxspeed(vl,vr)
nf = @. avg_nf + 0.5 * C * (ur - ul)
nf
end
has_signature(::typeof(avg_lax_friedrich_flux), ::AbstractRiemannSolver) = true
"""
@with_signature central_flux(args, rsolver::AbstractRiemannSolver)
See: `AbstractRiemannSolver`
"""
function central_flux(args, rsolver::AbstractRiemannSolver, normal=(1.0,))
args_lhs, args_rhs = args
ul, ur = state(args_lhs, rsolver), state(args_rhs, rsolver)
nf = @. 0.5 * (ul + ur)
nf
end
has_signature(::typeof(central_flux), ::AbstractRiemannSolver) = true
"""
@with_signature mean_flux(args, rsolver::AbstractRiemannSolver)
See: `AbstractRiemannSolver`
"""
function mean_flux(args, rsolver::AbstractRiemannSolver, normal=(1.0,))
args_lhs, args_rhs = args
@unpack flux, equation = rsolver
fl, fr = flux(args_lhs, equation), flux(args_rhs, equation)
nfl, nfr = project(normal, fl), project(normal, fr)
nf = @. 0.5 * (nfl + nfr)
nf
end
has_signature(::typeof(mean_flux), ::AbstractRiemannSolver) = true
......@@ -15,7 +15,6 @@ mutable struct Evolution{T_RHS, T_Timestep, T_CB_Fullstep, T_CB_Substep, T_Algor
cb_fullstep::T_CB_Fullstep
cb_substep::T_CB_Substep
tend::Float64
cfl::Float64
alg::T_Algorithm
stepper!::T_Stepper
......
......@@ -270,7 +270,7 @@ end
TODO() = error("Not implemented yet!")
TODO(msg::String) = error(msg)
TODO(msg::String) = error("Not implemented: $msg")
TODO(fn::Function) = error("'$fn': Not implemented yet!")
TODO(fn::Function, msg) = error("""
'$fn'' Not implemented yet!
......
......@@ -4,7 +4,7 @@
This is the current default behavior of `@with_signature` but it is deprecated.
Macro to extract `accepted` and `returned` arguments from a function and make them
programmatically accessible. Functions defined using this macro can be used with
the `broadcast_volume!, broadcast_faces!, broadcast_bdrys` interfaces.
the `[new_]broadcast_[volume|faces|bdrys]_2!` interfaces.
# Example
......@@ -36,7 +36,7 @@ has_signature(::typeof(fn), ::Equation) = true
Macro to extract `accepted` and `returned` arguments from a function and make them
programmatically accessible. Functions defined using this macro can be used with
the `broadcast_volume!, broadcast_faces!, broadcast_bdrys` interfaces.
the `[new_]broadcast_[volume|faces|bdrys]_2!` interfaces.
To test a `@with_signature` in standalone mode, see below for an example.
......@@ -736,242 +736,6 @@ has_signature(fn, dispatch, any) = has_signature(fn, typeof(dispatch), nothing)
#######################################################################
#######################################################################
# version 1 #
#######################################################################
"""
broadcast_volume!(f::Function, D::T_Dispatch, C::Cache) where T_Dispatch
Broadcast a `@with_signature` function `f` dispatched on `D` over variables of
cache `C`. The arguments and return values are determined by the `accepts, returns` overloads
of `(f, D)` from the `@with_signature` interface.
See: `@with_signature`.
"""
function broadcast_volume!(f::Function, D::T_Dispatch, C::Cache) where T_Dispatch
has_signature(f, D) || error("broadcast_volume!: Function '$f' has no signature!")
accepts_syms, returns_syms = accepts(f, D)[1], returns(f, D)
n_accepts, n_returns = length(accepts_syms), length(returns_syms)
v = get_variable(C, first(accepts_syms))
accepts_vars = NTuple{n_accepts,Vector{Float64}}( get_variable(C, a) for a in accepts_syms )
returns_vars = NTuple{n_returns,Vector{Float64}}( get_variable(C, r) for r in returns_syms )
f_args = args -> f(args, D)
_broadcast_volume!(f_args, C, accepts_vars, returns_vars)
return
end
function _broadcast_volume!(f, C::Cache,
accepts_vars::NTuple{N,Vector{Float64}},
returns_vars::NTuple{M,Vector{Float64}}) where {N,M}
soa_accepts = StructArray{NTuple{N,Float64}}( accepts_vars )
soa_returns = StructArray{NTuple{M,Float64}}( returns_vars )
npts = length(first(accepts_vars))
for idx = 1:npts
soa_returns[idx] = f(soa_accepts[idx])
end
return
end
@deprecate broadcast_volume!(f, rsolver, cache, mesh) broadcast_faces!(f, rsolver, cache, mesh)
"""
broadcast_faces!(f, rsolver::ApproxRiemannSolver, cache::Cache, mesh::Mesh)
Broadcast `@with_signature` function `f` with dispatch `rsolver` over `cache` and
faces of `mesh`.
"""
function broadcast_faces!(f, rsolver::AbstractRiemannSolver, cache::Cache, mesh::Mesh)
has_signature(f, rsolver) || error("broadcast_faces!: Function '$f' has no signature!")
accepts_vars = accepts(f, rsolver)
returns_vars = returns(f, rsolver)
n_accepts = length(accepts_vars)
n_returns = length(returns_vars)
accepts_data = NTuple{n_accepts,Vector{Float64}}( get_variable(cache, a)
for a in accepts_vars )
returns_data = NTuple{n_returns,Vector{Float64}}( get_variable(cache, r)
for r in returns_vars )
f_args = (args, normal) -> f(args, rsolver, normal)
_broadcast_face!(f_args, mesh, accepts_data, returns_data)
return
end
function broadcast_faces!(f, equation, mesh::Mesh)
has_signature(f, equation) || error("broadcast_faces!: Function '$f' has no signature!")
accepts_vars = accepts(f, equation)
returns_vars = returns(f, equation)
n_accepts = length(accepts_vars)
n_returns = length(returns_vars)
accepts_data = NTuple{n_accepts,Vector{Float64}}( get_variable(cache, a)
for a in accepts_vars )
returns_data = NTuple{n_returns,Vector{Float64}}( get_variable(cache, r)
for r in returns_vars )
f_args = (args, normal) -> f(args, rsolver, normal)
_broadcast_face!(f_args, mesh, accepts_data, returns_data)
return
end
function _broadcast_face!(f, mesh::Mesh1d,
state_data::NTuple{Na,Vector{Float64}},
nflx_data::NTuple{Nr,Vector{Float64}}) where {Na,Nr}
soa_state = StructArray{NTuple{Na,Float64}}( state_data )
soa_nflx = StructArray{NTuple{Nr,Float64}}( nflx_data )
Npts, K = layout(mesh)
Idx = LinearIndices((Npts, K))
n_out = (-1.0,) # 'normal' in negative direction
for cell = 1:K
# each loop deals with the lhs' interface of the current cell
cell_out, cell_in = periodic_index(cell-1, K), cell
idx_out, idx_in = Idx[end,cell_out], Idx[1,cell_in]
accepts_out, accepts_in = soa_state[idx_out], soa_state[idx_in]
nflx_out = f((accepts_out, accepts_in), n_out)
# n_out * f(u_out, u_in, n_out) = - n_in * f(u_in, u_out, n_in) with n_out = -n_in
nflx_in = @. -nflx_out
face_idx_rhs = cell
face_idx_lhs = K+periodize(cell-K-1, K)
soa_nflx[face_idx_rhs] = nflx_out
soa_nflx[face_idx_lhs] = nflx_in
end
return
end
function _broadcast_face!(f, mesh::Mesh2d,
state_data::NTuple{Na,Vector{Float64}},
nflx_data::NTuple{Nr,Vector{Float64}}) where {Na,Nr}
soa_state = StructArray{NTuple{Na,Float64}}( state_data )
soa_nflx = StructArray{NTuple{Nr,Float64}}( nflx_data )
@unpack Npts = mesh.element
Kx, Ky = mesh.tree.dims
Nx, Ny = Npts, Npts
@unpack nx, ny = get_bdry_variables(mesh.cache)
for cell = 1:ncells(mesh)
bulk_in = cellindices(mesh, cell)
rng_lhs, rng_rhs, rng_down, rng_up = faceindices(mesh, cell)
# xmin face
cell_out = mesh.tree.cells[cell].neighbors[Cart2d.Xmin]
bulk_out = cellindices(mesh, cell_out)
for (fidx, face_idx) in enumerate(rng_lhs)
idx_out = bulk_out[end,fidx]
idx_in = bulk_in[1,fidx]
n_out = (nx[face_idx], ny[face_idx])
state_out, state_in = soa_state[idx_out], soa_state[idx_in]
soa_nflx[face_idx] = f((state_out, state_in), n_out)
end
# xmax face
cell_out = mesh.tree.cells[cell].neighbors[Cart2d.Xmax]
bulk_out = cellindices(mesh, cell_out)
for (fidx, face_idx) in enumerate(rng_rhs)
idx_out = bulk_out[1,fidx]
idx_in = bulk_in[end,fidx]
n_out = (nx[face_idx], ny[face_idx])
state_out, state_in = soa_state[idx_out], soa_state[idx_in]
soa_nflx[face_idx] = f((state_out, state_in), n_out)
end
# ymin face
cell_out = mesh.tree.cells[cell].neighbors[Cart2d.Ymin]
bulk_out = cellindices(mesh, cell_out)
for (fidx, face_idx) in enumerate(rng_down)
idx_out = bulk_out[fidx,end]
idx_in = bulk_in[fidx,1]
n_out = (nx[face_idx], ny[face_idx])
state_out, state_in = soa_state[idx_out], soa_state[idx_in]
soa_nflx[face_idx] = f((state_out, state_in), n_out)
end
# ymax face
cell_out = mesh.tree.cells[cell].neighbors[Cart2d.Ymax]
bulk_out = cellindices(mesh, cell_out)
for (fidx, face_idx) in enumerate(rng_up)
idx_out = bulk_out[fidx,1]
idx_in = bulk_in[fidx,end]
n_out = (nx[face_idx], ny[face_idx])
state_out, state_in = soa_state[idx_out], soa_state[idx_in]
soa_nflx[face_idx] = f((state_out, state_in), n_out)
end
end
return
end
"""
broadcast_boundaryconditions!(f, bc::BoundaryConditions, cache::Cache, mesh::Mesh, t)
Impose boundary conditions using the flux method `f` and boundary conditions `bc`
for fields in `cache` on `mesh` at time `t`.
"""
function broadcast_boundaryconditions!(f, bc::BoundaryConditions, cache::Cache, mesh::Mesh1d, t=0.0)
@unpack lhs_bc, rhs_bc, rsolver = bc
has_signature(f, rsolver) || error("broadcast_volume!: Function '$f' has no signature!")
bc.isperiodic && return
accepts_vars = accepts(f, rsolver)
returns_vars = returns(f, rsolver)
n_accepts = length(accepts_vars)
n_returns = length(returns_vars)
accepts_data = NTuple{n_accepts,Vector{Float64}}( get_variable(cache, a)
for a in accepts_vars )
returns_data = NTuple{n_returns,Vector{Float64}}( get_variable(cache, r)
for r in returns_vars )
f_args = (args, normal) -> f(args, rsolver, normal)
_broadcast_boundaries!(f_args, mesh, accepts_data, returns_data, lhs_bc, rhs_bc, t)
return
end
boundary_state(bc::AbstractBC, u_inner, t) = bc(t)
boundary_state(bc::OutflowBC, u_inner, t) = u_inner
# similar to _broadcast_volume!, but only iterating domain bdrys and broadcasting
# over two returned data arrays
function _broadcast_boundaries!(f, mesh::Mesh,
state_data::NTuple{Na,Vector{Float64}},
nflx_data::NTuple{Nr,Vector{Float64}},
lhs_bc::AbstractBC,
rhs_bc::AbstractBC,
t) where {Na,Nr}
soa_state = StructArray{NTuple{Na,Float64}}( state_data )
soa_nflx = StructArray{NTuple{Nr,Float64}}( nflx_data )
Npts, K = layout(mesh)
# lhs bdry
cell = 1
state = soa_state[cell]
bdry_state = boundary_state(lhs_bc, state, t)
n_out = (-1.0,)
nf = f((bdry_state, state), n_out)
soa_nflx[1] = f((bdry_state, state), n_out)
# rhs bdry
cell = K
state = soa_state[end]
bdry_state = boundary_state(rhs_bc, state, t)
n_out = (1.0,)
soa_nflx[end] = f((bdry_state, state), n_out)
return
end
#######################################################################
# version 2 #
#######################################################################
......