matrix_alg.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. /*
  2. * matrix3.cpp
  3. * Copyright (C) Siddharth Bharat Purohit, 3DRobotics Inc. 2015
  4. *
  5. * This file is free software: you can redistribute it and/or modify it
  6. * under the terms of the GNU General Public License as published by the
  7. * Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This file is distributed in the hope that it will be useful, but
  11. * WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  13. * See the GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License along
  16. * with this program. If not, see <http://www.gnu.org/licenses/>.
  17. */
  18. #pragma GCC optimize("O2")
  19. #include <AP_HAL/AP_HAL.h>
  20. #include <stdio.h>
  21. #if CONFIG_HAL_BOARD == HAL_BOARD_SITL
  22. #include <fenv.h>
  23. #endif
  24. #include <AP_Math/AP_Math.h>
  25. extern const AP_HAL::HAL& hal;
  26. //TODO: use higher precision datatypes to achieve more accuracy for matrix algebra operations
  27. /*
  28. * Does matrix multiplication of two regular/square matrices
  29. *
  30. * @param A, Matrix A
  31. * @param B, Matrix B
  32. * @param n, dimemsion of square matrices
  33. * @returns multiplied matrix i.e. A*B
  34. */
  35. float* mat_mul(float *A, float *B, uint8_t n)
  36. {
  37. float* ret = new float[n*n];
  38. memset(ret,0.0f,n*n*sizeof(float));
  39. for(uint8_t i = 0; i < n; i++) {
  40. for(uint8_t j = 0; j < n; j++) {
  41. for(uint8_t k = 0;k < n; k++) {
  42. ret[i*n + j] += A[i*n + k] * B[k*n + j];
  43. }
  44. }
  45. }
  46. return ret;
  47. }
  48. static inline void swap(float &a, float &b)
  49. {
  50. float c;
  51. c = a;
  52. a = b;
  53. b = c;
  54. }
  55. /*
  56. * calculates pivot matrix such that all the larger elements in the row are on diagonal
  57. *
  58. * @param A, input matrix matrix
  59. * @param pivot
  60. * @param n, dimenstion of square matrix
  61. * @returns false = matrix is Singular or non positive definite, true = matrix inversion successful
  62. */
  63. static void mat_pivot(float* A, float* pivot, uint8_t n)
  64. {
  65. for(uint8_t i = 0;i<n;i++){
  66. for(uint8_t j=0;j<n;j++) {
  67. pivot[i*n+j] = static_cast<float>(i==j);
  68. }
  69. }
  70. for(uint8_t i = 0;i < n; i++) {
  71. uint8_t max_j = i;
  72. for(uint8_t j=i;j<n;j++){
  73. if(fabsf(A[j*n + i]) > fabsf(A[max_j*n + i])) {
  74. max_j = j;
  75. }
  76. }
  77. if(max_j != i) {
  78. for(uint8_t k = 0; k < n; k++) {
  79. swap(pivot[i*n + k], pivot[max_j*n + k]);
  80. }
  81. }
  82. }
  83. }
  84. /*
  85. * calculates matrix inverse of Lower trangular matrix using forward substitution
  86. *
  87. * @param L, lower triangular matrix
  88. * @param out, Output inverted lower triangular matrix
  89. * @param n, dimension of matrix
  90. */
  91. static void mat_forward_sub(float *L, float *out, uint8_t n)
  92. {
  93. // Forward substitution solve LY = I
  94. for(int i = 0; i < n; i++) {
  95. out[i*n + i] = 1/L[i*n + i];
  96. for (int j = i+1; j < n; j++) {
  97. for (int k = i; k < j; k++) {
  98. out[j*n + i] -= L[j*n + k] * out[k*n + i];
  99. }
  100. out[j*n + i] /= L[j*n + j];
  101. }
  102. }
  103. }
  104. /*
  105. * calculates matrix inverse of Upper trangular matrix using backward substitution
  106. *
  107. * @param U, upper triangular matrix
  108. * @param out, Output inverted upper triangular matrix
  109. * @param n, dimension of matrix
  110. */
  111. static void mat_back_sub(float *U, float *out, uint8_t n)
  112. {
  113. // Backward Substitution solve UY = I
  114. for(int i = n-1; i >= 0; i--) {
  115. out[i*n + i] = 1/U[i*n + i];
  116. for (int j = i - 1; j >= 0; j--) {
  117. for (int k = i; k > j; k--) {
  118. out[j*n + i] -= U[j*n + k] * out[k*n + i];
  119. }
  120. out[j*n + i] /= U[j*n + j];
  121. }
  122. }
  123. }
  124. /*
  125. * Decomposes square matrix into Lower and Upper triangular matrices such that
  126. * A*P = L*U, where P is the pivot matrix
  127. * ref: http://rosettacode.org/wiki/LU_decomposition
  128. * @param U, upper triangular matrix
  129. * @param out, Output inverted upper triangular matrix
  130. * @param n, dimension of matrix
  131. */
  132. static void mat_LU_decompose(float* A, float* L, float* U, float *P, uint8_t n)
  133. {
  134. memset(L,0,n*n*sizeof(float));
  135. memset(U,0,n*n*sizeof(float));
  136. memset(P,0,n*n*sizeof(float));
  137. mat_pivot(A,P,n);
  138. float *APrime = mat_mul(P,A,n);
  139. for(uint8_t i = 0; i < n; i++) {
  140. L[i*n + i] = 1;
  141. }
  142. for(uint8_t i = 0; i < n; i++) {
  143. for(uint8_t j = 0; j < n; j++) {
  144. if(j <= i) {
  145. U[j*n + i] = APrime[j*n + i];
  146. for(uint8_t k = 0; k < j; k++) {
  147. U[j*n + i] -= L[j*n + k] * U[k*n + i];
  148. }
  149. }
  150. if(j >= i) {
  151. L[j*n + i] = APrime[j*n + i];
  152. for(uint8_t k = 0; k < i; k++) {
  153. L[j*n + i] -= L[j*n + k] * U[k*n + i];
  154. }
  155. L[j*n + i] /= U[i*n + i];
  156. }
  157. }
  158. }
  159. delete[] APrime;
  160. }
  161. /*
  162. * matrix inverse code for any square matrix using LU decomposition
  163. * inv = inv(U)*inv(L)*P, where L and U are triagular matrices and P the pivot matrix
  164. * ref: http://www.cl.cam.ac.uk/teaching/1314/NumMethods/supporting/mcmaster-kiruba-ludecomp.pdf
  165. * @param m, input 4x4 matrix
  166. * @param inv, Output inverted 4x4 matrix
  167. * @param n, dimension of square matrix
  168. * @returns false = matrix is Singular, true = matrix inversion successful
  169. */
  170. static bool mat_inverse(float* A, float* inv, uint8_t n)
  171. {
  172. float *L, *U, *P;
  173. bool ret = true;
  174. L = new float[n*n];
  175. U = new float[n*n];
  176. P = new float[n*n];
  177. mat_LU_decompose(A,L,U,P,n);
  178. float *L_inv = new float[n*n];
  179. float *U_inv = new float[n*n];
  180. memset(L_inv,0,n*n*sizeof(float));
  181. mat_forward_sub(L,L_inv,n);
  182. memset(U_inv,0,n*n*sizeof(float));
  183. mat_back_sub(U,U_inv,n);
  184. // decomposed matrices no longer required
  185. delete[] L;
  186. delete[] U;
  187. float *inv_unpivoted = mat_mul(U_inv,L_inv,n);
  188. float *inv_pivoted = mat_mul(inv_unpivoted, P, n);
  189. //check sanity of results
  190. for(uint8_t i = 0; i < n; i++) {
  191. for(uint8_t j = 0; j < n; j++) {
  192. if(isnan(inv_pivoted[i*n+j]) || isinf(inv_pivoted[i*n+j])){
  193. ret = false;
  194. }
  195. }
  196. }
  197. memcpy(inv,inv_pivoted,n*n*sizeof(float));
  198. //free memory
  199. delete[] inv_pivoted;
  200. delete[] inv_unpivoted;
  201. delete[] P;
  202. delete[] U_inv;
  203. delete[] L_inv;
  204. return ret;
  205. }
  206. /*
  207. * fast matrix inverse code only for 3x3 square matrix
  208. *
  209. * @param m, input 4x4 matrix
  210. * @param invOut, Output inverted 4x4 matrix
  211. * @returns false = matrix is Singular, true = matrix inversion successful
  212. */
  213. bool inverse3x3(float m[], float invOut[])
  214. {
  215. float inv[9];
  216. // computes the inverse of a matrix m
  217. float det = m[0] * (m[4] * m[8] - m[7] * m[5]) -
  218. m[1] * (m[3] * m[8] - m[5] * m[6]) +
  219. m[2] * (m[3] * m[7] - m[4] * m[6]);
  220. if (is_zero(det) || isinf(det)) {
  221. return false;
  222. }
  223. float invdet = 1 / det;
  224. inv[0] = (m[4] * m[8] - m[7] * m[5]) * invdet;
  225. inv[1] = (m[2] * m[7] - m[1] * m[8]) * invdet;
  226. inv[2] = (m[1] * m[5] - m[2] * m[4]) * invdet;
  227. inv[3] = (m[5] * m[6] - m[3] * m[8]) * invdet;
  228. inv[4] = (m[0] * m[8] - m[2] * m[6]) * invdet;
  229. inv[5] = (m[3] * m[2] - m[0] * m[5]) * invdet;
  230. inv[6] = (m[3] * m[7] - m[6] * m[4]) * invdet;
  231. inv[7] = (m[6] * m[1] - m[0] * m[7]) * invdet;
  232. inv[8] = (m[0] * m[4] - m[3] * m[1]) * invdet;
  233. for(uint8_t i = 0; i < 9; i++){
  234. invOut[i] = inv[i];
  235. }
  236. return true;
  237. }
  238. /*
  239. * fast matrix inverse code only for 4x4 square matrix copied from
  240. * gluInvertMatrix implementation in opengl for 4x4 matrices.
  241. *
  242. * @param m, input 4x4 matrix
  243. * @param invOut, Output inverted 4x4 matrix
  244. * @returns false = matrix is Singular, true = matrix inversion successful
  245. */
  246. bool inverse4x4(float m[],float invOut[])
  247. {
  248. float inv[16], det;
  249. uint8_t i;
  250. #if CONFIG_HAL_BOARD == HAL_BOARD_SITL
  251. int old = fedisableexcept(FE_OVERFLOW);
  252. if (old < 0) {
  253. hal.console->printf("inverse4x4(): warning: error on disabling FE_OVERFLOW floating point exception\n");
  254. }
  255. #endif
  256. inv[0] = m[5] * m[10] * m[15] -
  257. m[5] * m[11] * m[14] -
  258. m[9] * m[6] * m[15] +
  259. m[9] * m[7] * m[14] +
  260. m[13] * m[6] * m[11] -
  261. m[13] * m[7] * m[10];
  262. inv[4] = -m[4] * m[10] * m[15] +
  263. m[4] * m[11] * m[14] +
  264. m[8] * m[6] * m[15] -
  265. m[8] * m[7] * m[14] -
  266. m[12] * m[6] * m[11] +
  267. m[12] * m[7] * m[10];
  268. inv[8] = m[4] * m[9] * m[15] -
  269. m[4] * m[11] * m[13] -
  270. m[8] * m[5] * m[15] +
  271. m[8] * m[7] * m[13] +
  272. m[12] * m[5] * m[11] -
  273. m[12] * m[7] * m[9];
  274. inv[12] = -m[4] * m[9] * m[14] +
  275. m[4] * m[10] * m[13] +
  276. m[8] * m[5] * m[14] -
  277. m[8] * m[6] * m[13] -
  278. m[12] * m[5] * m[10] +
  279. m[12] * m[6] * m[9];
  280. inv[1] = -m[1] * m[10] * m[15] +
  281. m[1] * m[11] * m[14] +
  282. m[9] * m[2] * m[15] -
  283. m[9] * m[3] * m[14] -
  284. m[13] * m[2] * m[11] +
  285. m[13] * m[3] * m[10];
  286. inv[5] = m[0] * m[10] * m[15] -
  287. m[0] * m[11] * m[14] -
  288. m[8] * m[2] * m[15] +
  289. m[8] * m[3] * m[14] +
  290. m[12] * m[2] * m[11] -
  291. m[12] * m[3] * m[10];
  292. inv[9] = -m[0] * m[9] * m[15] +
  293. m[0] * m[11] * m[13] +
  294. m[8] * m[1] * m[15] -
  295. m[8] * m[3] * m[13] -
  296. m[12] * m[1] * m[11] +
  297. m[12] * m[3] * m[9];
  298. inv[13] = m[0] * m[9] * m[14] -
  299. m[0] * m[10] * m[13] -
  300. m[8] * m[1] * m[14] +
  301. m[8] * m[2] * m[13] +
  302. m[12] * m[1] * m[10] -
  303. m[12] * m[2] * m[9];
  304. inv[2] = m[1] * m[6] * m[15] -
  305. m[1] * m[7] * m[14] -
  306. m[5] * m[2] * m[15] +
  307. m[5] * m[3] * m[14] +
  308. m[13] * m[2] * m[7] -
  309. m[13] * m[3] * m[6];
  310. inv[6] = -m[0] * m[6] * m[15] +
  311. m[0] * m[7] * m[14] +
  312. m[4] * m[2] * m[15] -
  313. m[4] * m[3] * m[14] -
  314. m[12] * m[2] * m[7] +
  315. m[12] * m[3] * m[6];
  316. inv[10] = m[0] * m[5] * m[15] -
  317. m[0] * m[7] * m[13] -
  318. m[4] * m[1] * m[15] +
  319. m[4] * m[3] * m[13] +
  320. m[12] * m[1] * m[7] -
  321. m[12] * m[3] * m[5];
  322. inv[14] = -m[0] * m[5] * m[14] +
  323. m[0] * m[6] * m[13] +
  324. m[4] * m[1] * m[14] -
  325. m[4] * m[2] * m[13] -
  326. m[12] * m[1] * m[6] +
  327. m[12] * m[2] * m[5];
  328. inv[3] = -m[1] * m[6] * m[11] +
  329. m[1] * m[7] * m[10] +
  330. m[5] * m[2] * m[11] -
  331. m[5] * m[3] * m[10] -
  332. m[9] * m[2] * m[7] +
  333. m[9] * m[3] * m[6];
  334. inv[7] = m[0] * m[6] * m[11] -
  335. m[0] * m[7] * m[10] -
  336. m[4] * m[2] * m[11] +
  337. m[4] * m[3] * m[10] +
  338. m[8] * m[2] * m[7] -
  339. m[8] * m[3] * m[6];
  340. inv[11] = -m[0] * m[5] * m[11] +
  341. m[0] * m[7] * m[9] +
  342. m[4] * m[1] * m[11] -
  343. m[4] * m[3] * m[9] -
  344. m[8] * m[1] * m[7] +
  345. m[8] * m[3] * m[5];
  346. inv[15] = m[0] * m[5] * m[10] -
  347. m[0] * m[6] * m[9] -
  348. m[4] * m[1] * m[10] +
  349. m[4] * m[2] * m[9] +
  350. m[8] * m[1] * m[6] -
  351. m[8] * m[2] * m[5];
  352. det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
  353. #if CONFIG_HAL_BOARD == HAL_BOARD_SITL
  354. if (old >= 0 && feenableexcept(old) < 0) {
  355. hal.console->printf("inverse4x4(): warning: error on restoring floating exception mask\n");
  356. }
  357. #endif
  358. if (is_zero(det) || isinf(det)){
  359. return false;
  360. }
  361. det = 1.0f / det;
  362. for (i = 0; i < 16; i++)
  363. invOut[i] = inv[i] * det;
  364. return true;
  365. }
  366. /*
  367. * generic matrix inverse code
  368. *
  369. * @param x, input nxn matrix
  370. * @param y, Output inverted nxn matrix
  371. * @param n, dimension of square matrix
  372. * @returns false = matrix is Singular, true = matrix inversion successful
  373. */
  374. bool inverse(float x[], float y[], uint16_t dim)
  375. {
  376. switch(dim){
  377. case 3: return inverse3x3(x,y);
  378. case 4: return inverse4x4(x,y);
  379. default: return mat_inverse(x,y,dim);
  380. }
  381. }