Commit efcfd9ab authored by Connor Hainje's avatar Connor Hainje
Browse files

Add pion scaling weight as command-line option

parent a21fdc3b
......@@ -161,8 +161,7 @@ def train_model(bin_idx, args, use_tqdm=True):
diag_lfn = nn.CrossEntropyLoss()
pion_lfn = nn.BCELoss()
diag_wgt = 1
pion_wgt = 0.1
pion_wgt = args.beta
def compute_accuracies(out, y):
output = out.detach().cpu().numpy()
......@@ -180,7 +179,7 @@ def train_model(bin_idx, args, use_tqdm=True):
pi_out = output[pi_mask, 2]
pi_y = (target[pi_mask] == 2).float()
pion = pion_lfn(pi_out, pi_y)
return diag * diag_wgt + pion * pion_wgt, diag, pion
return diag + pion_wgt * pion, diag, pion
print(f"Training network for {args.n_epochs} epochs.")
......@@ -289,7 +288,8 @@ def parse():
nargs="*",
help=(
"Use only log-likelihood data from a subset of particle "
"types specified by PDG code."
"types specified by PDG code. If left unspecified, all "
"particle types will be used."
),
)
ap.add_argument(
......@@ -300,6 +300,15 @@ def parse():
"distributed with mean of 1 and width of 1."
),
)
ap.add_argument(
"--beta",
type=float,
default=0.1,
help=(
"Scaling factor for the pion binary cross entropy term in "
"the loss function. Defaults to 0.1."
),
)
args = ap.parse_args()
return args
......@@ -325,6 +334,7 @@ def main():
print("A model will be trained over all data in the file.")
print(f"The model will be trained for {args.n_epochs} epochs.")
print(f"The model will be written to {args.output}.")
print(f"Training will use a pion scaling factor beta = {args.beta}.")
os.makedirs(args.output, exist_ok=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment