From 15746b470306bb6673a941a6b4262709f62c48d9 Mon Sep 17 00:00:00 2001
From: Nikolay Ivanov <simply_nicky@nikolays-mbp.localdomain>
Date: Tue, 24 Nov 2020 00:25:12 +0100
Subject: [PATCH] bug fixes

---
 docs/conf.py               |  8 +-------
 pyrst/simulation/st_sim.py | 10 +++++-----
 test_rst.py                |  9 +++++----
 3 files changed, 11 insertions(+), 16 deletions(-)

diff --git a/docs/conf.py b/docs/conf.py
index 0328658..ce4ee67 100755
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -13,13 +13,7 @@
 import os
 import sys
 
-
-if (os.environ.get("READTHEDOCS") == "True") is False:
-    sys.path.insert(0, os.path.abspath('..'))
-else:
-    import site
-    p = site.getsitepackages()[0]
-    sys.path.insert(0, p)
+sys.path.insert(0, os.path.abspath('..'))
 
 # -- Project information -----------------------------------------------------
 
diff --git a/pyrst/simulation/st_sim.py b/pyrst/simulation/st_sim.py
index bad2a66..f3292ca 100755
--- a/pyrst/simulation/st_sim.py
+++ b/pyrst/simulation/st_sim.py
@@ -161,7 +161,7 @@ class STSim:
         self.logger.info("Generating wavefields at the sample's plane")
         self.wf0_x = lens_wp(x_arr=self.x_arr, wl=self.wl, ap=self.ap_x,
                              focus=self.focus, defoc=self.defocus, alpha=self.alpha,
-                             x0=(self.x0 - 0.5) * self.ap_x)
+                             xc=(self.x0 - 0.5) * self.ap_x)
         self.wf0_y = aperture_wp(x_arr=self.y_arr, z=self.focus + self.defocus,
                                  wl=self.wl, ap=self.ap_y)
         self.i0 = self.p0 / self.ap_x / self.ap_y
@@ -176,16 +176,16 @@ class STSim:
                                    br_dx=self.bar_size, rd=self.rnd_dev)
         self.bsteps = bsteps
         self.bs_t = barcode_profile(x_arr=self.x_arr, bx_arr=self.bsteps, sgm=self.bar_sigma,
-                               atn0=self.bulk_atn, atn=self.bar_atn, ss=self.step_size,
-                               nf=self.n_frames)
+                                    atn0=self.bulk_atn, atn=self.bar_atn, ss=self.step_size,
+                                    nf=self.n_frames)
         self.logger.info("The coefficients have been generated")
 
     def _init_detector(self):
         self.logger.info("Generating wavefields at the detector's plane")
         self.wf1_y = fraunhofer_1d(wf0=self.wf0_y, x_arr=self.y_arr, xx_arr=self.yy_arr,
-                                   dist=self.det_dist, wl=self.wl)
+                                   z=self.det_dist, wl=self.wl)
         self.wf1_x = fraunhofer_1d_scan(wf0=self.wf0_x * self.bs_t, x_arr=self.x_arr,
-                                   xx_arr=self.xx_arr, dist=self.det_dist, wl=self.wl)
+                                        xx_arr=self.xx_arr, z=self.det_dist, wl=self.wl)
         self.det_c = self.smp_c / self.wl / self.det_dist
         self.logger.info("The wavefields have been generated")
 
diff --git a/test_rst.py b/test_rst.py
index 1982da8..4cf3147 100755
--- a/test_rst.py
+++ b/test_rst.py
@@ -43,7 +43,7 @@ def loader(request):
     """
     Return a default cxi protocol
     """
-    return rst.loader(request.param)
+    return pyrst.loader(request.param)
 
 @pytest.fixture(scope='function')
 def ini_path():
@@ -74,7 +74,7 @@ def test_st_sim(st_params):
 def test_loader_exp(exp_data, loader):
     assert os.path.isfile(exp_data['path'])
     data_dict = loader._load(**exp_data)
-    for attr in rst.STData.attr_set:
+    for attr in pyrst.STData.attr_set:
         assert not data_dict[attr] is None
 
 @pytest.mark.rst
@@ -83,7 +83,7 @@ def test_loader_sim(sim_data, loader):
     data_path = os.path.join(sim_data, 'data.cxi')
     assert os.path.isfile(data_path)
     data_dict = loader._load(data_path)
-    for attr in rst.STData.attr_set:
+    for attr in pyrst.STData.attr_set:
         assert not data_dict[attr] is None
 
 @pytest.mark.rst
@@ -94,7 +94,8 @@ def test_iter_update(sim_data, loader):
     st_data = loader.load(data_path)
     st_obj = st_data.get_st()
     pixel_map0 = st_obj.pixel_map.copy()
-    st_obj.iter_update([0, 150], ls_pm=2.5, ls_ri=15, verbose=True, n_iter=5)
+    st_obj.iter_update(sw_ss=0, sw_fs=150, ls_pm=2.5, ls_ri=15,
+                       verbose=True, n_iter=5)
     assert (st_obj.pixel_map == pixel_map0).all()
     assert st_obj.pixel_map.dtype == loader.protocol.known_types['float']
 
-- 
GitLab