From 73e636ee10232fd08a420fa35f437851ad2a6e65 Mon Sep 17 00:00:00 2001
From: Florian Atteneder <florian.atteneder@uni-jena.de>
Date: Thu, 29 Aug 2024 20:01:29 +0000
Subject: [PATCH] dg1d: Implement `Workspace`
 (https://git.tpi.uni-jena.de/dg/dg1d.jl/-/merge_requests/219)

This is a new caching mechanism which should be used for small temporary arrays.
---
 src/dg1d.jl                          |  2 +-
 src/workspace.jl                     | 84 ++++++++++++++++++++++++++++
 test/UnitTests/src/test_workspace.jl | 47 ++++++++++++++++
 test/UnitTests/src/tests.jl          |  1 +
 4 files changed, 133 insertions(+), 1 deletion(-)
 create mode 100644 src/workspace.jl
 create mode 100644 test/UnitTests/src/test_workspace.jl

diff --git a/src/dg1d.jl b/src/dg1d.jl
index b9cbe9f4..48be5deb 100644
--- a/src/dg1d.jl
+++ b/src/dg1d.jl
@@ -70,7 +70,7 @@ include("tensorbasis.jl")
 include("box.jl")
 include("tree.jl")
 include("objectcache.jl")
-# include("refvector.jl")
+include("workspace.jl")
 
 export Variables, n_variables, names, isregistered
 include("variables.jl")
diff --git a/src/workspace.jl b/src/workspace.jl
new file mode 100644
index 00000000..964f0b5b
--- /dev/null
+++ b/src/workspace.jl
@@ -0,0 +1,84 @@
+mutable struct Chunk
+  nborrowed::Int64
+  const sz::Int64
+  const arrays::Vector{Vector{Float64}}
+end
+Chunk(sz::Int64) = Chunk(0,sz,Vector{Float64}[])
+
+
+function borrow(chunk::Chunk)
+  chunk.nborrowed += 1
+  if chunk.nborrowed <= length(chunk.arrays)
+    return chunk.arrays[chunk.nborrowed]
+  end
+  new_array = zeros(Float64, chunk.sz)
+  push!(chunk.arrays, new_array)
+  return new_array
+end
+
+
+"""
+  mutable struct Workspace
+
+A lazy cache type to provide temporary `Float64` vector of arbitrary sizes.
+
+See: `enter, borrow`.
+
+# Example
+
+To `borrow` vectors from a `Workspace`, you first have to `enter` it.
+Borrowed vectors must not escape the `enter` block.
+
+```
+ws = Workspace()
+enter(ws) do
+  sz1, sz2 = 3, 4
+  v1, v2 = borrow(ws, sz1), borrow(ws, sz2)
+end
+```
+"""
+mutable struct Workspace
+  entered::Bool
+  const chunks::Dict{Int64,Chunk}
+end
+Workspace() = Workspace(false, Dict{Int64,Chunk}())
+
+
+"""
+  enter(fn::Function, ws::Workspace)
+
+Enter `ws` to enable borrowing vector.
+
+See: `Workspace, borrow`.
+"""
+function enter(fn::Function, ws::Workspace)
+  if ws.entered
+    error("Can only enter workspace once!")
+  end
+  try
+    ws.entered = true
+    fn()
+  finally
+    for chunk in values(ws.chunks)
+      chunk.nborrowed = 0
+    end
+    ws.entered = false
+  end
+  return
+end
+
+
+"""
+  borrow(ws::Workspace, sz::Int64)
+
+Borrow a `Float64` vector of length `sz` from `ws`.
+
+See: `Workspace, borrow`.
+"""
+function borrow(ws::Workspace, sz::Int64)
+  if !ws.entered
+    error("Can only borrow from workspace once entered!")
+  end
+  chunk = get!(ws.chunks, sz, Chunk(sz))
+  return borrow(chunk)
+end
diff --git a/test/UnitTests/src/test_workspace.jl b/test/UnitTests/src/test_workspace.jl
new file mode 100644
index 00000000..a97c07f4
--- /dev/null
+++ b/test/UnitTests/src/test_workspace.jl
@@ -0,0 +1,47 @@
+@testset "Workspace" begin
+
+  ws = dg1d.Workspace()
+  @test isempty(ws.chunks)
+
+  dg1d.enter(ws) do
+    v1 = dg1d.borrow(ws, 1)
+    v2 = dg1d.borrow(ws, 2)
+    v3 = dg1d.borrow(ws, 3)
+    @test length(v1) == 1
+    @test length(v2) == 2
+    @test length(v3) == 3
+    v4 = dg1d.borrow(ws, 1)
+    v5 = dg1d.borrow(ws, 2)
+    v6 = dg1d.borrow(ws, 3)
+    @test length(v4) == 1
+    @test length(v5) == 2
+    @test length(v6) == 3
+  end
+
+  @test length(ws.chunks) == 3
+  for chk in values(ws.chunks)
+    @test length(chk.arrays) == 2
+  end
+
+  dg1d.enter(ws) do
+    v1 = dg1d.borrow(ws, 1)
+    v2 = dg1d.borrow(ws, 2)
+    v3 = dg1d.borrow(ws, 3)
+    @test length(v1) == 1
+    @test length(v2) == 2
+    @test length(v3) == 3
+  end
+
+  @test length(ws.chunks) == 3
+  for chk in values(ws.chunks)
+    @test length(chk.arrays) == 2
+  end
+
+  @test_throws ErrorException dg1d.borrow(ws, 1)
+  dg1d.enter(ws) do
+    @test_throws ErrorException dg1d.enter(ws) do
+      println("sers")
+    end
+  end
+
+end
diff --git a/test/UnitTests/src/tests.jl b/test/UnitTests/src/tests.jl
index a35df0d0..495b5655 100644
--- a/test/UnitTests/src/tests.jl
+++ b/test/UnitTests/src/tests.jl
@@ -21,4 +21,5 @@ include("test_variables.jl")
 include("test_cache.jl")
 include("test_callbacks.jl")
 include("test_output.jl")
+include("test_workspace.jl")
 include("test_EquationOfState.jl")
-- 
GitLab