Commit 675c6d65 authored by Maximilian Schuette's avatar Maximilian Schuette 🌃
Browse files

Performance improvements (hopefully)

parent e2f420b0
......@@ -13,7 +13,6 @@
# plt.yticks([e.value for e in DestinationXfel], [e.name for e in DestinationXfel])
from collections.abc import Sequence
from enum import IntEnum, IntFlag, unique
import numpy as np
from ctypes import LittleEndianStructure, c_float, c_uint8, c_uint16, POINTER, cast
......@@ -126,7 +125,7 @@ class SpecialFlagsFlash(IntFlag):
SpecialFlags = {'xfel': SpecialFlagsXfel, 'flash': SpecialFlagsFlash}
class TimingPatternPacked(LittleEndianStructure):
class TimingWordPacked(LittleEndianStructure):
_fields_ = [
('bunch_charge_setting', c_uint8, 4),
('injector_laser_triggers', c_uint8, 4),
......@@ -138,47 +137,54 @@ class TimingPatternPacked(LittleEndianStructure):
]
timing_pattern_type = {linac: np.dtype([('bunch_charge_setting', BunchChargeSetting),
timing_word_type = {linac: np.dtype([('bunch_charge_setting', BunchChargeSetting),
('injector_laser_triggers', InjectorLaserTriggers[linac]),
('seed_user_laser_triggers', SeedUserLaserTriggers[linac]),
('destination', Destination[linac]),
('special_flags', SpecialFlags[linac])]) for linac in LINACS}
def unpack_timing_pattern(value, linac=LINACS[0]):
if not isinstance(linac, str) or linac.lower() not in LINACS:
raise TypeError(f"`mode` must be either of {LINACS}")
def unpack_timing_word(timing_word):
timing_word = np.asarray(timing_word)
if timing_word.dtype.itemsize != 4:
raise TypeError("`timing_word` must have 4-byte (word) datatype convertible to numpy scalar")
if isinstance(value, Sequence):
value = np.asarray(value)
timing_word_bitfield = cast(timing_word.ctypes.data_as(POINTER(c_float)),
POINTER(TimingWordPacked)).contents
if isinstance(value, np.ndarray):
values = value
patterns = cast(values.ctypes.data_as(POINTER(c_float * values.size)),
POINTER(TimingPatternPacked * values.size)).contents
patterns_unpacked = np.empty(values.shape + (5,), dtype=np.uint16)
return np.asarray((
timing_word_bitfield.bunch_charge_setting,
timing_word_bitfield.injector_laser_triggers,
timing_word_bitfield.seed_user_laser_triggers,
timing_word_bitfield.destination,
timing_word_bitfield.special_flags), dtype=np.uint16)
for i, pattern in enumerate(patterns):
patterns_unpacked.flat[i * 5:(i + 1) * 5] = (
pattern.bunch_charge_setting,
pattern.injector_laser_triggers,
pattern.seed_user_laser_triggers,
pattern.destination,
pattern.special_flags)
return patterns_unpacked
else:
pattern = cast(value, POINTER(TimingPatternPacked)).contents
return np.asarray((
pattern.bunch_charge_setting,
pattern.injector_laser_triggers,
pattern.seed_user_laser_triggers,
pattern.destination,
pattern.special_flags), dtype=np.uint16)
def unpack_timing_pattern(pattern):
pattern = np.asarray(pattern)
if pattern.dtype.itemsize != 4:
raise TypeError("`pattern` must have 4-byte (word) datatype convertible to numpy array")
# There can be many more bunches than bunch codes,
# so compute them once and then produce the output array
unique_timing_words = np.unique(pattern)
timing_word_lut = np.vstack([unpack_timing_word(word) for word in unique_timing_words])
pattern_unpacked = np.zeros(pattern.shape + (5,), dtype=np.uint16)
it = np.nditer(pattern, flags=['multi_index'])
for word in it:
if word != 0:
pattern_unpacked[it.multi_index] = timing_word_lut[unique_timing_words == word]
return pattern_unpacked
def decode_timing_pattern(value, linac=LINACS[0]):
raise NotImplementedError('TODO')
# if not isinstance(linac, str) or linac.lower() not in LINACS:
# raise TypeError(f"`mode` must be either of {LINACS}")
# TODO
# patterns_unpacked.flat[i]['bunch_charge_setting'] = BunchChargeSetting(pattern.bunch_charge_setting)
# patterns_unpacked.flat[i]['injector_laser_triggers'] = InjectorLaserTriggers[linac](
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment