From 9d4fa24e5f4f2c61959ab3c3f3419787a8452115 Mon Sep 17 00:00:00 2001 From: Florian Atteneder <florian.atteneder@uni-jena.de> Date: Fri, 15 Dec 2023 21:09:49 +0100 Subject: [PATCH] wrap variables and derivatives for axisymmetry case --- pyloc3d.py | 61 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/pyloc3d.py b/pyloc3d.py index a9366f7..ee88cca 100755 --- a/pyloc3d.py +++ b/pyloc3d.py @@ -66,9 +66,6 @@ class Grid: self.DM_v = wrap_in_numpy_array(g.DM_v, g.nv**2) self.DM_w = wrap_in_numpy_array(g.DM_w, g.nw**2) - # TODO Reshape arrays - # TODO Check if dimensions are correct - # CC=0, CS=1, SS=2, SC=3, CU=4, DF=5 # grid type self.type = g.type @@ -77,15 +74,53 @@ class Grid: # grid orientation self.orientation = g.orientation - # fields - all_fields_cart = wrap_in_numpy_array(g.fields_cart, g.nn*lib.n_fields_cart) - self.fields_cart = np.split(all_fields_cart, lib.n_fields_cart) + self.n_fields_cart = lib.n_fields_cart + self.n_derivatives = lib.n_derivatives # FULL3D=0, AXISYM=1 - ndim = 3 if lib.sym_mode == 0 else 2 - # derivatives of fields - all_derivatives = wrap_in_numpy_array(g.derivatives, g.nn*lib.n_derivatives*ndim) - self.derivatives = np.split(all_fields_cart, lib.n_derivatives*ndim) + if lib.sym_mode == 0: + self.sym_mode = "full3d" + elif lib.sym_mode == 1: + self.sym_mode = "axisym" + else: + raise Exception("Unknown value lib.sym_mode = {lib.sym_mode} encountered") + + # wrap the variables defined in AHloc.h into numpy arrays and reshape them + if self.sym_mode == "axisym": + + # in axisym we have nv=1, because of 2d, so we drop that when reshaping + # TODO someone needs to test if the reshaping is done correctly! + + # variables; indices from tFields2d + self.gxx = wrap_in_numpy_array(lib.getField(0,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.gxz = wrap_in_numpy_array(lib.getField(1,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.gyy = wrap_in_numpy_array(lib.getField(2,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.gzz = wrap_in_numpy_array(lib.getField(3,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.kxx = wrap_in_numpy_array(lib.getField(4,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.kxz = wrap_in_numpy_array(lib.getField(5,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.kyy = wrap_in_numpy_array(lib.getField(6,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + self.kzz = wrap_in_numpy_array(lib.getField(7,g.fields_cart,g.nn), g.nn).reshape((g.nw,g.nu)) + # derivatives; indices from tDerivatives2d + self.dgxx_dx = wrap_in_numpy_array(lib.getField(0,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgxz_dx = wrap_in_numpy_array(lib.getField(1,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgyy_dx = wrap_in_numpy_array(lib.getField(2,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgzz_dx = wrap_in_numpy_array(lib.getField(3,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgxx_dz = wrap_in_numpy_array(lib.getField(4,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgxz_dz = wrap_in_numpy_array(lib.getField(5,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgyy_dz = wrap_in_numpy_array(lib.getField(6,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + self.dgzz_dz = wrap_in_numpy_array(lib.getField(7,g.derivatives,g.nn), g.nn).reshape((g.nw,g.nu)) + + else: # 3d + raise Exception("TODO") + + # # fields + # all_fields_cart = wrap_in_numpy_array(g.fields_cart, g.nn*lib.n_fields_cart) + # self.fields_cart = np.split(all_fields_cart, lib.n_fields_cart) + + # ndim = 3 if lib.sym_mode == 0 else 2 + # # derivatives of fields + # all_derivatives = wrap_in_numpy_array(g.derivatives, g.nn*lib.n_derivatives*ndim) + # self.derivatives = np.split(all_fields_cart, lib.n_derivatives*ndim) parfile = "./test.par" @@ -104,6 +139,10 @@ for i in range(n_work): lib.load_data_and_derivatives(w) ngrids = w.ngrids for n in range(ngrids): - grids.append(Grid(w.grids[n])) + g = Grid(w.grids[n]) + print(g.gxx) + print(g.dgxx_dx) + break + # grids.append(Grid(w.grids[n])) print("wrapped number of grids = ", len(grids)) -- GitLab