networks.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
networks.py
Connor Hainje (connor.hainje@pnnl.gov)

Implementations of the likelihood combination networks.
"""

import numpy as np
import matplotlib.pyplot as plt
10
import torch
11
12
13
import torch.nn as nn
import torch.nn.functional as F

14
15
from .const import PARTICLES, DETECTORS

16

17
def visualize_weights(
18
19
    x,
    xtick_labels,
20
    ytick_labels=PARTICLES,
21
    fig_kw=dict(figsize=(20, 8)),
22
23
    image_kw=dict(cmap="RdBu", aspect=1, vmin=-2.5, vmax=2.5),
    grid_kw=dict(color="black"),
24
    text_kw=dict(
25
        ha="center", va="center", fontsize=12, bbox=dict(lw=0, fc="white", alpha=0.7)
26
27
    ),
    cbar=True,
28
    ax=None,
29
30
31
32
):
    """
    Visualizes a network's weights using imshow.

33
    Args:
34
35
36
37
        x: The network's weights.
        xtick_labels: A list of strings corresponding to the x-tick labels.
            Defaults to the names of the six PID detectors.
        ytick_labels: A list of strings corresponding to the y-tick labels.
38
        fig_kw (dict): Keywords for the figure object. The default sets
39
40
41
42
43
            `figsize=(20,8)`.
        image_kw (dict): Keywords for the imshow call. The default sets
            `cmap='RdBu'`, `aspect=1`, `vmin=-2.5`, `vmax=2.5`.
        grid_kw (dict): Keywords for the horizontal and vertical lines making
            up the grid-lines. Default sets `color='black'`.
44
45
        text_kw (dict): Keywords for the value labels in each cell. Default
            sets
46
47
48
            `ha='center'`, `va='center'`, `fontsize=12`,
            `bbox=dict(lw=0, fc='white', alpha=0.7)`.
        cbar (bool): Specify whether to create a colorbar.
49
50
51
        ax (Axes): Optional Axes object. If given, the plot will be made on
            these axes, and fig_kw will be ignored. Otherwise, a new figure and
            axis are created.
52
53
54
55
56
57
58
59

    Returns:
        A matplotlib.figure.Figure object. (Or, if ax is given, nothing.)
    """

    if not ax:
        fig, ax = plt.subplots(**fig_kw)

60
61
62
63
64
65
    im = ax.imshow(
        x,
        origin="lower",
        extent=(-0.5, x.shape[1] - 0.5, -0.5, x.shape[0] - 0.5),
        **image_kw,
    )
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    if cbar:
        plt.colorbar(im, ax=ax)

    # Add tick marks
    ax.set_xticks(range(x.shape[1]))
    ax.set_xticklabels(xtick_labels, fontsize=12, rotation=90)
    ax.set_yticks(range(x.shape[0]))
    ax.set_yticklabels(ytick_labels, fontsize=12)

    # Add values as text
    for row in range(x.shape[0]):
        for col in range(x.shape[1]):
            ax.text(col, row, f"{x[row,col]:.2f}", **text_kw)
80

81
82
83
84
85
86
87
88
    # Make a grid
    for row in range(x.shape[0] - 1):
        ax.axhline(row + 0.5, **grid_kw)
    for col in range(x.shape[1] - 1):
        ax.axvline(col + 0.5, **grid_kw)

    if not ax:
        return fig
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
class GeneralNet(nn.Module):
    """
    A superclass for our networks. Provides methods for saving and loading.
    """

    def __init__(self):
        """Instantiates the model."""
        super().__init__()

    def save(self, filename):
        """Saves the model.

        Args:
            filename (str): The name of the output file.
        """
        torch.save(self.state_dict(), filename)

    def load(self, filename):
        """Loads a state_dict into this model.
110

111
        Args:
112
113
            filename (str): The name of the file containing the
                previously-saved model.
114
115
116
        """
        self.load_state_dict(torch.load(filename))

117

118
class SimpleNet(GeneralNet):
119
120
121
122
    """
    Implements the ''simple'' network for likelihood combination and particle
    identification.

123
    Args:
124
125
        n_class (int): Number of classes (particle types)
        n_detector (int): Number of detectors
126
127
128
        const_init (number, optional): If not None, initializes all weights in
            the network to this value. If None, uses torch's default random
            initialization. Defaults to 1.
129
    """
130

131
    def __init__(self, n_class=6, n_detector=6, const_init=1):
132
133
        """Instantiator"""
        super().__init__()
134
135
        self.n_class = n_class
        self.n_detector = n_detector
136
137
138
        self.fcs = nn.ModuleList(
            [nn.Linear(self.n_detector, 1, bias=False) for i in range(self.n_class)]
        )
139

140
        if const_init is not None:
141
142
            with torch.no_grad():
                for fc in self.fcs:
143
                    fc.weight.fill_(const_init)
144

145
        # needed for forward, pre-computed here
146
147
148
149
        self.idx = [
            [j * self.n_class + i for j in range(self.n_detector)]
            for i in range(self.n_class)
        ]
150
151
152
153
154

    def forward(self, x):
        """
        The network's forward method.

155
        Args:
156
            x (torch.Tensor): Input likelihood data. Should be of shape
157
158
159
160
                (N, N_C*N_D), where N is the number of data points, N_C is
                n_class, and N_D is n_detector

        Returns:
161
            A torch.Tensor of shape (N, N_C), containing the model's computed
162
163
164
            'probabilities' that each particle was of each possible particle
            type.
        """
connor.hainje@pnnl.gov's avatar
connor.hainje@pnnl.gov committed
165
        n = self.n_detector
166
        outs = [
connor.hainje@pnnl.gov's avatar
connor.hainje@pnnl.gov committed
167
            self.fcs[i](x[:, i * n : (i + 1) * n])
168
169
            for i in range(self.n_class)
        ]
170
171
172
        out = torch.cat(outs, dim=1)
        return F.softmax(out, dim=1)

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def visualize(
        self,
        xtick_labels=DETECTORS,
        ytick_labels=PARTICLES,
        fig_kw=dict(figsize=(20, 8)),
        image_kw=dict(cmap="RdBu", aspect=1, vmin=-2.5, vmax=2.5),
        grid_kw=dict(color="black"),
        text_kw=dict(
            ha="center",
            va="center",
            fontsize=12,
            bbox=dict(lw=0, fc="white", alpha=0.7),
        ),
        cbar=True,
        ax=None,
    ):
189
190
191
        """
        Visualizes the network's weights using imshow.

192
        Args:
193
194
195
196
            xtick_labels: A list of strings corresponding to the x-tick labels.
                Defaults to the names of the six PID detectors.
            ytick_labels: A list of strings corresponding to the y-tick labels.
                Defaults to the names of the six standard hypotheses.
197
            fig_kw (dict): Keywords for the figure object. The default sets
198
199
200
                `figsize=(20,8)`.
            image_kw (dict): Keywords for the imshow call. The default sets
                `cmap='RdBu'`, `aspect=1`, `vmin=-2.5`, `vmax=2.5`.
201
202
203
204
            grid_kw (dict): Keywords for the horizontal and vertical lines
                making up the grid-lines. Default sets `color='black'`.
            text_kw (dict): Keywords for the value labels in each cell. Default
                sets `ha='center'`, `va='center'`, `fontsize=12`,
205
                `bbox=dict(lw=0, fc='white', alpha=0.7)`.
206
            cbar (bool): Specify whether to create a colorbar.
207
208
209
            ax (Axes): Optional Axes object. If given, the plot will be made
                on these axes, and fig_kw will be ignored. Otherwise, a new
                figure and axis are created.
210
211
212
213
214
215
216

        Returns:
            A matplotlib.figure.Figure object.
        """
        # Make array from weights
        x = np.concatenate([fc.weight.detach().numpy() for fc in self.fcs])

217
        if not ax:
218
219
220
221
222
223
224
225
226
227
            return visualize_weights(
                x,
                xtick_labels,
                ytick_labels=ytick_labels,
                fig_kw=fig_kw,
                image_kw=image_kw,
                grid_kw=grid_kw,
                text_kw=text_kw,
                cbar=cbar,
            )
228

229
        else:
230
231
232
233
234
235
236
237
238
239
240
            visualize_weights(
                x,
                xtick_labels,
                ytick_labels=ytick_labels,
                ax=ax,
                fig_kw=fig_kw,
                image_kw=image_kw,
                grid_kw=grid_kw,
                text_kw=text_kw,
                cbar=cbar,
            )
241

242

243
class DenseNet(GeneralNet):
244
245
246
247
    """
    Implements the ''dense'' network for likelihood combination and particle
    identification.

248
    Args:
249
250
251
252
253
        n_class (int): Number of classes (particle types)
        n_detector (int): Number of detectors
        random_init (bool): Whether to randomly initialize network weights.
            Default: `False`.
    """
254

255
    def __init__(self, n_class=6, n_detector=6, random_init=False):
256
        """Instantiator"""
257
258
259
260
261
262
263
264
        super(DenseNet, self).__init__()
        self.n_class = n_class
        self.n_detector = n_detector

        self.fc = nn.Linear(n_class * n_detector, n_class, bias=False)

        if not random_init:
            with torch.no_grad():
265
                self.fc.weight.fill_(0)
266
267
                for i in range(self.n_class):
                    for j in range(self.n_detector):
268
                        self.fc.weight[j][j * n_class + i] = 1.0
269
270
271
272
273

    def forward(self, x):
        """
        The network's forward method.

274
        Args:
275
            x (torch.Tensor): Input likelihood data. Should be of shape
276
277
278
279
                (N, N_C*N_D), where N is the number of data points, N_C is
                n_class, and N_D is n_detector

        Returns:
280
            A torch.Tensor of shape (N, N_C), containing the model's computed
281
282
283
284
285
            'probabilities' that each particle was of each possible particle
            type.
        """
        return F.softmax(self.fc(x), dim=1)

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    def visualize(
        self,
        reorder_weights=True,
        xtick_labels=[f"{d}_{p}" for d in DETECTORS for p in PARTICLES],
        ytick_labels=PARTICLES,
        fig_kw=dict(figsize=(30, 8)),
        image_kw=dict(cmap="RdBu", aspect=1.5, vmin=-2.5, vmax=2.5),
        grid_kw=dict(color="black"),
        text_kw=dict(
            ha="center",
            va="center",
            fontsize=12,
            bbox=dict(lw=0, fc="white", alpha=0.7),
        ),
        extra_grid_mod=len(DETECTORS),
        extra_grid_kw=dict(color="red"),
        cbar=True,
        ax=None,
    ):
305
306
307
        """
        Visualizes the network's weights using imshow.

308
        Args:
309
            reorder_weights (bool): Reorder weights such that the x-axis will
310
311
312
313
314
315
                be ordered first by detector, then hypothesis. Default: True.
            xtick_labels: A list of strings corresponding to the x-tick labels.
                Defaults to the names of the six PID detectors.
            ytick_labels: A list of strings corresponding to the y-tick labels.
                Defaults to the names of the thirty-six hypothesis-detector
                combinations, ordered first by detector, then hypothesis.
316
            fig_kw (dict): Keywords for the figure object. The default sets
317
318
319
                `figsize=(20,8)`.
            image_kw (dict): Keywords for the imshow call. The default sets
                `cmap='RdBu'`, `aspect=1`, `vmin=-2.5`, `vmax=2.5`.
320
321
322
323
            grid_kw (dict): Keywords for the horizontal and vertical lines
                making up the grid-lines. Default sets `color='black'`.
            text_kw (dict): Keywords for the value labels in each cell. Default
                sets `ha='center'`, `va='center'`, `fontsize=12`,
324
                `bbox=dict(lw=0, fc='white', alpha=0.7)`.
325
326
            extra_grid_mod (int): Frequency of extra vertical grid lines.
                Default is the length of the standard list of detectors, 6.
327
328
329
            extra_grid_kw (dict): Keywords for the vertical lines separating
                detectors. Default sets `color='red'`.
            cbar (bool): Specify whether to create a colorbar.
330
331
332
            ax (Axes): Optional Axes object. If given, the plot will be made on
                these axes, and fig_kw will be ignored. Otherwise, a new figure
                and axis are created.
333
334
335
336
337
338
339
340
341
342

        Returns:
            A matplotlib.figure.Figure object. (Or, if ax is given, nothing.)
        """
        # Make array from weights
        x = self.fc.weight.detach().numpy()

        if reorder_weights:
            permutation = [
                j * self.n_detector + i
343
                for i in range(self.n_class)
344
345
346
347
348
349
                for j in range(self.n_detector)
            ]
            # Rearrange columns of numpy 2D array in O(n) time
            # https://stackoverflow.com/a/20265477
            idx = np.empty_like(permutation)
            idx[permutation] = np.arange(len(permutation))
350
351
            x = x[:, idx]

352
        if not ax:
353
354
355
356
357
358
359
360
361
362
            fig = visualize_weights(
                x,
                xtick_labels,
                ytick_labels=ytick_labels,
                fig_kw=fig_kw,
                image_kw=image_kw,
                grid_kw=grid_kw,
                text_kw=text_kw,
                cbar=cbar,
            )
363
364

            for i in range(x.shape[1]):
365
                if i % extra_grid_mod == 0:
366
367
368
369
370
                    plt.axvline(i - 0.5, **extra_grid_kw)

            return fig

        else:
371
372
373
374
375
376
377
378
379
380
381
            visualize_weights(
                x,
                xtick_labels,
                ytick_labels=ytick_labels,
                ax=ax,
                fig_kw=fig_kw,
                image_kw=image_kw,
                grid_kw=grid_kw,
                text_kw=text_kw,
                cbar=cbar,
            )
382
383

            for i in range(x.shape[1]):
384
                if i % extra_grid_mod == 0:
385
386
                    ax.axvline(i - 0.5, **extra_grid_kw)

387

388
if __name__ == "__main__":
389
390
391
392
393
394
395
396
397
398
    n_class = 2
    n_detector = 3
    X = np.random.random(size=(10, n_class * n_detector))
    X = torch.from_numpy(X).float()

    simple = SimpleNet(n_class=n_class, n_detector=n_detector)
    out = simple(X)

    dense = DenseNet(n_class=n_class, n_detector=n_detector)
    out = dense(X)