Commit 0b3c67ab authored by connor.hainje@pnnl.gov's avatar connor.hainje@pnnl.gov
Browse files

Clean, modularize, and improve training script

parent 94b631c2
......@@ -22,94 +22,141 @@ import torch
import torch.nn as nn
import torch.optim as optim
import os
import h5py
from argparse import ArgumentParser
from sklearn.model_selection import train_test_split
from tqdm import tqdm
def load_data(folder, bin_idx, device):
def train_model(bin_idx, args, use_tqdm=True):
"""Reads data and trains a model for the given bin index.
Args:
bin_idx (tuple of ints or None): If not None, specifies a (p, theta)
bin. If None, uses all data.
use_tqdm (bool, optional): Enables a tqdm progress tracker over the
epochs. Defaults to True.
"""
train = np.load(os.path.join(folder, "train.npz"))
X_tr, y_tr = train["X"], train["y"]
print("Reading data.")
train = np.load(os.path.join(args.input, 'train.npz'))
X_tr, y_tr = train['X'], train['y']
val = np.load(os.path.join(args.input, 'val.npz'))
X_va, y_va = val['X'], val['y']
val = np.load(os.path.join(folder, "val.npz"))
X_va, y_va = val["X"], val["y"]
if bin_idx is not None:
mask_tr = train['p'] >= pidml.P_RANGES[bin_idx[0]][0]
mask_tr &= train['p'] <= pidml.P_RANGES[bin_idx[0]][1]
mask_tr &= train['theta'] >= pidml.THETA_RANGES[bin_idx[1]][0]
mask_tr &= train['theta'] <= pidml.THETA_RANGES[bin_idx[1]][1]
p_lo = pidml.P_RANGES[bin_idx[0]][0]
p_hi = pidml.P_RANGES[bin_idx[0]][1]
t_lo = pidml.THETA_RANGES[bin_idx[1]][0]
t_hi = pidml.THETA_RANGES[bin_idx[1]][1]
mask_tr = np.logical_and.reduce(
[
train["p"] >= p_lo,
train["p"] <= p_hi,
train["theta"] >= t_lo,
train["theta"] <= t_hi,
]
)
X_tr, y_tr = X_tr[mask_tr], y_tr[mask_tr]
mask_va = val['p'] >= pidml.P_RANGES[bin_idx[0]][0]
mask_va &= val['p'] <= pidml.P_RANGES[bin_idx[0]][1]
mask_va &= val['theta'] >= pidml.THETA_RANGES[bin_idx[1]][0]
mask_va &= val['theta'] <= pidml.THETA_RANGES[bin_idx[1]][1]
mask_va = np.logical_and.reduce(
[
val["p"] >= p_lo,
val["p"] <= p_hi,
val["theta"] >= t_lo,
val["theta"] <= t_hi,
]
)
X_va, y_va = X_va[mask_va], y_va[mask_va]
if len(y_tr) < 10 or len(y_va) < 10:
print("Not enough data in the given bin. Skipping...")
return
if args.only is not None:
for i, pdg in enumerate(pidml.PDG_CODES):
if pdg in args.only:
continue
X_tr[:,i] = -1e10
X_va[:,i] = -1e10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Moving data to", device)
X_tr = torch.tensor(X_tr).float().to(device)
X_va = torch.tensor(X_va).float().to(device)
y_tr = torch.tensor(y_tr).long().to(device)
y_va = torch.tensor(y_va).long().to(device)
print(f"{len(y_tr)} train events, {len(y_va)} val events")
return X_tr, y_tr, X_va, y_va
print("Initializing network.")
if bin_idx is not None:
output_filename = f"net_{bin_idx[0]}_{bin_idx[1]}.pt"
else:
output_filename = f"net.pt"
output_filename = os.path.join(args.output, output_filename)
net = pidml.SimpleNet()
epochs_0 = 0
def random_init(net):
with torch.no_grad():
for fc in net.fcs:
fc.weight.fill_(0)
fc.weight.add_(torch.normal(mean=1.0, std=0.5, size=fc.weight.size()))
if args.resume:
net, _, _, epochs_0, history = pidml.load_training(output_filename, empty_net=net, empty_opt=False)
loss_t = history['loss_t']
loss_v = history['loss_v']
accu_t = history['accu_t']
accu_v = history['accu_v']
else:
loss_t, loss_v, accu_t, accu_v = [], [], [], []
if args.only is not None:
# freeze parameters that we are not training on
def kill_unused(net, only):
if only is not None:
# particle types that are not being trained...
# set to zero and freeze
for i, pdg in enumerate(pidml.PDG_CODES):
if pdg in args.only:
if pdg in only:
continue
net.fcs[i].weight.fill_(0)
net.fcs[i].weight.requires_grad = False
def resume(filename, net):
net, _, _, epochs_0, history = pidml.load_training(
filename, empty_net=net, empty_opt=False
)
loss_t = history["loss_t"]
loss_v = history["loss_v"]
accu_t = history["accu_t"]
accu_v = history["accu_v"]
return net, epochs_0, loss_t, loss_v, accu_t, accu_v
def initialize(filename, args):
net = pidml.SimpleNet()
opt = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-4)
if args.random:
random_init(net)
if args.resume:
net, epochs_0, loss_t, loss_v, accu_t, accu_v = resume(filename, net)
else:
epochs_0 = 0
loss_t = {"diag": [], "pion": [], "sum": []}
loss_v = {"diag": [], "pion": [], "sum": []}
accu_t = {"net": [], "pion": []}
accu_v = {"net": [], "pion": []}
kill_unused(net, args.only)
return net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v
def train_model(bin_idx, args, use_tqdm=True):
"""Reads data and trains a model for the given bin index.
Args:
bin_idx (tuple of ints or None): If not None, specifies a (p, theta)
bin. If None, uses all data.
use_tqdm (bool, optional): Enables a tqdm progress tracker over the
epochs. Defaults to True.
"""
print("Reading data.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("...and moving to", device)
X_tr, y_tr, X_va, y_va = load_data(args.input, bin_idx, device)
if len(y_tr) < 10 or len(y_va) < 10:
print("Not enough data in the given bin. Skipping...")
return
print(f"{len(y_tr)} train events, {len(y_va)} val events")
# # zero out likelihoods for irrelevant particle types
# if args.only is not None:
# for i, pdg in enumerate(pidml.PDG_CODES):
# if pdg in args.only:
# continue
# X_tr[:, i] = -1e10
# X_va[:, i] = -1e10
print("Initializing network.")
if bin_idx is not None:
filename = f"net_{bin_idx[0]}_{bin_idx[1]}.pt"
else:
filename = f"net.pt"
filename = os.path.join(args.output, filename)
net, opt, epochs_0, loss_t, loss_v, accu_t, accu_v = initialize(filename, args)
net.to(device)
diag_lfn = nn.CrossEntropyLoss()
......@@ -122,30 +169,24 @@ def train_model(bin_idx, args, use_tqdm=True):
target = y.detach().cpu().numpy()
pred = np.squeeze(np.argmax(output, axis=1))
accu = np.count_nonzero(pred == target) / len(pred)
pi_out = output[(target == 2),2]
pi_out = output[(target == 2), 2]
pi_pred = (pi_out > 0.5).astype(float)
pi_accu = pi_pred.sum() / len(pi_pred)
return accu, pi_accu
def lfn(output, target):
diag = diag_lfn(output, target) * diag_wgt
pi_mask = (target == 2)
diag = diag_lfn(output, target)
pi_mask = target == 2
pi_out = output[pi_mask, 2]
pi_y = (target[pi_mask] == 2).float()
pion = pion_lfn(pi_out, pi_y) * pion_wgt
return diag + pion, diag, pion
loss_t = {'diag': [], 'pion': [], 'sum': []}
loss_v = {'diag': [], 'pion': [], 'sum': []}
accu_t = {'net': [], 'pion': []}
accu_v = {'net': [], 'pion': []}
pion = pion_lfn(pi_out, pi_y)
return diag * diag_wgt + pion * pion_wgt, diag, pion
print(f"Training network for {args.n_epochs} epochs.")
iterator = range(args.n_epochs)
if use_tqdm:
iterator = tqdm(range(args.n_epochs))
else:
iterator = range(args.n_epochs)
iterator = tqdm(iterator)
for epoch in iterator:
# train step
......@@ -156,14 +197,14 @@ def train_model(bin_idx, args, use_tqdm=True):
loss.backward()
opt.step()
loss_t['diag'].append(diag.item())
loss_t['pion'].append(pion.item())
loss_t['sum'].append(loss.item())
loss_t["diag"].append(diag.item())
loss_t["pion"].append(pion.item())
loss_t["sum"].append(loss.item())
if epoch % 10 == 0:
accu, pi_accu = compute_accuracies(out, y_tr)
accu_t['net'].append(accu)
accu_t['pion'].append(pi_accu)
accu_t["net"].append(accu)
accu_t["pion"].append(pi_accu)
# val step
net.eval()
......@@ -171,34 +212,100 @@ def train_model(bin_idx, args, use_tqdm=True):
out = net(X_va)
loss, diag, pion = lfn(out, y_va)
loss_v['diag'].append(diag.item())
loss_v['pion'].append(pion.item())
loss_v['sum'].append(loss.item())
loss_v["diag"].append(diag.item())
loss_v["pion"].append(pion.item())
loss_v["sum"].append(loss.item())
if epoch % 10 == 0:
accu, pi_accu = compute_accuracies(out, y_va)
accu_v['net'].append(accu)
accu_v['pion'].append(pi_accu)
accu_v["net"].append(accu)
accu_v["pion"].append(pi_accu)
print("Training complete.")
pidml.save_training(output_filename, net, opt, None, epochs_0 + args.n_epochs,
loss_t=loss_t, loss_v=loss_v, accu_t=accu_t, accu_v=accu_v)
net.cpu()
pidml.save_training(
filename,
net,
opt,
None,
epochs_0 + args.n_epochs,
loss_t=loss_t,
loss_v=loss_v,
accu_t=accu_t,
accu_v=accu_v,
)
print(f"Model saved to {output_filename}.")
print(f"Model saved to {filename}.")
def main():
def parse():
ap = ArgumentParser(description="", epilog="")
ap.add_argument('input', type=str, help='Path to folder with training files (in .npz format).')
ap.add_argument('output', type=str, help='Path to output directory where models will be saved.')
ap.add_argument('-n', '--n_epochs', type=int, default=500, help='Number of epochs to train the network(s). Defaults to 500.')
ap.add_argument('-b', '--binned', action='store_true', help='Train individual networks in all (p, theta) bins. Overridden if bin_idx is specified.')
ap.add_argument('--bin_idx', type=int, nargs=2, help='Train a single network in the given (p, theta) bin.')
ap.add_argument('-R', '--resume', action='store_true', help='Load a pre-existing model and resume training instead of starting a new one. The final trained model will overwrite the existing one.')
ap.add_argument('--only', type=int, nargs='*', help='Use only log-likelihood data from a subset of particle types specified by PDG code.')
ap.add_argument(
"input",
type=str,
help=("Path to folder with training files (in .npz format)."),
)
ap.add_argument(
"output",
type=str,
help=("Path to output directory where models will be saved."),
)
ap.add_argument(
"-n",
"--n_epochs",
type=int,
default=500,
help="Number of epochs to train the network(s). Defaults to 500.",
)
ap.add_argument(
"-b",
"--binned",
action="store_true",
help=(
"Train individual networks in all (p, theta) bins. Overridden "
"if bin_idx is specified."
),
)
ap.add_argument(
"--bin_idx",
type=int,
nargs=2,
help="Train a single network in the given (p, theta) bin.",
)
ap.add_argument(
"-R",
"--resume",
action="store_true",
help=(
"Load a pre-existing model and resume training instead of "
"starting a new one. The final trained model will overwrite "
"the existing one."
),
)
ap.add_argument(
"--only",
type=int,
nargs="*",
help=(
"Use only log-likelihood data from a subset of particle "
"types specified by PDG code."
),
)
ap.add_argument(
"--random",
action="store_true",
help=(
"Initialize network weights to random values, normally "
"distributed with mean of 1 and width of 1."
),
)
args = ap.parse_args()
return args
def main():
args = parse()
print("Welcome to the network trainer.")
print(f"Data will be read from {args.input}.")
......@@ -211,21 +318,24 @@ def main():
print(f"Models will be written to {args.output}.")
else:
if args.bin_idx is not None:
print(f"A model will be trained on data in bin ({args.bin_idx[0]}, {args.bin_idx[1]}).")
print(
f"A model will be trained on data in bin ({args.bin_idx[0]}, {args.bin_idx[1]})."
)
else:
print("A model will be trained over all data in the file.")
print(f"The model will be trained for {args.n_epochs} epochs.")
print(f"The model will be written to {args.output}.")
try:
os.makedirs(args.output)
except FileExistsError:
pass
os.makedirs(args.output, exist_ok=True)
print("---")
if args.bin_idx is None and args.binned:
indices = [(i,j) for i in range(len(pidml.P_RANGES)) for j in range(len(pidml.THETA_RANGES))]
indices = [
(i, j)
for i in range(len(pidml.P_RANGES))
for j in range(len(pidml.THETA_RANGES))
]
for bin_idx in indices:
print(f"Now starting training for bin {bin_idx}.")
train_model(bin_idx, args)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment