Refinement.cpp 17.7 KB
Newer Older
1
2
3
#define _USE_MATH_DEFINES
#include <cmath>

4
5
6
#include <pinkIndexer/Refinement.h>
#include <pinkIndexer/WrongUsageException.h>
#include <pinkIndexer/eigenSTLContainers.h>
7
8
#include <iostream>
#include <limits>
9
10
11
12

using namespace std;
using namespace Eigen;

13
namespace pinkIndexer
14
{
15
16
    static void roundTowardsZero(Matrix3Xf& x);
    static void roundAwayFromZero(Matrix3Xf& x);
17

18
19
20
21
22
23
    Refinement::Refinement(float tolerance)
        : tolerance(tolerance)
    {
        millerIndices.reserve(500);
        backprojection = NULL;
    }
24

25
26
27
28
29
30
31
    Refinement::Refinement(float tolerance, const Backprojection& backprojection)
        : tolerance(tolerance)
        , backprojection(&backprojection)
    {
        millerIndices.reserve(500);
    }

32
    void Refinement::refineVariableLattice(Lattice& lattice, const Matrix3Xf& ulsDirections, const Array2Xf& ulsBorderNorms)
33
    {
34
        Matrix3f basis = lattice.getBasis();
35
        float delta = 1e-8; //for numerical differentiation, in A^-1
36
37
38
39
40

        float stepSize = lattice.getBasisVectorNorms().maxCoeff() * 0.002;
        float minStepSize = lattice.getBasisVectorNorms().minCoeff() * 0.00001;
        int maxStepsCount = 200;
        meanDefects.resize(maxStepsCount);
41
        meanDefects[0] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
42
        for (int i = 0; i < maxStepsCount; i++)
43
        {
44
45
            // cout << meanDefects[i] << endl;

46
            Array33f gradient; //gradient for change of each basis matrix element
47
48
            Matrix3f offsetBasis = basis;
            offsetBasis(0, 0) += delta;
49
            gradient(0, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
50
51
            offsetBasis(0, 0) = basis(0, 0);
            offsetBasis(1, 0) += delta;
52
            gradient(1, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
53
54
            offsetBasis(1, 0) = basis(1, 0);
            offsetBasis(2, 0) += delta;
55
            gradient(2, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
56
57
            offsetBasis(2, 0) = basis(2, 0);
            offsetBasis(0, 1) += delta;
58
            gradient(0, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
59
60
            offsetBasis(0, 1) = basis(0, 1);
            offsetBasis(1, 1) += delta;
61
            gradient(1, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
62
63
            offsetBasis(1, 1) = basis(1, 1);
            offsetBasis(2, 1) += delta;
64
            gradient(2, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
65
66
            offsetBasis(2, 1) = basis(2, 1);
            offsetBasis(0, 2) += delta;
67
            gradient(0, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
68
69
            offsetBasis(0, 2) = basis(0, 2);
            offsetBasis(1, 2) += delta;
70
            gradient(1, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
71
72
            offsetBasis(1, 2) = basis(1, 2);
            offsetBasis(2, 2) += delta;
73
            gradient(2, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
74
75
76
77
78
79
80
81

            float norm = gradient.matrix().norm();
            gradient = gradient / norm * stepSize;
            if (norm == 0)
            {
                // throw WrongUsageException("Numerical problems! Delta has been chosen too small for current lattice!\n");
                break;
            }
82

83
            basis = basis - gradient.matrix();
84
            meanDefects[i + 1] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
85

86
87
88
            if (meanDefects[i + 1] > meanDefects[i])
            {
                stepSize = stepSize * 0.9;
89

90
91
                if (i > 10 && (meanDefects.segment(i - 4, 4).maxCoeff() - meanDefects.segment(i - 4, 4).minCoeff()) / meanDefects[i] < 0.01) // settled down
                    stepSize = stepSize * 0.2;
92

93
94
95
                if (stepSize < minStepSize)
                    break;
            }
96
        }
97

98
99
        lattice = Lattice(basis);
    }
100

101
    void Refinement::refineFixedLattice(Lattice& lattice, const Matrix3Xf& ulsDirections, const Array2Xf& ulsBorderNorms)
102
103
    {
        Matrix3f basis = lattice.getBasis();
104

105
106
107
108
109
        float degreeDelta = 0.0001 / 180 * M_PI;
        Matrix3f rotX, rotY, rotZ;
        rotX = AngleAxisf(degreeDelta, Vector3f::UnitX());
        rotY = AngleAxisf(degreeDelta, Vector3f::UnitY());
        rotZ = AngleAxisf(degreeDelta, Vector3f::UnitZ());
110
111


112
113
114
115
        float stepSize = 0.1 / 180 * M_PI;
        float minStepSize = 0.001 / 180 * M_PI;
        int maxStepsCount = 200;
        meanDefects.resize(maxStepsCount);
116
        meanDefects[0] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
117
118
119
        for (int i = 0; i < maxStepsCount; i++)
        {
            // cout << meanDefects[i] << endl;
120

121
            Vector3f gradient;
122
123
124
            gradient(0) = getMeanDefect(rotX * basis, ulsDirections, ulsBorderNorms) - meanDefects[i];
            gradient(1) = getMeanDefect(rotY * basis, ulsDirections, ulsBorderNorms) - meanDefects[i];
            gradient(2) = getMeanDefect(rotZ * basis, ulsDirections, ulsBorderNorms) - meanDefects[i];
125

126
127
128
            gradient = -gradient.normalized() * stepSize;
            basis =
                AngleAxisf(gradient(0), Vector3f::UnitX()) * AngleAxisf(gradient(1), Vector3f::UnitY()) * AngleAxisf(gradient(2), Vector3f::UnitZ()) * basis;
129
            meanDefects[i + 1] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
130

131
132
133
            if (meanDefects[i + 1] > meanDefects[i])
            {
                stepSize = stepSize * 0.9;
134

135
136
                if (i > 10 && (meanDefects.segment(i - 4, 4).maxCoeff() - meanDefects.segment(i - 4, 4).minCoeff()) / meanDefects[i] < 0.001) // settled down
                    break;
137

138
139
140
                if (stepSize < minStepSize)
                    break;
            }
141
142
        }

143
        lattice = Lattice(basis);
Yaroslav Gevorkov's avatar
Yaroslav Gevorkov committed
144
    }
145

146
147
148
149
150
151
152
153
154
155
    void Refinement::refineVariableLatticeWithCenter(Lattice& lattice, Vector2f& centerShift, const Eigen::Matrix2Xf& detectorPeaks_m)
    {
        Matrix3f basis = lattice.getBasis();
        float delta = 1e-8;

        float stepSize_basis = lattice.getBasisVectorNorms().maxCoeff() * 0.002;
        float minStepSize_basis = lattice.getBasisVectorNorms().minCoeff() * 0.00001;
        float startStepSize_center = 10e-6;
        int maxStepsCount = 200;
        meanDefects.resize(maxStepsCount);
156
157
        getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, centerShift);
        meanDefects[0] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
158
        if (meanDefects[0] == 1)
159
        {
160
            return;
161
        }
162
        for (int i = 0; i < maxStepsCount; i++)
163
        {
164
165
166
167
168
            // cout << meanDefects[i] << endl;
            if (i % 6 == 0)
            {
                refineCenter(centerShift, basis, detectorPeaks_m, startStepSize_center);
                startStepSize_center *= 0.85;
169
170
                getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, centerShift);
                meanDefects[i] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
171
172
            }

173
174
175
            Array33f basisGradient;
            Matrix3f offsetBasis = basis;
            offsetBasis(0, 0) += delta;
176
            basisGradient(0, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
177
178
            offsetBasis(0, 0) = basis(0, 0);
            offsetBasis(1, 0) += delta;
179
            basisGradient(1, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
180
181
            offsetBasis(1, 0) = basis(1, 0);
            offsetBasis(2, 0) += delta;
182
            basisGradient(2, 0) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
183
184
            offsetBasis(2, 0) = basis(2, 0);
            offsetBasis(0, 1) += delta;
185
            basisGradient(0, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
186
187
            offsetBasis(0, 1) = basis(0, 1);
            offsetBasis(1, 1) += delta;
188
            basisGradient(1, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
189
190
            offsetBasis(1, 1) = basis(1, 1);
            offsetBasis(2, 1) += delta;
191
            basisGradient(2, 1) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms, false) - meanDefects[i];
192
193
            offsetBasis(2, 1) = basis(2, 1);
            offsetBasis(0, 2) += delta;
194
            basisGradient(0, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
195
196
            offsetBasis(0, 2) = basis(0, 2);
            offsetBasis(1, 2) += delta;
197
            basisGradient(1, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
198
199
            offsetBasis(1, 2) = basis(1, 2);
            offsetBasis(2, 2) += delta;
200
            basisGradient(2, 2) = getMeanDefect(offsetBasis, ulsDirections, ulsBorderNorms) - meanDefects[i];
201
202
203
204
205
206

            float norm = basisGradient.matrix().norm();
            basisGradient = basisGradient / norm * stepSize_basis;
            if (norm == 0)
            {
                // throw WrongUsageException("Numerical problems! Delta has been chosen too small for current lattice!\n");
207
                break;
208
            }
209

210
            basis = basis - basisGradient.matrix();
211
            meanDefects[i + 1] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
212

213
214
215
            if (meanDefects[i + 1] > meanDefects[i])
            {
                stepSize_basis = stepSize_basis * 0.9;
216

217
218
219
220
                if (i > 10 && (meanDefects.segment(i - 4, 4).maxCoeff() - meanDefects.segment(i - 4, 4).minCoeff()) / meanDefects[i] < 0.01)
                { // settled down
                    stepSize_basis = stepSize_basis * 0.2;
                }
221

222
223
224
225
                if (stepSize_basis < minStepSize_basis)
                    break;
            }
        }
226

227
        lattice = Lattice(basis);
Yaroslav Gevorkov's avatar
Yaroslav Gevorkov committed
228
    }
229
230

    void Refinement::refineCenter(Eigen::Vector2f& centerShift, const Matrix3f& basis, const Eigen::Matrix2Xf& detectorPeaks_m, float startStepSize)
231
    {
232
233
234
235
236
237
238
239
        float deltaCenterShift = 1e-7;

        float minStepSize_center = 5e-7;
        float stepSize_center = max(startStepSize, minStepSize_center);


        int maxStepsCount = 35;
        meanDefects_centerAdjustment.resize(maxStepsCount);
240
241
        getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, centerShift);
        meanDefects_centerAdjustment[0] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
242
        if (meanDefects_centerAdjustment[0] == 1)
243
        {
244
            return;
245
        }
246
247
248
249
250
251
        for (int i = 0; i < maxStepsCount; i++)
        {
            // cout << meanDefects[i] << endl;

            Vector2f centerOffsetGradient;
            Vector2f offsetCenterShift = centerShift + Vector2f(deltaCenterShift, 0);
252
253
            getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, offsetCenterShift);
            centerOffsetGradient.x() = getMeanDefect(basis, ulsDirections, ulsBorderNorms) - meanDefects_centerAdjustment[i];
254
            offsetCenterShift = centerShift + Vector2f(0, deltaCenterShift);
255
256
            getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, offsetCenterShift);
            centerOffsetGradient.y() = getMeanDefect(basis, ulsDirections, ulsBorderNorms) - meanDefects_centerAdjustment[i];
257
258
259
260
261
262
263
264

            float norm = centerOffsetGradient.norm();
            centerOffsetGradient = centerOffsetGradient / norm * stepSize_center;
            if (norm == 0)
            {
                // throw WrongUsageException("Numerical problems! Delta has been chosen too small for current lattice!\n");
                break;
            }
265

266
            centerShift = centerShift - centerOffsetGradient;
267
268
            getCenterShiftedBackprojection(ulsDirections, ulsBorderNorms, detectorPeaks_m, centerShift);
            meanDefects_centerAdjustment[i + 1] = getMeanDefect(basis, ulsDirections, ulsBorderNorms);
269

270
271
272
            if (meanDefects_centerAdjustment[i + 1] > meanDefects_centerAdjustment[i])
            {
                stepSize_center = stepSize_center * 0.8;
273

274
275
276
277
278
279
                if (i > 5 && (meanDefects_centerAdjustment.segment(i - 4, 4).maxCoeff() - meanDefects_centerAdjustment.segment(i - 4, 4).minCoeff()) /
                                     meanDefects_centerAdjustment[i] <
                                 0.01)
                { // settled down
                    stepSize_center = stepSize_center * 0.2;
                }
280

281
282
283
                if (stepSize_center < minStepSize_center)
                    break;
            }
284
285
        }
    }
286

287
    int Refinement::getFittedPeaksCount(Lattice& lattice, const Eigen::Matrix3Xf& ulsDirections, const Eigen::Array2Xf& ulsBorderNorms)
288
    {
289
        getDefects(defects, lattice.getBasis(), ulsDirections, ulsBorderNorms);
290
        int fittedPeaksCount = (defects < tolerance).count();
291

292
293
        return fittedPeaksCount;
    }
294

295
296
    int Refinement::getFittedPeaks(Lattice& lattice, Eigen::Array<bool, Eigen::Dynamic, 1>& fittedPeaks, const Eigen::Matrix3Xf& ulsDirections,
                                   const Eigen::Array2Xf& ulsBorderNorms)
297
    {
298
        getDefects(defects, lattice.getBasis(), ulsDirections, ulsBorderNorms);
299
        fittedPeaks = (defects < tolerance);
300

301
302
        return fittedPeaks.count();
    }
303

304

305
    void Refinement::getDefects(ArrayXf& defects, const Matrix3f& basis, const Matrix3Xf& ulsDirections, const Array2Xf& ulsBorderNorms,
306
                                bool significantChangesToPreviousCall)
307
    {
308
        Matrix3f basis_inverse = basis.inverse();
309

310
311
        if (significantChangesToPreviousCall)
        {
312
            ulsBorderNormsSquared = ulsBorderNorms.array().square().matrix();
313

314
            millerIndices_close = basis_inverse * (ulsDirections.array().rowwise() * ulsBorderNorms.array().row(0)).matrix();
315
            roundTowardsZero(millerIndices_close);
316
            millerIndices_far = basis_inverse * (ulsDirections.array().rowwise() * ulsBorderNorms.array().row(1)).matrix();
317
318
319
            roundAwayFromZero(millerIndices_far);
        }

320
        int peakCount = ulsDirections.cols();
321
322
        defects.resize(peakCount);
        for (int i = 0; i < peakCount; i++)
323
        {
324
            // create miller indices close to ULS
325
326
327
            millerIndices.clear();
            for (float k = min(millerIndices_close(0, i), millerIndices_far(0, i)), maxK = max(millerIndices_close(0, i), millerIndices_far(0, i)); k <= maxK;
                 k++)
328
            {
329
330
                for (float l = min(millerIndices_close(1, i), millerIndices_far(1, i)), maxL = max(millerIndices_close(1, i), millerIndices_far(1, i));
                     l <= maxL; l++)
331
                {
332
333
334
335
336
                    for (float m = min(millerIndices_close(2, i), millerIndices_far(2, i)), maxM = max(millerIndices_close(2, i), millerIndices_far(2, i));
                         m <= maxM; m++)
                    {
                        millerIndices.emplace_back(k, l, m);
                    }
337
338
339
                }
            }

340
341
342
343
344
            candidatePeaks.noalias() = basis * Map<Matrix3Xf>((float*)millerIndices.data(), 3, millerIndices.size());
            candidatePeaksNormsSquared = candidatePeaks.colwise().squaredNorm();
            // clear peaks that exceed the borders
            validCandidatePeaksCount = 0;
            for (int j = 0, end = candidatePeaksNormsSquared.size(); j < end; j++)
345
            {
346
                if ((candidatePeaksNormsSquared(j) > ulsBorderNormsSquared(0, i)) & (candidatePeaksNormsSquared(j) < ulsBorderNormsSquared(1, i)))
347
348
349
350
                {
                    candidatePeaks.col(validCandidatePeaksCount) = candidatePeaks.col(j);
                    validCandidatePeaksCount++;
                }
351
            }
352
353
354
355
356
357
            if (validCandidatePeaksCount == 0)
            {
                defects(i) = 1;
                continue;
            }

358
            projectedVectorNorms.noalias() = candidatePeaks.leftCols(validCandidatePeaksCount).transpose() * ulsDirections.col(i);
359
            defectVectors_absolute = candidatePeaks.leftCols(validCandidatePeaksCount);
360
            defectVectors_absolute.noalias() -= ulsDirections.col(i) * projectedVectorNorms;
361
362
            defectVectors_relative.noalias() = basis_inverse * defectVectors_absolute;
            defects(i) = defectVectors_relative.cwiseAbs().colwise().maxCoeff().minCoeff();
363
        }
364
365
    }

366
    double Refinement::getMeanDefect(const Matrix3f& basis, const Matrix3Xf& ulsDirections, const Array2Xf& ulsBorderNorms,
367
368
                                     bool significantChangesToPreviousCall)
    {
369
        getDefects(defects, basis, ulsDirections, ulsBorderNorms, significantChangesToPreviousCall);
370
371
372
373
374
375

        notPredictablePeaks = defects > tolerance;
        int16_t notPredictablePeaksCount = notPredictablePeaks.count();
        // cout << "np " << notPredictablePeaksCount << endl;

        if (notPredictablePeaksCount == defects.size())
376
        {
377
            return 1;
378
379
        }

380
381
        sort((float*)defects.data(), (float*)defects.data() + defects.size());
        return defects.head(round(0.9 * (defects.size() - notPredictablePeaksCount))).mean();
382
383
    }

384
    void Refinement::getCenterShiftedBackprojection(Eigen::Matrix3Xf& ulsDirections_local, Eigen::Array2Xf& ulsBorderNorms_local,
385
                                                    const Eigen::Matrix2Xf& detectorPeaks_m, const Eigen::Vector2f& centerShift)
386
    {
387
        detectorPeaks_m_shifted = detectorPeaks_m.colwise() + centerShift;
388
        backprojection->backProject(detectorPeaks_m_shifted, ulsDirections_local, ulsBorderNorms_local);
389
390
    }

391
392
393
394
    static void roundTowardsZero(Matrix3Xf& x)
    {
        x = x.array().abs().floor() * x.array().sign();
    }
395

396
397
398
399
    static void roundAwayFromZero(Matrix3Xf& x)
    {
        x = x.array().abs().ceil() * x.array().sign();
    }
400
} // namespace pinkIndexer