00001
00002
00003
00004
00005
00006
00015 #include <fvar.hpp>
00016 #ifndef OPT_LIB
00017 #include <cassert>
00018 #endif
00019
00020 #define SIGN(a,b) ((b) >= 0.0 ? fabs(a) : -fabs(a))
00021
00022 int svd(int m, int n, int withu, int withv, double eps, double tol,
00023 const dmatrix& a, const dvector& _q,
00024 const dmatrix& _u, const dmatrix& _v);
00025 int svd_nlm(int m, int n, int withu, int withv, double eps, double tol,
00026 const dmatrix& aa, const dvector& _q,
00027 const dmatrix& _u, const dmatrix& _v);
00028 int svd_mln(int m, int n, int withu, int withv, double eps, double tol,
00029 const dmatrix& aa, const dvector& _q,
00030 const dmatrix& _u, const dmatrix& _v);
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00062 sing_val_decomp::sing_val_decomp(const dmatrix& _a, const dvector & _w,
00063 const dmatrix& _v) :
00064 a(_a), w(_w), v(_v)
00065 {}
00066
00071 sing_val_decomp singval_decomp(const dmatrix &_a)
00072 {
00073 if (_a.indexmin() !=1 )
00074 {
00075 cerr << "index error in singval_decomp" << endl;
00076 ad_exit(1);
00077 }
00078 int m = _a.indexmax();
00079 int n = _a(1).indexmax();
00080 dmatrix a(1,m,1,n);
00081 a=_a;
00082 dvector w(1,n);
00083 dmatrix u(1,m,1,n);
00084 dmatrix v(1,n,1,n);
00085
00086 double eps = 1.e-12;
00087 double tol = eps;
00088 int k = svd(m,n,1,1,eps,tol,a,w,u,v);
00089 if(k!=0)
00090 {
00091 cerr << "Error in singval_decomp in iteration " << k << endl;
00092 ad_exit(1);
00093 }
00094 return sing_val_decomp(u,w,v);
00095 }
00096
00118 int svd(int m,int n,int withu,int withv,double eps,double tol,
00119 const dmatrix& aa, const dvector& _q,
00120 const dmatrix& _u, const dmatrix& _v)
00121 {
00122 ADUNCONST(dmatrix,u)
00123 ADUNCONST(dmatrix,v)
00124 ADUNCONST(dvector,q)
00125
00126 int urlb=u.rowmin();
00127 int uclb=u.colmin();
00128 u.rowshift(0);
00129 u.colshift(0);
00130 int vrlb=v.rowmin();
00131 int vclb=v.colmin();
00132 v.rowshift(0);
00133 v.colshift(0);
00134 int qlb=q.indexmin();
00135 q.shift(0);
00136 dmatrix a=aa;
00137 int arlb=a.rowmin();
00138 int aclb=a.colmin();
00139 a.rowshift(0);
00140 a.colshift(0);
00141
00142 int k;
00143 if(m>=n)
00144 k = svd_nlm(m,n,withu,withv,eps,tol,a,q,u,v);
00145 else
00146 k = svd_mln(m,n,withu,withv,eps,tol,a,q,u,v);
00147
00148 u.rowshift(urlb);
00149 u.colshift(uclb);
00150 v.rowshift(vrlb);
00151 v.colshift(vclb);
00152 q.shift(qlb);
00153 a.rowshift(arlb);
00154 a.colshift(aclb);
00155
00156 return k;
00157 }
00158
00170 int svd_mln(int m, int n, int withu, int withv, double eps, double tol,
00171 const dmatrix& aa, const dvector& _q,
00172 const dmatrix& _u, const dmatrix& _v)
00173 {
00174 ADUNCONST(dmatrix,u)
00175 ADUNCONST(dmatrix,v)
00176 ADUNCONST(dvector,q)
00177
00178 int i,j,k,l,l1,iter,retval;
00179 double c,f,g,h,s,x,y,z;
00180
00181 #ifndef OPT_LIB
00182 assert(n > 0);
00183 #endif
00184
00185 double* e = (double*)calloc((size_t)n,sizeof(double));
00186 retval = 0;
00187
00188 u=aa;
00189
00190
00191 g = x = 0.0;
00192 for (i=0;i<n;i++)
00193 {
00194 e[i] = g;
00195 s = g = 0.0;
00196 l = i+1;
00197 if( i<m )
00198 {
00199 for (j=i;j<m;j++)
00200 {
00201 s += (u[j][i]*u[j][i]);
00202 }
00203 if (s < tol)
00204 g = 0.0;
00205 else
00206 {
00207 f = u[i][i];
00208 g = (f < 0) ? sqrt(s) : -sqrt(s);
00209 h = f * g - s;
00210 u[i][i] = f - g;
00211 for (j=l;j<n;j++)
00212 {
00213 s = 0.0;
00214 for (k=i;k<m;k++)
00215 {
00216 s += (u[k][i] * u[k][j]);
00217 }
00218 f = s / h;
00219 for (k=i;k<m;k++)
00220 {
00221 u[k][j] += (f * u[k][i]);
00222 }
00223 }
00224 }
00225 }
00226 q[i] = g;
00227 s = g = 0.0;
00228 if( i<m && i!=n-1 )
00229 {
00230 for (j=l;j<n;j++)
00231 {
00232 s += (u[i][j] * u[i][j]);
00233 }
00234 if (s < tol)
00235 g = 0.0;
00236 else
00237 {
00238 f = u[i][i+1];
00239 g = (f < 0) ? sqrt(s) : -sqrt(s);
00240 h = f * g - s;
00241 u[i][i+1] = f - g;
00242 for (j=l;j<n;j++)
00243 {
00244 e[j] = u[i][j]/h;
00245 }
00246 for (j=l;j<m;j++)
00247 {
00248 s = 0.0;
00249 for (k=l;k<n;k++)
00250 {
00251 s += (u[j][k] * u[i][k]);
00252 }
00253 for (k=l;k<n;k++)
00254 {
00255 u[j][k] += (s * e[k]);
00256 }
00257 }
00258 }
00259 }
00260 y = fabs(q[i]) + fabs(e[i]);
00261 if (y > x)
00262 {
00263 x = y;
00264 }
00265 }
00266
00267
00268 if (withv)
00269 {
00270
00271 l = n;
00272 for (i=n-1;i>=0;i--)
00273 {
00274 if ( i < n-2 )
00275 {
00276 if (g != 0.0)
00277 {
00278 h = u[i][i+1] * g;
00279 for (j=l;j<n;j++)
00280 {
00281 v[j][i] = u[i][j]/h;
00282 }
00283 for (j=l;j<n;j++)
00284 {
00285 s = 0.0;
00286 for (k=l;k<n;k++)
00287 {
00288 s += (u[i][k] * v[k][j]);
00289 }
00290 for (k=l;k<n;k++)
00291 {
00292 v[k][j] += (s * v[k][i]);
00293 }
00294 }
00295 }
00296 for (j=l;j<n;j++)
00297 {
00298 v[i][j] = v[j][i] = 0.0;
00299 }
00300 }
00301 v[i][i] = 1.0;
00302 g = e[i];
00303 l = i;
00304 }
00305 }
00306
00307
00308 if (withu) {
00309 for (i=min(m,n)-1;i>=0;i--) {
00310 l = i + 1;
00311 g = q[i];
00312 for (j=l;j<n;j++)
00313 u[i][j] = 0.0;
00314 if (g != 0.0) {
00315 h = u[i][i] * g;
00316 for (j=l;j<n;j++) {
00317 s = 0.0;
00318 for (k=l;k<m;k++)
00319 s += (u[k][i] * u[k][j]);
00320 f = s / h;
00321 for (k=i;k<m;k++)
00322 u[k][j] += (f * u[k][i]);
00323 }
00324 for (j=i;j<m;j++)
00325 u[j][i] /= g;
00326 }
00327 else {
00328 for (j=i;j<m;j++)
00329 u[j][i] = 0.0;
00330 }
00331 u[i][i] += 1.0;
00332 }
00333 }
00334
00335
00336 eps *= x;
00337 for (k=n-1;k>=0;k--) {
00338 iter = 0;
00339 test_f_splitting:
00340 for (l=k;l>=0;l--) {
00341 if (fabs(e[l]) <= eps) goto test_f_convergence;
00342 if (fabs(q[l-1]) <= eps) goto cancellation;
00343 }
00344
00345
00346 cancellation:
00347 c = 0.0;
00348 s = 1.0;
00349 l1 = l - 1;
00350 for (i=l;i<=k;i++) {
00351 f = s * e[i];
00352 e[i] *= c;
00353 if (fabs(f) <= eps) goto test_f_convergence;
00354 g = q[i];
00355 h = q[i] = sqrt(f*f + g*g);
00356 c = g / h;
00357 s = -f / h;
00358 if (withu) {
00359 for (j=0;j<m;j++) {
00360 y = u[j][l1];
00361 z = u[j][i];
00362 u[j][l1] = y * c + z * s;
00363 u[j][i] = -y * s + z * c;
00364 }
00365 }
00366 }
00367 test_f_convergence:
00368 z = q[k];
00369 if (l == k) goto convergence;
00370
00371
00372 iter++;
00373 if (iter > 30) {
00374 retval = k;
00375 break;
00376 }
00377 x = q[l];
00378 y = q[k-1];
00379 g = e[k-1];
00380 h = e[k];
00381 f = ((y-z)*(y+z) + (g-h)*(g+h)) / (2*h*y);
00382 g = sqrt(f*f + 1.0);
00383 f = ((x-z)*(x+z) + h*(y/((f<0)?(f-g):(f+g))-h))/x;
00384
00385 c = s = 1.0;
00386 for (i=l+1;i<=k;i++) {
00387 g = e[i];
00388 y = q[i];
00389 h = s * g;
00390 g *= c;
00391 e[i-1] = z = sqrt(f*f+h*h);
00392 c = f / z;
00393 s = h / z;
00394 f = x * c + g * s;
00395 g = -x * s + g * c;
00396 h = y * s;
00397 y *= c;
00398 if (withv) {
00399 for (j=0;j<n;j++) {
00400 x = v[j][i-1];
00401 z = v[j][i];
00402 v[j][i-1] = x * c + z * s;
00403 v[j][i] = -x * s + z * c;
00404 }
00405 }
00406 q[i-1] = z = sqrt(f*f + h*h);
00407 c = f/z;
00408 s = h/z;
00409 f = c * g + s * y;
00410 x = -s * g + c * y;
00411 if (withu) {
00412 for (j=0;j<m;j++) {
00413 y = u[j][i-1];
00414 z = u[j][i];
00415 u[j][i-1] = y * c + z * s;
00416 u[j][i] = -y * s + z * c;
00417 }
00418 }
00419 }
00420 e[l] = 0.0;
00421 e[k] = f;
00422 q[k] = x;
00423 goto test_f_splitting;
00424 convergence:
00425 if (z < 0.0) {
00426
00427 q[k] = - z;
00428 if (withv) {
00429 for (j=0;j<n;j++)
00430 v[j][k] = -v[j][k];
00431 }
00432 }
00433 }
00434
00435 free(e);
00436
00437 return retval;
00438 }
00439
00451 int svd_nlm(int m, int n, int withu, int withv, double eps, double tol,
00452 const dmatrix& aa, const dvector& _q,
00453 const dmatrix& _u, const dmatrix& _v)
00454 {
00455 ADUNCONST(dmatrix,u)
00456 ADUNCONST(dmatrix,v)
00457 ADUNCONST(dvector,q)
00458
00459 int i,j,k,l,l1,iter,retval;
00460 double c,f,g,h,s,x,y,z;
00461
00462 #ifndef OPT_LIB
00463 assert(n > 0);
00464 #endif
00465 double* e = (double *)calloc((size_t)n, sizeof(double));
00466 retval = 0;
00467
00468 u=aa;
00469
00470 g = x = 0.0;
00471 for (i=0;i<n;i++)
00472 {
00473 e[i] = g;
00474 s = 0.0;
00475 l = i+1;
00476 for (j=i;j<m;j++)
00477 {
00478 s += (u[j][i]*u[j][i]);
00479 }
00480 if (s < tol)
00481 g = 0.0;
00482 else
00483 {
00484 f = u[i][i];
00485 g = (f < 0) ? sqrt(s) : -sqrt(s);
00486 h = f * g - s;
00487 u[i][i] = f - g;
00488 for (j=l;j<n;j++)
00489 {
00490 s = 0.0;
00491 for (k=i;k<m;k++)
00492 {
00493 s += (u[k][i] * u[k][j]);
00494 }
00495 f = s / h;
00496 for (k=i;k<m;k++)
00497 {
00498 u[k][j] += (f * u[k][i]);
00499 }
00500 }
00501 }
00502 q[i] = g;
00503 s = 0.0;
00504 for (j=l;j<n;j++)
00505 {
00506 s += (u[i][j] * u[i][j]);
00507 }
00508 if (s < tol)
00509 g = 0.0;
00510 else
00511 {
00512 f = u[i][i+1];
00513 g = (f < 0) ? sqrt(s) : -sqrt(s);
00514 h = f * g - s;
00515 u[i][i+1] = f - g;
00516 for (j=l;j<n;j++)
00517 {
00518 e[j] = u[i][j]/h;
00519 }
00520 for (j=l;j<m;j++)
00521 {
00522 s = 0.0;
00523 for (k=l;k<n;k++)
00524 {
00525 s += (u[j][k] * u[i][k]);
00526 }
00527 for (k=l;k<n;k++)
00528 {
00529 u[j][k] += (s * e[k]);
00530 }
00531 }
00532 }
00533 y = fabs(q[i]) + fabs(e[i]);
00534 if (y > x)
00535 {
00536 x = y;
00537 }
00538 }
00539
00540
00541 if (withv)
00542 {
00543
00544 l = n;
00545 for (i=n-1;i>=0;i--)
00546 {
00547 if (g != 0.0)
00548 {
00549 h = u[i][i+1] * g;
00550 for (j=l;j<n;j++)
00551 {
00552 v[j][i] = u[i][j]/h;
00553 }
00554 for (j=l;j<n;j++)
00555 {
00556 s = 0.0;
00557 for (k=l;k<n;k++)
00558 {
00559 s += (u[i][k] * v[k][j]);
00560 }
00561 for (k=l;k<n;k++)
00562 {
00563 v[k][j] += (s * v[k][i]);
00564 }
00565 }
00566 }
00567 for (j=l;j<n;j++)
00568 {
00569 v[i][j] = v[j][i] = 0.0;
00570 }
00571 v[i][i] = 1.0;
00572 g = e[i];
00573 l = i;
00574 }
00575 }
00576
00577
00578 if (withu) {
00579 for (i=n-1;i>=0;i--) {
00580 l = i + 1;
00581 g = q[i];
00582 for (j=l;j<n;j++)
00583 u[i][j] = 0.0;
00584 if (g != 0.0) {
00585 h = u[i][i] * g;
00586 for (j=l;j<n;j++) {
00587 s = 0.0;
00588 for (k=l;k<m;k++)
00589 s += (u[k][i] * u[k][j]);
00590 f = s / h;
00591 for (k=i;k<m;k++)
00592 u[k][j] += (f * u[k][i]);
00593 }
00594 for (j=i;j<m;j++)
00595 u[j][i] /= g;
00596 }
00597 else {
00598 for (j=i;j<m;j++)
00599 u[j][i] = 0.0;
00600 }
00601 u[i][i] += 1.0;
00602 }
00603 }
00604
00605
00606 eps *= x;
00607 for (k=n-1;k>=0;k--) {
00608 iter = 0;
00609 test_f_splitting:
00610 for (l=k;l>=0;l--) {
00611 if (fabs(e[l]) <= eps) goto test_f_convergence;
00612 if (fabs(q[l-1]) <= eps) goto cancellation;
00613 }
00614
00615
00616 cancellation:
00617 c = 0.0;
00618 s = 1.0;
00619 l1 = l - 1;
00620 for (i=l;i<=k;i++) {
00621 f = s * e[i];
00622 e[i] *= c;
00623 if (fabs(f) <= eps) goto test_f_convergence;
00624 g = q[i];
00625 h = q[i] = sqrt(f*f + g*g);
00626 c = g / h;
00627 s = -f / h;
00628 if (withu) {
00629 for (j=0;j<m;j++) {
00630 y = u[j][l1];
00631 z = u[j][i];
00632 u[j][l1] = y * c + z * s;
00633 u[j][i] = -y * s + z * c;
00634 }
00635 }
00636 }
00637 test_f_convergence:
00638 z = q[k];
00639 if (l == k) goto convergence;
00640
00641
00642 iter++;
00643 if (iter > 30) {
00644 retval = k;
00645 break;
00646 }
00647 x = q[l];
00648 y = q[k-1];
00649 g = e[k-1];
00650 h = e[k];
00651 f = ((y-z)*(y+z) + (g-h)*(g+h)) / (2*h*y);
00652 g = sqrt(f*f + 1.0);
00653 f = ((x-z)*(x+z) + h*(y/((f<0)?(f-g):(f+g))-h))/x;
00654
00655 c = s = 1.0;
00656 for (i=l+1;i<=k;i++) {
00657 g = e[i];
00658 y = q[i];
00659 h = s * g;
00660 g *= c;
00661 e[i-1] = z = sqrt(f*f+h*h);
00662 c = f / z;
00663 s = h / z;
00664 f = x * c + g * s;
00665 g = -x * s + g * c;
00666 h = y * s;
00667 y *= c;
00668 if (withv) {
00669 for (j=0;j<n;j++) {
00670 x = v[j][i-1];
00671 z = v[j][i];
00672 v[j][i-1] = x * c + z * s;
00673 v[j][i] = -x * s + z * c;
00674 }
00675 }
00676 q[i-1] = z = sqrt(f*f + h*h);
00677 c = f/z;
00678 s = h/z;
00679 f = c * g + s * y;
00680 x = -s * g + c * y;
00681 if (withu) {
00682 for (j=0;j<m;j++) {
00683 y = u[j][i-1];
00684 z = u[j][i];
00685 u[j][i-1] = y * c + z * s;
00686 u[j][i] = -y * s + z * c;
00687 }
00688 }
00689 }
00690 e[l] = 0.0;
00691 e[k] = f;
00692 q[k] = x;
00693 goto test_f_splitting;
00694 convergence:
00695 if (z < 0.0) {
00696
00697 q[k] = - z;
00698 if (withv) {
00699 for (j=0;j<n;j++)
00700 v[j][k] = -v[j][k];
00701 }
00702 }
00703 }
00704
00705 free(e);
00706
00707 return retval;
00708 }