From b082b7896d6e81f88ac9a834e3a86b31ec68aa66 Mon Sep 17 00:00:00 2001
From: Florian Atteneder <florian.atteneder@uni-jena.de>
Date: Mon, 9 Sep 2024 16:17:58 +0000
Subject: [PATCH] CallbacksTiming: add `always_on` option
 (https://git.tpi.uni-jena.de/dg/dg1d.jl/-/merge_requests/222)

If `always_on=true` then this always fires, but the timer's internal state is no longer updated!
---
 src/callbacks.jl                     | 19 +++++++++++++------
 test/UnitTests/src/test_callbacks.jl | 10 ++++++++++
 2 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/src/callbacks.jl b/src/callbacks.jl
index ae191044..6cde9e85 100644
--- a/src/callbacks.jl
+++ b/src/callbacks.jl
@@ -40,12 +40,14 @@ Base.show(io::IO, @nospecialize(cbt::AbstractCallback)) = print(io, state(cbt))
 abstract type AbstractCallbackTrigger end
 
 """
-    CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0)
+    CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0, always_on=false)
 
 Object to time whether callbacks should be run after `every_iteration` iterations,
 or after `every_dt` (simulation) time has passed, or after `every_dt_walltime` time (in sec)
 has passed.
 
+If `always_on=true` then this always fires, but the timer's internal state is no longer updated!
+
 See also: `istimed!`
 """
 mutable struct CallbackTiming <: AbstractCallbackTrigger
@@ -58,15 +60,17 @@ mutable struct CallbackTiming <: AbstractCallbackTrigger
   iter::Int64                 # iteration count
   t::Float64                  # iteration time
   t_walltime::Float64         # iteration walltime (in secs)
+  always_on::Bool
 
-  function CallbackTiming(; every_iteration=0, every_dt=0, every_dt_walltime=0)
+  function CallbackTiming(; every_iteration::Int64=0, every_dt=0.0,
+                            every_dt_walltime=0.0, always_on::Bool=false)
     @toggled_assert every_iteration >= 0
-    @toggled_assert every_dt >= 0
-    @toggled_assert every_dt_walltime >= 0
+    @toggled_assert every_dt >= 0.0
+    @toggled_assert every_dt_walltime >= 0.0
     t_walltime = time()
     return new(every_iteration, every_dt, every_dt_walltime,
                every_iteration, every_dt, t_walltime + every_dt_walltime,
-               0,               0,        t_walltime)
+               0,               0,        t_walltime, always_on)
   end
 end
 
@@ -83,6 +87,7 @@ function state(cbt::CallbackTiming)
     iter                = $(cbt.iter)
     next_iter           = $(cbt.next_iter)
     every_iteration     = $(cbt.every_iteration)
+    always_on           = $(cbt.always_on)
   """
 end
 
@@ -96,7 +101,9 @@ Base.show(io::IO, cbt::CallbackTiming) = print(io, state(cbt))
 Check if callback `timing` should trigger for given time `t`.
 """
 function istimed!(timing::CallbackTiming, t)
-  @unpack every_iteration, every_dt, every_dt_walltime = timing
+  @unpack every_iteration, every_dt, every_dt_walltime, always_on = timing
+
+  always_on && return true
 
   if every_iteration == 0 && every_dt == 0 && every_dt_walltime == 0
     return false
diff --git a/test/UnitTests/src/test_callbacks.jl b/test/UnitTests/src/test_callbacks.jl
index 25ec2992..ae2ace89 100644
--- a/test/UnitTests/src/test_callbacks.jl
+++ b/test/UnitTests/src/test_callbacks.jl
@@ -48,6 +48,16 @@
   @test dg1d.istimed!(cbt, 4) == true  # triggered by every_dt
   @test dg1d.istimed!(cbt, 5) == true  # triggered by every_iteration
 
+  ### always_on makes timing fire on every check
+  ### note: this is a hack that I just happen to need
+  cbt = CallbackTiming(always_on = true, every_iteration = 3, every_dt = 2)
+  @test dg1d.istimed!(cbt, 0) == true
+  @test dg1d.istimed!(cbt, 1) == true
+  @test dg1d.istimed!(cbt, 2) == true
+  @test dg1d.istimed!(cbt, 3) == true
+  @test dg1d.istimed!(cbt, 4) == true
+  @test dg1d.istimed!(cbt, 5) == true
+
   # test constructor
   @test_throws AssertionError CallbackTiming(every_iteration = -1, every_dt = 0)
   @test_throws AssertionError CallbackTiming(every_iteration = 0, every_dt = -1)
-- 
GitLab