Skip to content
Snippets Groups Projects
Commit 0db2ceac authored by simply-nicky's avatar simply-nicky
Browse files

Merge branch 'dev0'

parents e34d256e 64148ea1
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ copyright = '2020, Nikolay Ivanov'
author = 'Nikolay Ivanov'
# The full version, including alpha/beta/rc tags
release = '0.3.5-r1'
release = '0.3.6'
# -- General configuration ---------------------------------------------------
......
......@@ -58,14 +58,14 @@ class DataContainer:
attr_set, init_set = {}, {}
def __init__(self, **kwargs):
self.__dict__['attr_dict'] = {}
self.__dict__['attr_dict'] = {key: None for key in self.attr_set | self.init_set}
for attr in self.attr_set:
if kwargs.get(attr) is None:
raise ValueError('Attribute {:s} has not been provided'.format(attr))
else:
self.attr_dict[attr] = kwargs.get(attr)
self.__setattr__(attr, kwargs.get(attr))
for attr in self.init_set:
self.attr_dict[attr] = kwargs.get(attr)
self.__setattr__(attr, kwargs.get(attr))
def __iter__(self):
return self.attr_dict.__iter__()
......
......@@ -166,7 +166,10 @@ class STData(DataContainer):
def __setattr__(self, attr, value):
if attr in self.attr_set | self.init_set:
value = np.array(value, dtype=self.protocol.get_dtype(attr))
if isinstance(value, np.ndarray):
value = np.array(value, dtype=self.protocol.get_dtype(attr))
elif not value is None:
value = self.protocol.get_dtype(attr)(value)
super(STData, self).__setattr__(attr, value)
else:
super(STData, self).__setattr__(attr, value)
......@@ -491,11 +494,12 @@ class STData(DataContainer):
if not val is None:
if attr in ['data', 'error_frame', 'mask', 'phase',
'pixel_abberations', 'pixel_map', 'whitefield']:
val = val[..., self.roi[0]:self.roi[1], self.roi[2]:self.roi[3]]
val = np.ascontiguousarray(val[..., self.roi[0]:self.roi[1],
self.roi[2]:self.roi[3]])
if attr in ['basis_vectors', 'data', 'mask', 'pixel_translations',
'translations']:
val = val[self.good_frames]
return np.ascontiguousarray(val)
val = np.ascontiguousarray(val[self.good_frames])
return val
else:
return value
......
......@@ -359,9 +359,6 @@ class STConverter:
t_arr = np.zeros((st_params.n_frames, 3), dtype=self.protocol.get_dtype('translations'))
t_arr[:, 0] = -smp_pos
data_dict['translations'] = self.crd_rat * t_arr
for attr in data_dict:
data_dict[attr] = np.asarray(data_dict[attr], dtype=self.protocol.get_dtype(attr))
return data_dict
def export_data(self, data, sim_obj):
......
......@@ -39,7 +39,7 @@ with open('README.md', 'r') as readme:
long_description = readme.read()
setup(name='pyrost',
version='0.3.5-r1',
version='0.3.6',
author='Nikolay Ivanov',
author_email="nikolay.ivanov@desy.de",
long_description=long_description,
......
......@@ -115,8 +115,8 @@ def test_full(converter, ptych, sim_obj):
data = converter.export_data(ptych, sim_obj)
assert data.data.dtype == converter.protocol.known_types['float']
st_obj = data.get_st()
st_res = st_obj.iter_update(sw_fs=10, ls_pm=2.5, ls_ri=15,
verbose=True, n_iter=10)
st_res = st_obj.iter_update_gd(sw_fs=10, ls_pm=2.5, ls_ri=15,
verbose=True, n_iter=10)
data = data.update_phase(st_res)
fit = data.fit_phase(axis=1)
assert (st_obj.pixel_map != st_res.pixel_map).any()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment