diff --git a/src/callbacks.jl b/src/callbacks.jl index ae191044cd39ed79e7f4073aaea573f6c89a35dd..6cde9e85361f15aa7d8416af3e9e64c8e7f65c47 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -40,12 +40,14 @@ Base.show(io::IO, @nospecialize(cbt::AbstractCallback)) = print(io, state(cbt)) abstract type AbstractCallbackTrigger end """ - CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0) + CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0, always_on=false) Object to time whether callbacks should be run after `every_iteration` iterations, or after `every_dt` (simulation) time has passed, or after `every_dt_walltime` time (in sec) has passed. +If `always_on=true` then this always fires, but the timer's internal state is no longer updated! + See also: `istimed!` """ mutable struct CallbackTiming <: AbstractCallbackTrigger @@ -58,15 +60,17 @@ mutable struct CallbackTiming <: AbstractCallbackTrigger iter::Int64 # iteration count t::Float64 # iteration time t_walltime::Float64 # iteration walltime (in secs) + always_on::Bool - function CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0) + function CallbackTiming(; every_iteration::Int64=0, every_dt=0.0, + every_dt_walltime=0.0, always_on::Bool=false) @toggled_assert every_iteration >= 0 - @toggled_assert every_dt >= 0 - @toggled_assert every_dt_walltime >= 0 + @toggled_assert every_dt >= 0.0 + @toggled_assert every_dt_walltime >= 0.0 t_walltime = time() return new(every_iteration, every_dt, every_dt_walltime, every_iteration, every_dt, t_walltime + every_dt_walltime, - 0, 0, t_walltime) + 0, 0, t_walltime, always_on) end end @@ -83,6 +87,7 @@ function state(cbt::CallbackTiming) iter = $(cbt.iter) next_iter = $(cbt.next_iter) every_iteration = $(cbt.every_iteration) + always_on = $(cbt.always_on) """ end @@ -96,7 +101,9 @@ Base.show(io::IO, cbt::CallbackTiming) = print(io, state(cbt)) Check if callback `timing` should trigger for given time `t`. """ function istimed!(timing::CallbackTiming, t) - @unpack every_iteration, every_dt, every_dt_walltime = timing + @unpack every_iteration, every_dt, every_dt_walltime, always_on = timing + + always_on && return true if every_iteration == 0 && every_dt == 0 && every_dt_walltime == 0 return false diff --git a/test/UnitTests/src/test_callbacks.jl b/test/UnitTests/src/test_callbacks.jl index 25ec299254318cf6b36e1a787bdd8f97752aed95..ae2ace89cc4d680b65c6aa4dd42c7a988855205c 100644 --- a/test/UnitTests/src/test_callbacks.jl +++ b/test/UnitTests/src/test_callbacks.jl @@ -48,6 +48,16 @@ @test dg1d.istimed!(cbt, 4) == true # triggered by every_dt @test dg1d.istimed!(cbt, 5) == true # triggered by every_iteration + ### always_on makes timing fire on every check + ### note: this is a hack that I just happen to need + cbt = CallbackTiming(always_on = true, every_iteration = 3, every_dt = 2) + @test dg1d.istimed!(cbt, 0) == true + @test dg1d.istimed!(cbt, 1) == true + @test dg1d.istimed!(cbt, 2) == true + @test dg1d.istimed!(cbt, 3) == true + @test dg1d.istimed!(cbt, 4) == true + @test dg1d.istimed!(cbt, 5) == true + # test constructor @test_throws AssertionError CallbackTiming(every_iteration = -1, every_dt = 0) @test_throws AssertionError CallbackTiming(every_iteration = 0, every_dt = -1)