diff --git a/src/with_signature.jl b/src/with_signature.jl index 8b53ef797ac5e97b65c0e0701f49d2d5a344df79..98edb3d7fcd1183b28c578816bd40658cc376839 100644 --- a/src/with_signature.jl +++ b/src/with_signature.jl @@ -47,6 +47,7 @@ The following options are available to customize `@with_signature` definitions: if the function parsing works correctly. - `inbounds::Bool=true`: Controls `@inbounds` access for the grid loop. - `simd::Bool=true`: Controls `@simd` optimization for the grid loop. + Use `simd=false` for debugging to restore correct line numbers in stack traces. - `debug::Bool=false`: To disable `inbounds, simd` optimizations. - `print_idxs=false`: To enable printing of volume and boundary indices. @@ -494,6 +495,7 @@ parse_returns(line, ex) = parse_options_variables("@returns", line, ex) function with_signature(expr, options=(;)) + # display(expr) ex = MacroTools.prewalk(MacroTools.rmlines, expr) # display(ex) @@ -506,6 +508,19 @@ function with_signature(expr, options=(;)) push!(fn_signature.args, :(mesh::Any)) end fn_name, first_arg, second_arg = fn_signature.args + lno_func = expr.args[2].args[1] # line number node + + # sanitize function arguments + (first_arg isa Expr && first_arg.head === :(::)) || error( + "@with_signature: function argument '$first_arg' must be type annotated") + dispatch_type_first = last(first_arg.args) + dispatch_type_second = if second_arg isa Symbol + :Any + elseif second_arg isa Expr && second_arg.head === :(::) + last(second_arg.args) + else + error("@with_signature: invalid argument '$second_arg'") + end ex.args[2].head === :block || error("@with_singature: failed to extract function body") fn_body = ex.args[2].args @@ -521,6 +536,9 @@ function with_signature(expr, options=(;)) end length(accepts_lines) == 0 && error( "@with_signature: expected '@accepts [options] a, b, ...' at the start of the function body") + lno_accepts = [ expr.args[2].args[2+2*(i-1)] for i in 1:length(accepts_lines) ] # line number nodes + # @show lno_accepts + returns_lines, returns_exs = Any[], Any[] for line in reverse(fn_body) # search for @returns from bottom to top matched_returns = MacroTools.@capture(line, @returns returns_ex__) @@ -533,6 +551,8 @@ function with_signature(expr, options=(;)) # reverse returns expressions so that we emmit code from top to bottom returns_exs = returns_exs[end:-1:1] returns_lines = returns_lines[end:-1:1] + lno_returns = [ expr.args[2].args[end-1-2*(i-1)] for i in 1:length(returns_lines) ] # line number nodes + # @show lno_returns # @accepts and @returns can only occur right at the start or end if any(line -> MacroTools.@capture(line, @accepts _), fn_body[length(accepts_lines)+1:end]) @@ -554,6 +574,7 @@ function with_signature(expr, options=(;)) init_fn_body = expr.args[2].args init_fn_body[2+2*length(accepts_lines):end-2*length(returns_lines)] catch + @warn "@with_signature: failed to extract LineNumberNodes for function body" fn_body[1+length(accepts_lines):end-length(returns_lines)] end else @@ -573,70 +594,74 @@ function with_signature(expr, options=(;)) # 1) resolve the prefix, # 2) make a gen'ed symbol for array unpacking # 3) construct the array unpacking calls - acc_vars, prefix_acc_vars, gen_acc_vars, bdry_acc = Symbol[], Symbol[], Symbol[], Bool[] + acc_vars, prefix_acc_vars, gen_acc_vars, bdry_acc = + Vector{Symbol}[], Vector{Symbol}[], Vector{Symbol}[], Vector{Bool}[] unpack_acc = Expr[] n = 1 for (i,(opts, vars)) in enumerate(accepts_options_vars) - append!(acc_vars, vars) - append!(prefix_acc_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars)) - append!(bdry_acc, fill(opts[:bdry], length(vars))) + push!(acc_vars, vars) + push!(prefix_acc_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars)) + push!(bdry_acc, fill(opts[:bdry], length(vars))) + gens = Symbol[] for v in vars gv = gensym(v) push!(unpack_acc, :($gv = $acc[$n])) - push!(gen_acc_vars, gv) + push!(gens, gv) n += 1 end + push!(gen_acc_vars, gens) end # for each variable in @returns # 1) resolve the prefix, # 2) make a gen'ed symbol for array unpacking # 3) construct the array unpacking calls - ret_vars, prefix_ret_vars, gen_ret_vars, bdry_ret = Symbol[], Symbol[], Symbol[], Bool[] + ret_vars, prefix_ret_vars, gen_ret_vars, bdry_ret = Vector{Symbol}[], Vector{Symbol}[], Vector{Symbol}[], Vector{Bool}[] unpack_ret = Expr[] n = 1 for (i,(opts, vars)) in enumerate(returns_options_vars) - append!(ret_vars, vars) - append!(prefix_ret_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars)) - append!(bdry_ret, fill(opts[:bdry], length(vars))) + push!(ret_vars, vars) + push!(prefix_ret_vars, isnothing(opts[:prefix]) ? vars : Symbol.(opts[:prefix],:_,vars)) + push!(bdry_ret, fill(opts[:bdry], length(vars))) + gens = Symbol[] for v in vars gv = gensym(v) push!(unpack_ret, :($gv = $ret[$n])) - push!(gen_ret_vars, gv) + push!(gens, gv) n += 1 end + push!(gen_ret_vars, gens) end # error on duplicated variable names (after resolving prefixes) + red_prefix_acc_vars = reduce(vcat, prefix_acc_vars) if !allunique(prefix_acc_vars) - uvars = unique(prefix_acc_vars) - idxdups = findall(v -> count(Base.Fix1(===,v),prefix_acc_vars) > 1, uvars) - dups = prefix_acc_vars[idxdups] + uvars = unique(red_prefix_acc_vars) + idxdups = findall(v -> count(Base.Fix1(===,v),red_prefix_acc_vars) > 1, uvars) + dups = red_prefix_acc_vars[idxdups] error("duplicated arguments in all @accepts found (after resolving prefixes): $(join(dups,", "))") end - if !allunique(prefix_ret_vars) - uvars = unique(prefix_ret_vars) - idxdups = findall(v -> count(Base.Fix1(===,v),prefix_ret_vars) > 1, uvars) - dups = prefix_ret_vars[idxdups] + red_prefix_ret_vars = reduce(vcat, prefix_ret_vars) + if !allunique(red_prefix_ret_vars) + uvars = unique(red_prefix_ret_vars) + idxdups = findall(v -> count(Base.Fix1(===,v),red_prefix_ret_vars) > 1, uvars) + dups = red_prefix_ret_vars[idxdups] error("duplicated arguments in all @returns found (after resolving prefixes): $(join(dups,", "))") end - # generate getindex calls (v = var"##v#359"[i]) for all @accepts variables to load data inside loop - getvars_lines = [ :( $pv = $gv[$(bdry ? bi : vi)] ) - for (pv,gv,bdry) in zip(prefix_acc_vars,gen_acc_vars,bdry_acc) ] - # generate setindex! calls (var"##v#359"[i] = v) for all @returns variables to write data inside loop - setvars_lines = [ :( $gv[$(bdry ? bi : vi)] = $pv ) - for (pv,gv,bdry) in zip(prefix_ret_vars,gen_ret_vars,bdry_ret) ] + # construct array getter and setters for loop body, e.g. + # var"##vi#2802" = var"##vidxs#2804"[var"##ii#2801"] + getvars_lines = [ [ :( $pv = $gv[$(bdry ? bi : vi)] ) for (pv,gv,bdry) in zip(prefixs,gens,bdrys) ] + for (prefixs,gens,bdrys) in zip(prefix_acc_vars,gen_acc_vars,bdry_acc) ] + setvars_lines = [ [ :( $gv[$(bdry ? bi : vi)] = $pv ) for (pv,gv,bdry) in zip(prefixs,gens,bdrys) ] + for (prefixs,gens,bdrys) in zip(prefix_ret_vars,gen_ret_vars,bdry_ret) ] - # need to use @macroexpand1 @with_signature when debugging to not expand @inbounds macro + # loop skeleton loop_ex = if options[:print_idxs] quote for $ii in eachindex($vidxs, $bidxs) $vi = $vidxs[$ii] $bi = $bidxs[$ii] - $(getvars_lines...) - $(compute_body...) - $(setvars_lines...) println((;ii=$ii, vi=$vi, bi=$bi)) end end @@ -645,15 +670,28 @@ function with_signature(expr, options=(;)) for $ii in eachindex($vidxs, $bidxs) $vi = $vidxs[$ii] $bi = $bidxs[$ii] - $(getvars_lines...) - $(compute_body...) - $(setvars_lines...) end end end + # strip artificial LineNumberNodes and unwrap a begin block loop_ex = MacroTools.prewalk(MacroTools.rmlines, loop_ex) - loop_ex = loop_ex.args[1] # unwrap a begin block - loop_body = loop_ex.args[2] + loop_ex = loop_ex.args[1] + loop_body = loop_ex.args[2].args + + # insert getvars_lines with original LineNumberNodes + for (lno,getvars) in zip(lno_accepts,getvars_lines) + push!(loop_body, lno) + append!(loop_body, getvars) + end + + # contains original LineNumberNodes + append!(loop_body, compute_body) + + # insert setvars_lines with original LineNumberNodes + for (lno,setvars) in zip(lno_returns,setvars_lines) + push!(loop_body, lno) + append!(loop_body, setvars) + end # prepend macros to for loop if options[:simd] @@ -665,47 +703,39 @@ function with_signature(expr, options=(;)) loop_ex = Expr(:block, loop_ex) # wrap into begin - new_ex = deepcopy(ex) - new_ex.args[2].args = vcat(unpack_acc, unpack_ret, loop_ex) - - # sanitize function arguments - (first_arg isa Expr && first_arg.head === :(::)) || error( - "@with_signature: function argument '$first_arg' must be type annotated") - dispatch_type_first = last(first_arg.args) - dispatch_type_second = if second_arg isa Symbol - :Any - elseif second_arg isa Expr && second_arg.head === :(::) - last(second_arg.args) - else - error("@with_signature: invalid argument '$second_arg'") - end + # combine loop with accepts, returns array unpacking + new_func = deepcopy(ex) + new_func.args[2].args = vcat(unpack_acc, unpack_ret, loop_ex) + # adjust function signature if dispatch_type_second === :Any new_fn_signature = Expr(:call, fn_name, first_arg, acc, ret, vidxs, bidxs, Expr(:kw, second_arg, nothing)) else new_fn_signature = Expr(:call, fn_name, first_arg, acc, ret, vidxs, bidxs, second_arg) end - new_ex.args[1] = new_fn_signature - - tpl_acc_vars = tuple(acc_vars...) - tpl_ret_vars = tuple(ret_vars...) + new_func.args[1] = new_fn_signature + # construct @with_signature interface functions signature(), hassignature() + tpl_acc_vars = Tuple(reduce(vcat, acc_vars)) + tpl_ret_vars = Tuple(reduce(vcat, ret_vars)) thismod = @__MODULE__ - # TODO What about docstrings for these methods? - # defs = if dispatch_type_second === :Any - # quote - # $new_ex - # $(thismod).signature(::typeof($(fn_name)), dispatch::$(dispatch_type_first), mesh::Any=nothing) = ($(tpl_acc_vars), $(tpl_ret_vars)) - # $(thismod).has_signature(::typeof($(fn_name)), dispatch::Type{<:$(dispatch_type_first)}, mesh::Any=nothing) = true - # end - # else - defs = quote - $new_ex + ex_sig = quote $(thismod).signature(::typeof($(fn_name)), dispatch::$(dispatch_type_first), mesh::$(dispatch_type_second)) = ($(tpl_acc_vars), $(tpl_ret_vars)) + end + # strip artificial LineNumberNodes and unwrap a begin block + ex_sig = MacroTools.prewalk(MacroTools.rmlines, ex_sig) + ex_sig = ex_sig.args[1] + ex_hassig = quote $(thismod).has_signature(::typeof($(fn_name)), dispatch::Type{<:$(dispatch_type_first)}, mesh::$(dispatch_type_second)) = true end + # strip artificial LineNumberNodes and unwrap a begin block + ex_hassig = MacroTools.prewalk(MacroTools.rmlines, ex_hassig) + ex_hassig = ex_hassig.args[1] + + # TODO What about docstrings for these methods? + defs = Expr(:block, lno_func, new_func, lno_func, ex_sig, lno_func, ex_hassig) # display(MacroTools.prewalk(MacroTools.rmlines, new_ex))