|
|
This script provides a solution for training a neural network (NN) model utilizing DAQ data. Users can supply essential parameters related to the SASE, run name, properties, source, and label through command line options or interactive prompts. The script manages loading data, preparing features and targets, splitting datasets, normalizing data, instantiating the NN model, configuring optimization settings, performing early stopping, and saving finalized models alongside evaluation metrics.
|
|
|
|
|
|
### Key Functions
|
|
|
|
|
|
* `input_params(argv)`: Parses command line arguments and handles missing mandatory fields gracefully.
|
|
|
* `rmse(predictions, targets)`: Computes Root Mean Square Error between predicted and ground truth values.
|
|
|
* `train_model(model, epoch, train_loader, OPTIMIZER)`: Executes one epoch of training, updating the model weights and returning losses, scores, and updated model.
|
|
|
* `validation_model(model, valid_loader)`: Validates performance of the supplied NN model against a held-out dataset.
|
|
|
* `read_folder(data_folder)`: Concatenates Parquet formatted DAQ data frames contained within a particular folder.
|
|
|
## Script Overview
|
|
|
|
|
|
### Logging and Suppression of Warnings
|
|
|
|
|
|
The script begins by setting up logging configurations and suppressing warning messages to enhance clarity during execution.
|
|
|
|
|
|
### Neural Network Architecture Definition
|
|
|
|
|
|
The script defines a PyTorch neural network class (`NN`) with customizable hidden layers and nodes. The network uses the ReLU activation function for hidden layers and is intended for regression tasks.
|
|
|
|
|
|
### Custom Dataset Class
|
|
|
|
|
|
A custom dataset class (`MyDataset`) is defined to facilitate the loading of data into PyTorch dataloaders. It extracts features and targets from the provided dataframe.
|
|
|
|
|
|
### Helper Functions
|
|
|
|
|
|
Several helper functions are defined, including functions for printing help information, displaying fatal errors, calculating root mean squared error (rmse), training the neural network for an epoch, validating the model, resetting model weights, and reading data from a specified folder.
|
|
|
|
|
|
### Input Parameter Parsing
|
|
|
|
|
|
The script uses the `getopt` library to parse command-line arguments, including parameters for the DAQ run (SASE number, run name, properties folder, source folder, and label).
|
|
|
|
|
|
### Training Process
|
|
|
|
|
|
The main training process involves reading data from specified folders, normalizing the data, defining the neural network, setting hyperparameters, creating dataloaders, training the model, and saving the trained model and optimizer.
|
|
|
|
|
|
### Early Stopping Mechanism
|
|
|
|
|
|
The script implements an early stopping mechanism based on validation loss to prevent overfitting.
|
|
|
|
|
|
### Evaluation and Results Logging
|
|
|
|
|
|
The script evaluates the trained model on the validation dataset, calculates the R2 score, and logs relevant information. The results, including training metadata, are serialized into a JSON file.
|
|
|
|
|
|
### Conclusion
|
|
|
|
|
|
The module provides a comprehensive and flexible approach to training PyTorch neural networks on DAQ data. It can be configured through command-line arguments and a YAML configuration file, making it adaptable to different datasets and training scenarios. The logging and result-saving functionalities contribute to its usability and transparency.
|
|
|
|
|
|
### Usage
|
|
|
|
... | ... | |