From 9d4fa24e5f4f2c61959ab3c3f3419787a8452115 Mon Sep 17 00:00:00 2001
From: Florian Atteneder <florian.atteneder@uni-jena.de>
Date: Fri, 15 Dec 2023 21:09:49 +0100
Subject: [PATCH] wrap variables and derivatives for axisymmetry case

---
 pyloc3d.py | 61 ++++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 50 insertions(+), 11 deletions(-)

diff --git a/pyloc3d.py b/pyloc3d.py
index a9366f7..ee88cca 100755
--- a/pyloc3d.py
+++ b/pyloc3d.py
@@ -66,9 +66,6 @@ class Grid:
         self.DM_v = wrap_in_numpy_array(g.DM_v, g.nv**2)
         self.DM_w = wrap_in_numpy_array(g.DM_w, g.nw**2)
 
-        # TODO Reshape arrays
-        # TODO Check if dimensions are correct
-
         # CC=0, CS=1, SS=2, SC=3, CU=4, DF=5
         # grid type
         self.type = g.type
@@ -77,15 +74,53 @@ class Grid:
         # grid orientation
         self.orientation = g.orientation
 
-        # fields
-        all_fields_cart = wrap_in_numpy_array(g.fields_cart, g.nn*lib.n_fields_cart)
-        self.fields_cart = np.split(all_fields_cart, lib.n_fields_cart)
+        self.n_fields_cart = lib.n_fields_cart
+        self.n_derivatives = lib.n_derivatives
 
         # FULL3D=0, AXISYM=1
-        ndim = 3 if lib.sym_mode == 0 else 2
-        # derivatives of fields
-        all_derivatives = wrap_in_numpy_array(g.derivatives, g.nn*lib.n_derivatives*ndim)
-        self.derivatives = np.split(all_fields_cart, lib.n_derivatives*ndim)
+        if lib.sym_mode == 0:
+            self.sym_mode = "full3d"
+        elif lib.sym_mode == 1:
+            self.sym_mode = "axisym"
+        else:
+            raise Exception("Unknown value lib.sym_mode = {lib.sym_mode} encountered")
+
+        # wrap the variables defined in AHloc.h into numpy arrays and reshape them
+        if self.sym_mode == "axisym":
+
+            # in axisym we have nv=1, because of 2d, so we drop that when reshaping
+            # TODO someone needs to test if the reshaping is done correctly!
+
+            # variables; indices from tFields2d
+            self.gxx = wrap_in_numpy_array(lib.getField(0,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.gxz = wrap_in_numpy_array(lib.getField(1,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.gyy = wrap_in_numpy_array(lib.getField(2,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.gzz = wrap_in_numpy_array(lib.getField(3,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.kxx = wrap_in_numpy_array(lib.getField(4,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.kxz = wrap_in_numpy_array(lib.getField(5,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.kyy = wrap_in_numpy_array(lib.getField(6,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.kzz = wrap_in_numpy_array(lib.getField(7,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu))
+            # derivatives; indices from tDerivatives2d
+            self.dgxx_dx = wrap_in_numpy_array(lib.getField(0,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgxz_dx = wrap_in_numpy_array(lib.getField(1,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgyy_dx = wrap_in_numpy_array(lib.getField(2,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgzz_dx = wrap_in_numpy_array(lib.getField(3,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgxx_dz = wrap_in_numpy_array(lib.getField(4,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgxz_dz = wrap_in_numpy_array(lib.getField(5,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgyy_dz = wrap_in_numpy_array(lib.getField(6,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+            self.dgzz_dz = wrap_in_numpy_array(lib.getField(7,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu))
+
+        else: # 3d
+            raise Exception("TODO")
+
+        #  # fields
+        #  all_fields_cart = wrap_in_numpy_array(g.fields_cart, g.nn*lib.n_fields_cart)
+        #  self.fields_cart = np.split(all_fields_cart, lib.n_fields_cart)
+
+        #  ndim = 3 if lib.sym_mode == 0 else 2
+        #  # derivatives of fields
+        #  all_derivatives = wrap_in_numpy_array(g.derivatives, g.nn*lib.n_derivatives*ndim)
+        #  self.derivatives = np.split(all_fields_cart, lib.n_derivatives*ndim)
 
 
 parfile = "./test.par"
@@ -104,6 +139,10 @@ for i in range(n_work):
     lib.load_data_and_derivatives(w)
     ngrids = w.ngrids
     for n in range(ngrids):
-        grids.append(Grid(w.grids[n]))
+        g = Grid(w.grids[n])
+        print(g.gxx)
+        print(g.dgxx_dx)
+        break
+        #  grids.append(Grid(w.grids[n]))
 
 print("wrapped number of grids = ", len(grids))
-- 
GitLab