Skip to content

Commit 1efaaef

Browse files
[wpimath] Fix UnscentedKalmanFilter and improve math docs (#7850)
Throughout the code the state sqrt covariance S and innovation covariance Sy are maintained as upper triangular cholesky factors of those covariance matrices. The original paper defines P=S*S', so S should be lower triangular. The functions in the paper reflect this. In the code implementation, the sqrt covariance matrices are upper triangular, but the algorithm expects them to be lower triangular. This bug was likely missed because the incorrect version of the filter is able to converge for some systems where all the states are observed, and the test case is set up such that all states are observed. To fix the bug, a couple things needed to be changed: all instances of rankUpdate() needed to be changed to use the lower triangular cholesky factor, In the unscented transform, when S is found via QR decomposition, we need to take the transpose because R is upper triangular, P() and SetP() functions need to be modified to be P=S*S' instead of P=S'*S, and P.llt().matrixL() instead of P.llt().matrixU() respectively. Each part of the algorithm has also had the comments changed to clarify exactly which equation from the paper it implements. Co-authored-by: Tyler Veness <[email protected]>
1 parent 71b6e8e commit 1efaaef

File tree

7 files changed

+494
-88
lines changed

7 files changed

+494
-88
lines changed

wpimath/src/main/java/edu/wpi/first/math/estimator/MerweScaledSigmaPoints.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public int getNumSigmas() {
6969
}
7070

7171
/**
72-
* Computes the sigma points for an unscented Kalman filter given the mean (x) and covariance(P)
73-
* of the filter.
72+
* Computes the sigma points for an unscented Kalman filter given the mean (x) and square-root
73+
* covariance (s) of the filter.
7474
*
7575
* @param x An array of the means.
7676
* @param s Square-root covariance of the filter.
@@ -86,6 +86,8 @@ public int getNumSigmas() {
8686
// 2 * states + 1 by states
8787
Matrix<S, ?> sigmas =
8888
new Matrix<>(new SimpleMatrix(m_states.getNum(), 2 * m_states.getNum() + 1));
89+
90+
// equation (17)
8991
sigmas.setColumn(0, x);
9092
for (int k = 0; k < m_states.getNum(); k++) {
9193
var xPlusU = x.plus(U.extractColumnVector(k));

wpimath/src/main/java/edu/wpi/first/math/estimator/UnscentedKalmanFilter.java

+96-21
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
* href="https://file.tavsys.net/control/controls-engineering-in-frc.pdf">https://file.tavsys.net/control/controls-engineering-in-frc.pdf</a>
3636
* chapter 9 "Stochastic control theory".
3737
*
38-
* <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). For more
39-
* information about the SR-UKF, see <a
40-
* href="https://www.researchgate.net/publication/3908304">https://www.researchgate.net/publication/3908304</a>.
38+
* <p>This class implements a square-root-form unscented Kalman filter (SR-UKF). The main reason for
39+
* this is to guarantee that the covariance matrix remains positive definite. For more information
40+
* about the SR-UKF, see https://www.researchgate.net/publication/3908304.
4141
*
4242
* @param <States> Number of states.
4343
* @param <Inputs> Number of inputs.
@@ -105,7 +105,7 @@ public UnscentedKalmanFilter(
105105
}
106106

107107
/**
108-
* Constructs an unscented Kalman filter with custom mean, residual, and addition functions. Using
108+
* Constructs an Unscented Kalman filter with custom mean, residual, and addition functions. Using
109109
* custom functions for arithmetic can be useful if you have angles in the state or measurements,
110110
* because they allow you to correctly account for the modular nature of angle arithmetic.
111111
*
@@ -193,12 +193,21 @@ Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
193193
"Wc must be 2 * states + 1 by 1! Got " + Wc.getNumRows() + " by " + Wc.getNumCols());
194194
}
195195

196-
// New mean is usually just the sum of the sigmas * weight:
197-
// n
198-
// dot = Σ W[k] Xᵢ[k]
199-
// k=1
196+
// New mean is usually just the sum of the sigmas * weights:
197+
//
198+
// 2n
199+
// x̂ = Σ Wᵢ⁽ᵐ⁾𝒳ᵢ
200+
// i=0
201+
//
202+
// equations (19) and (23) in the paper show this,
203+
// but we allow a custom function, usually for angle wrapping
200204
Matrix<C, N1> x = meanFunc.apply(sigmas, Wm);
201205

206+
// Form an intermediate matrix S⁻ as:
207+
//
208+
// [√{W₁⁽ᶜ⁾}(𝒳_{1:2L} - x̂) √{Rᵛ}]
209+
//
210+
// the part of equations (20) and (24) within the "qr{}"
202211
Matrix<C, ?> Sbar = new Matrix<>(new SimpleMatrix(dim.getNum(), 2 * s.getNum() + dim.getNum()));
203212
for (int i = 0; i < 2 * s.getNum(); i++) {
204213
Sbar.setColumn(
@@ -214,8 +223,24 @@ Pair<Matrix<C, N1>, Matrix<C, C>> squareRootUnscentedTransform(
214223
throw new RuntimeException("QR decomposition failed! Input matrix:\n" + qrStorage);
215224
}
216225

217-
Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)));
218-
newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), false);
226+
// Compute the square-root covariance of the sigma points
227+
//
228+
// We transpose S⁻ first because we formed it by horizontally
229+
// concatenating each part; it should be vertical so we can take
230+
// the QR decomposition as defined in the "QR Decomposition" passage
231+
// of section 3. "EFFICIENT SQUARE-ROOT IMPLEMENTATION"
232+
//
233+
// The resulting matrix R is the square-root covariance S, but it
234+
// is upper triangular, so we need to transpose it.
235+
//
236+
// equations (20) and (24)
237+
Matrix<C, C> newS = new Matrix<>(new SimpleMatrix(qr.getR(null, true)).transpose());
238+
239+
// Update or downdate the square-root covariance with (𝒳₀-x̂)
240+
// depending on whether its weight (W₀⁽ᶜ⁾) is positive or negative.
241+
//
242+
// equations (21) and (25)
243+
newS.rankUpdate(residualFunc.apply(sigmas.extractColumnVector(0), x), Wc.get(0, 0), true);
219244

220245
return new Pair<>(x, newS);
221246
}
@@ -256,7 +281,7 @@ public void setS(Matrix<States, States> newS) {
256281
*/
257282
@Override
258283
public Matrix<States, States> getP() {
259-
return m_S.transpose().times(m_S);
284+
return m_S.times(m_S.transpose());
260285
}
261286

262287
/**
@@ -280,7 +305,7 @@ public double getP(int row, int col) {
280305
*/
281306
@Override
282307
public void setP(Matrix<States, States> newP) {
283-
m_S = newP.lltDecompose(false);
308+
m_S = newP.lltDecompose(true);
284309
}
285310

286311
/**
@@ -347,14 +372,28 @@ public void predict(Matrix<Inputs, N1> u, double dtSeconds) {
347372
var discQ = Discretization.discretizeAQ(contA, m_contQ, dtSeconds).getSecond();
348373
var squareRootDiscQ = discQ.lltDecompose(true);
349374

375+
// Generate sigma points around the state mean
376+
//
377+
// equation (17)
350378
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
351379

380+
// Project each sigma point forward in time according to the
381+
// dynamics f(x, u)
382+
//
383+
// sigmas = 𝒳ₖ₋₁
384+
// sigmasF = 𝒳ₖ,ₖ₋₁ or just 𝒳 for readability
385+
//
386+
// equation (18)
352387
for (int i = 0; i < m_pts.getNumSigmas(); ++i) {
353388
Matrix<States, N1> x = sigmas.extractColumnVector(i);
354389

355390
m_sigmasF.setColumn(i, NumericalIntegration.rk4(m_f, x, u, dtSeconds));
356391
}
357392

393+
// Pass the predicted sigmas (𝒳) through the Unscented Transform
394+
// to compute the prior state mean and covariance
395+
//
396+
// equations (18) (19) and (20)
358397
var ret =
359398
squareRootUnscentedTransform(
360399
m_states,
@@ -459,15 +498,27 @@ public <R extends Num> void correct(
459498
final var discR = Discretization.discretizeR(R, m_dtSeconds);
460499
final var squareRootDiscR = discR.lltDecompose(true);
461500

462-
// Transform sigma points into measurement space
501+
// Generate new sigma points from the prior mean and covariance
502+
// and transform them into measurement space using h(x, u)
503+
//
504+
// sigmas = 𝒳
505+
// sigmasH = 𝒴
506+
//
507+
// This differs from equation (22) which uses
508+
// the prior sigma points, regenerating them allows
509+
// multiple measurement updates per time update
463510
Matrix<R, ?> sigmasH = new Matrix<>(new SimpleMatrix(rows.getNum(), 2 * m_states.getNum() + 1));
464511
var sigmas = m_pts.squareRootSigmaPoints(m_xHat, m_S);
465512
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
466513
Matrix<R, N1> hRet = h.apply(sigmas.extractColumnVector(i), u);
467514
sigmasH.setColumn(i, hRet);
468515
}
469516

470-
// Mean and covariance of prediction passed through unscented transform
517+
// Pass the predicted measurement sigmas through the Unscented Transform
518+
// to compute the mean predicted measurement and square-root innovation
519+
// covariance.
520+
//
521+
// equations (23) (24) and (25)
471522
var transRet =
472523
squareRootUnscentedTransform(
473524
m_states,
@@ -481,30 +532,54 @@ public <R extends Num> void correct(
481532
var yHat = transRet.getFirst();
482533
var Sy = transRet.getSecond();
483534

484-
// Compute cross covariance of the state and the measurements
535+
// Compute cross covariance of the predicted state and measurement sigma
536+
// points given as:
537+
//
538+
// 2n
539+
// P_{xy} = Σ Wᵢ⁽ᶜ⁾[𝒳ᵢ - x̂][𝒴ᵢ - ŷ⁻]ᵀ
540+
// i=0
541+
//
542+
// equation (26)
485543
Matrix<States, R> Pxy = new Matrix<>(m_states, rows);
486544
for (int i = 0; i < m_pts.getNumSigmas(); i++) {
487-
// Pxy += (sigmas_f[:, i] - x̂)(sigmas_h[:, i] - ŷ)ᵀ W_c[i]
488545
var dx = residualFuncX.apply(m_sigmasF.extractColumnVector(i), m_xHat);
489546
var dy = residualFuncY.apply(sigmasH.extractColumnVector(i), yHat).transpose();
490547

491548
Pxy = Pxy.plus(dx.times(dy).times(m_pts.getWc(i)));
492549
}
493550

494-
// K = (P_{xy} / S_yᵀ) / S_y
495-
// K = (S_y \ P_{xy}ᵀ)ᵀ / S_y
496-
// K = (S_yᵀ \ (S_y \ P_{xy}ᵀ))ᵀ
551+
// Compute the Kalman gain. We use Eigen's QR decomposition to solve. This
552+
// is equivalent to MATLAB's \ operator, so we need to rearrange to use
553+
// that.
554+
//
555+
// K = (P_{xy} / S_{y}ᵀ) / S_{y}
556+
// K = (S_{y} \ P_{xy})ᵀ / S_{y}
557+
// K = (S_{y}ᵀ \ (S_{y} \ P_{xy}ᵀ))ᵀ
558+
//
559+
// equation (27)
497560
Matrix<States, R> K =
498561
Sy.transpose()
499562
.solveFullPivHouseholderQr(Sy.solveFullPivHouseholderQr(Pxy.transpose()))
500563
.transpose();
501564

502-
// x̂ₖ₊₁⁺ = x̂ₖ₊₁⁻ + K(y − ŷ)
565+
// Compute the posterior state mean
566+
//
567+
// x̂ = x̂⁻ + K(y − ŷ⁻)
568+
//
569+
// second part of equation (27)
503570
m_xHat = addFuncX.apply(m_xHat, K.times(residualFuncY.apply(y, yHat)));
504571

572+
// Compute the intermediate matrix U for downdating
573+
// the square-root covariance
574+
//
575+
// equation (28)
505576
Matrix<States, R> U = K.times(Sy);
577+
578+
// Downdate the posterior square-root state covariance
579+
//
580+
// equation (29)
506581
for (int i = 0; i < rows.getNum(); i++) {
507-
m_S.rankUpdate(U.extractColumnVector(i), -1, false);
582+
m_S.rankUpdate(U.extractColumnVector(i), -1, true);
508583
}
509584
}
510585
}

wpimath/src/main/native/include/frc/estimator/MerweScaledSigmaPoints.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class MerweScaledSigmaPoints {
5151

5252
/**
5353
* Computes the sigma points for an unscented Kalman filter given the mean
54-
* (x) and square-root covariance(S) of the filter.
54+
* (x) and square-root covariance (S) of the filter.
5555
*
5656
* @param x An array of the means.
5757
* @param S Square-root covariance of the filter.
@@ -68,6 +68,8 @@ class MerweScaledSigmaPoints {
6868
Matrixd<States, States> U = eta * S;
6969

7070
Matrixd<States, 2 * States + 1> sigmas;
71+
72+
// equation (17)
7173
sigmas.template block<States, 1>(0, 0) = x;
7274
for (int k = 0; k < States; ++k) {
7375
sigmas.template block<States, 1>(0, k + 1) =

0 commit comments

Comments
 (0)