Skip to content
Snippets Groups Projects
Commit 662af9ab authored by Sviatoslav Bilokin's avatar Sviatoslav Bilokin
Browse files

Production dispatcher task

parent 811b6c29
No related branches found
No related tags found
No related merge requests found
......@@ -9,12 +9,12 @@ import basf2 as b2
from syscorrfw.particleid.path_builder import PIDPathBuilder
from syscorrfw.particleid.fit_controller import FitController
input_udst = '.udsts/mdst_000001_prod00018890_task116212000001.root'
input_udst = '.udsts/udst_000001_prod00014654_task114406000001.root'
output_ntuple = 'tests/skim_integration.root'
model_name = 'KShort'
model_name = '__test__'
builder = PIDPathBuilder(model_name, input_name=input_udst, output_name=output_ntuple)
main_path = builder.build_ntuple_path(extra_variables=['nTracks'], dst_type='mdst', online_prescale=3)
main_path = builder.build_ntuple_path(extra_variables=['nTracks'], dst_type='udst')
b2.process(main_path)
......
......@@ -2,7 +2,7 @@ import os
import yaml
import b2luigi
import syscorrfw.constants as const
from syscorrfw.luigi.production_tasks import BenchmarkTask
from syscorrfw.luigi.production_tasks import ProductionDispatcherTask
from syscorrfw.luigi.weight_computing_tasks import WeightDispatcherTask
......@@ -25,7 +25,7 @@ class AggregatorTask(b2luigi.WrapperTask):
if task_dict['proc'] in const.STANDARD_FOLDER_NAMES:
raise ValueError(f'Processing name "{task_dict["proc"]}" cannot be the same as'
f' one of the standard folder names: {const.STANDARD_FOLDER_NAMES}')
yield BenchmarkTask(task_parameters=task_dict)
yield ProductionDispatcherTask(task_parameters=task_dict)
elif self.workflow_type == 'fixed_weights':
hid_weight_dict = self.load_parameters(const.HID_WEIGHT_TASKS_NAME)
remove_tmp_files = False
......@@ -47,7 +47,7 @@ class AggregatorTask(b2luigi.WrapperTask):
task_dict['refit'] = True
for refit_key in hid_refit_parameters:
task_dict[f'refit_{refit_key}'] = hid_refit_parameters[refit_key]
yield BenchmarkTask(task_parameters=task_dict)
yield ProductionDispatcherTask(task_parameters=task_dict)
def load_parameters(self, settings_name: str) -> list:
'''
......
......@@ -23,6 +23,7 @@ from syscorrfw.common.repository import get_repository
import syscorrfw.constants as const
import syscorrfw.common.efficiency_table_methods as etm
import syscorrfw.common.id_vs_misid_curve_methods as idm
from syscorrfw.particleid import pid_model_base
from syscorrfw import __version__ as framework_version
......@@ -310,7 +311,7 @@ class FitterTask(b2luigi.Task):
print(f'No files to remove, {input_files} does not exist.')
class BenchmarkTask(b2luigi.Task):
class WeightedBenchmarkTask(b2luigi.Task):
'''
Task class, which produces benchmark tables and plots
'''
......@@ -462,3 +463,24 @@ class DBLoggerTask(b2luigi.Task):
with get_repository(self.db_destination) as repository:
repository.add_one(dict(self.entry))
repository.complete()
class ProductionDispatcherTask(b2luigi.WrapperTask):
'''
Task class, which directs the production workflow
'''
task_parameters = b2luigi.DictParameter(positional=False)
def requires(self):
'''
Luigi's requires method
'''
model = get_model(self.task_parameters[const.HID_MODEL_KEY])
if isinstance(model, pid_model_base.WeightedModelBase):
yield WeightedBenchmarkTask(task_parameters=self.task_parameters)
elif isinstance(model, pid_model_base.TauModelBase):
yield GMergingTask(task_parameters=self.task_parameters)
elif isinstance(model, pid_model_base.LMModelBase):
yield GMergingTask(task_parameters=self.task_parameters)
else:
raise NotImplementedError(f'Model type {type(self.m_model)} is not propagated correctly!')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment