Skip to content
Snippets Groups Projects
Commit 285a652c authored by Florian Atteneder's avatar Florian Atteneder
Browse files

Insert original line number nodes for @with_signature functions (!90)

* insert line number nodes for better stack traces
parent a0a40a16
No related branches found
No related tags found
1 merge request!90Insert original line number nodes for @with_signature functions
Pipeline #6116 passed
......@@ -47,6 +47,7 @@ The following options are available to customize `@with_signature` definitions:
if the function parsing works correctly.
- `inbounds::Bool=true`: Controls `@inbounds` access for the grid loop.
- `simd::Bool=true`: Controls `@simd` optimization for the grid loop.
Use `simd=false` for debugging to restore correct line numbers in stack traces.
- `debug::Bool=false`: To disable `inbounds, simd` optimizations.
- `print_idxs=false`: To enable printing of volume and boundary indices.
......@@ -494,6 +495,7 @@ parse_returns(line, ex) = parse_options_variables("@returns", line, ex)
function with_signature(expr, options=(;))
# display(expr)
ex = MacroTools.prewalk(MacroTools.rmlines, expr)
# display(ex)
......@@ -506,6 +508,19 @@ function with_signature(expr, options=(;))
push!(fn_signature.args, :(mesh::Any))
end
fn_name, first_arg, second_arg = fn_signature.args
lno_func = expr.args[2].args[1] # line number node
# sanitize function arguments
(first_arg isa Expr && first_arg.head === :(::)) || error(
"@with_signature: function argument '$first_arg' must be type annotated")
dispatch_type_first = last(first_arg.args)
dispatch_type_second = if second_arg isa Symbol
:Any
elseif second_arg isa Expr && second_arg.head === :(::)
last(second_arg.args)
else
error("@with_signature: invalid argument '$second_arg'")
end
ex.args[2].head === :block || error("@with_singature: failed to extract function body")
fn_body = ex.args[2].args
......@@ -521,6 +536,9 @@ function with_signature(expr, options=(;))
end
length(accepts_lines) == 0 && error(
"@with_signature: expected '@accepts [options] a, b, ...' at the start of the function body")
lno_accepts = [ expr.args[2].args[2+2*(i-1)] for i in 1:length(accepts_lines) ] # line number nodes
# @show lno_accepts
returns_lines, returns_exs = Any[], Any[]
for line in reverse(fn_body) # search for @returns from bottom to top
matched_returns = MacroTools.@capture(line, @returns returns_ex__)
......@@ -533,6 +551,8 @@ function with_signature(expr, options=(;))
# reverse returns expressions so that we emmit code from top to bottom
returns_exs = returns_exs[end:-1:1]
returns_lines = returns_lines[end:-1:1]
lno_returns = [ expr.args[2].args[end-1-2*(i-1)] for i in 1:length(returns_lines) ] # line number nodes
# @show lno_returns
# @accepts and @returns can only occur right at the start or end
if any(line -> MacroTools.@capture(line, @accepts _), fn_body[length(accepts_lines)+1:end])
......@@ -554,6 +574,7 @@ function with_signature(expr, options=(;))
init_fn_body = expr.args[2].args
init_fn_body[2+2*length(accepts_lines):end-2*length(returns_lines)]
catch
@warn "@with_signature: failed to extract LineNumberNodes for function body"
fn_body[1+length(accepts_lines):end-length(returns_lines)]
end
else
......@@ -573,70 +594,74 @@ function with_signature(expr, options=(;))
# 1) resolve the prefix,
# 2) make a gen'ed symbol for array unpacking
# 3) construct the array unpacking calls
acc_vars, prefix_acc_vars, gen_acc_vars, bdry_acc = Symbol[], Symbol[], Symbol[], Bool[]
acc_vars, prefix_acc_vars, gen_acc_vars, bdry_acc =
Vector{Symbol}[], Vector{Symbol}[], Vector{Symbol}[], Vector{Bool}[]
unpack_acc = Expr[]
n = 1
for (i,(opts, vars)) in enumerate(accepts_options_vars)
append!(acc_vars, vars)
append!(prefix_acc_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars))
append!(bdry_acc, fill(opts[:bdry], length(vars)))
push!(acc_vars, vars)
push!(prefix_acc_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars))
push!(bdry_acc, fill(opts[:bdry], length(vars)))
gens = Symbol[]
for v in vars
gv = gensym(v)
push!(unpack_acc, :($gv = $acc[$n]))
push!(gen_acc_vars, gv)
push!(gens, gv)
n += 1
end
push!(gen_acc_vars, gens)
end
# for each variable in @returns
# 1) resolve the prefix,
# 2) make a gen'ed symbol for array unpacking
# 3) construct the array unpacking calls
ret_vars, prefix_ret_vars, gen_ret_vars, bdry_ret = Symbol[], Symbol[], Symbol[], Bool[]
ret_vars, prefix_ret_vars, gen_ret_vars, bdry_ret = Vector{Symbol}[], Vector{Symbol}[], Vector{Symbol}[], Vector{Bool}[]
unpack_ret = Expr[]
n = 1
for (i,(opts, vars)) in enumerate(returns_options_vars)
append!(ret_vars, vars)
append!(prefix_ret_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars))
append!(bdry_ret, fill(opts[:bdry], length(vars)))
push!(ret_vars, vars)
push!(prefix_ret_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars))
push!(bdry_ret, fill(opts[:bdry], length(vars)))
gens = Symbol[]
for v in vars
gv = gensym(v)
push!(unpack_ret, :($gv = $ret[$n]))
push!(gen_ret_vars, gv)
push!(gens, gv)
n += 1
end
push!(gen_ret_vars, gens)
end
# error on duplicated variable names (after resolving prefixes)
red_prefix_acc_vars = reduce(vcat, prefix_acc_vars)
if !allunique(prefix_acc_vars)
uvars = unique(prefix_acc_vars)
idxdups = findall(v -> count(Base.Fix1(===,v),prefix_acc_vars) > 1, uvars)
dups = prefix_acc_vars[idxdups]
uvars = unique(red_prefix_acc_vars)
idxdups = findall(v -> count(Base.Fix1(===,v),red_prefix_acc_vars) > 1, uvars)
dups = red_prefix_acc_vars[idxdups]
error("duplicated arguments in all @accepts found (after resolving prefixes): $(join(dups,", "))")
end
if !allunique(prefix_ret_vars)
uvars = unique(prefix_ret_vars)
idxdups = findall(v -> count(Base.Fix1(===,v),prefix_ret_vars) > 1, uvars)
dups = prefix_ret_vars[idxdups]
red_prefix_ret_vars = reduce(vcat, prefix_ret_vars)
if !allunique(red_prefix_ret_vars)
uvars = unique(red_prefix_ret_vars)
idxdups = findall(v -> count(Base.Fix1(===,v),red_prefix_ret_vars) > 1, uvars)
dups = red_prefix_ret_vars[idxdups]
error("duplicated arguments in all @returns found (after resolving prefixes): $(join(dups,", "))")
end
# generate getindex calls (v = var"##v#359"[i]) for all @accepts variables to load data inside loop
getvars_lines = [ :( $pv = $gv[$(bdry ? bi : vi)] )
for (pv,gv,bdry) in zip(prefix_acc_vars,gen_acc_vars,bdry_acc) ]
# generate setindex! calls (var"##v#359"[i] = v) for all @returns variables to write data inside loop
setvars_lines = [ :( $gv[$(bdry ? bi : vi)] = $pv )
for (pv,gv,bdry) in zip(prefix_ret_vars,gen_ret_vars,bdry_ret) ]
# construct array getter and setters for loop body, e.g.
# var"##vi#2802" = var"##vidxs#2804"[var"##ii#2801"]
getvars_lines = [ [ :( $pv = $gv[$(bdry ? bi : vi)] ) for (pv,gv,bdry) in zip(prefixs,gens,bdrys) ]
for (prefixs,gens,bdrys) in zip(prefix_acc_vars,gen_acc_vars,bdry_acc) ]
setvars_lines = [ [ :( $gv[$(bdry ? bi : vi)] = $pv ) for (pv,gv,bdry) in zip(prefixs,gens,bdrys) ]
for (prefixs,gens,bdrys) in zip(prefix_ret_vars,gen_ret_vars,bdry_ret) ]
# need to use @macroexpand1 @with_signature when debugging to not expand @inbounds macro
# loop skeleton
loop_ex = if options[:print_idxs]
quote
for $ii in eachindex($vidxs, $bidxs)
$vi = $vidxs[$ii]
$bi = $bidxs[$ii]
$(getvars_lines...)
$(compute_body...)
$(setvars_lines...)
println((;ii=$ii, vi=$vi, bi=$bi))
end
end
......@@ -645,15 +670,28 @@ function with_signature(expr, options=(;))
for $ii in eachindex($vidxs, $bidxs)
$vi = $vidxs[$ii]
$bi = $bidxs[$ii]
$(getvars_lines...)
$(compute_body...)
$(setvars_lines...)
end
end
end
# strip artificial LineNumberNodes and unwrap a begin block
loop_ex = MacroTools.prewalk(MacroTools.rmlines, loop_ex)
loop_ex = loop_ex.args[1] # unwrap a begin block
loop_body = loop_ex.args[2]
loop_ex = loop_ex.args[1]
loop_body = loop_ex.args[2].args
# insert getvars_lines with original LineNumberNodes
for (lno,getvars) in zip(lno_accepts,getvars_lines)
push!(loop_body, lno)
append!(loop_body, getvars)
end
# contains original LineNumberNodes
append!(loop_body, compute_body)
# insert setvars_lines with original LineNumberNodes
for (lno,setvars) in zip(lno_returns,setvars_lines)
push!(loop_body, lno)
append!(loop_body, setvars)
end
# prepend macros to for loop
if options[:simd]
......@@ -665,47 +703,39 @@ function with_signature(expr, options=(;))
loop_ex = Expr(:block, loop_ex) # wrap into begin
new_ex = deepcopy(ex)
new_ex.args[2].args = vcat(unpack_acc, unpack_ret, loop_ex)
# sanitize function arguments
(first_arg isa Expr && first_arg.head === :(::)) || error(
"@with_signature: function argument '$first_arg' must be type annotated")
dispatch_type_first = last(first_arg.args)
dispatch_type_second = if second_arg isa Symbol
:Any
elseif second_arg isa Expr && second_arg.head === :(::)
last(second_arg.args)
else
error("@with_signature: invalid argument '$second_arg'")
end
# combine loop with accepts, returns array unpacking
new_func = deepcopy(ex)
new_func.args[2].args = vcat(unpack_acc, unpack_ret, loop_ex)
# adjust function signature
if dispatch_type_second === :Any
new_fn_signature = Expr(:call, fn_name, first_arg, acc, ret, vidxs, bidxs,
Expr(:kw, second_arg, nothing))
else
new_fn_signature = Expr(:call, fn_name, first_arg, acc, ret, vidxs, bidxs, second_arg)
end
new_ex.args[1] = new_fn_signature
tpl_acc_vars = tuple(acc_vars...)
tpl_ret_vars = tuple(ret_vars...)
new_func.args[1] = new_fn_signature
# construct @with_signature interface functions signature(), hassignature()
tpl_acc_vars = Tuple(reduce(vcat, acc_vars))
tpl_ret_vars = Tuple(reduce(vcat, ret_vars))
thismod = @__MODULE__
# TODO What about docstrings for these methods?
# defs = if dispatch_type_second === :Any
# quote
# $new_ex
# $(thismod).signature(::typeof($(fn_name)), dispatch::$(dispatch_type_first), mesh::Any=nothing) = ($(tpl_acc_vars), $(tpl_ret_vars))
# $(thismod).has_signature(::typeof($(fn_name)), dispatch::Type{<:$(dispatch_type_first)}, mesh::Any=nothing) = true
# end
# else
defs = quote
$new_ex
ex_sig = quote
$(thismod).signature(::typeof($(fn_name)), dispatch::$(dispatch_type_first), mesh::$(dispatch_type_second)) = ($(tpl_acc_vars), $(tpl_ret_vars))
end
# strip artificial LineNumberNodes and unwrap a begin block
ex_sig = MacroTools.prewalk(MacroTools.rmlines, ex_sig)
ex_sig = ex_sig.args[1]
ex_hassig = quote
$(thismod).has_signature(::typeof($(fn_name)), dispatch::Type{<:$(dispatch_type_first)}, mesh::$(dispatch_type_second)) = true
end
# strip artificial LineNumberNodes and unwrap a begin block
ex_hassig = MacroTools.prewalk(MacroTools.rmlines, ex_hassig)
ex_hassig = ex_hassig.args[1]
# TODO What about docstrings for these methods?
defs = Expr(:block, lno_func, new_func, lno_func, ex_sig, lno_func, ex_hassig)
# display(MacroTools.prewalk(MacroTools.rmlines, new_ex))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment