00001
00002
00003
00004
00005
00006
00011 #include <fvar.hpp>
00012 #include <admodel.h>
00013 #include <df1b2fun.h>
00014 #include <adrndeff.h>
00015 #ifndef OPT_LIB
00016 #include <cassert>
00017 #include <climits>
00018 #endif
00019 double evaluate_function(const dvector& x,function_minimizer * pfmin);
00020 void get_second_ders(int xs,int us,const init_df1b2vector y,dmatrix& Hess,
00021 dmatrix& Dux, df1b2_gradlist * f1b2gradlist,function_minimizer * pfmin,
00022 laplace_approximation_calculator* lap);
00023 double calculate_laplace_approximation(const dvector& x,const dvector& u0,
00024 const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00025 const dmatrix& _Hessadjoint,function_minimizer * pmin);
00026
00027 double calculate_importance_sample(const dvector& x,const dvector& u0,
00028 const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00029 const dmatrix& _Hessadjoint,function_minimizer * pmin);
00030
00031 double calculate_importance_sample_funnel(const dvector& x,const dvector& u0,
00032 const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00033 const dmatrix& _Hessadjoint,function_minimizer * pmin);
00034
00035 dmatrix choleski_decomp_positive(const dmatrix& M,double b);
00036
00041 void laplace_approximation_calculator::
00042 check_derivatives(const dvector& _x,function_minimizer * pfmin,double f)
00043 {
00044 cerr << "need to define this" << endl;
00045 ad_exit(1);
00046 }
00047
00052 dvector laplace_approximation_calculator::
00053 default_calculations_check_derivatives(const dvector& _x,
00054 function_minimizer * pfmin, const double& _f)
00055 {
00056
00057 ADUNCONST(dvector,x)
00058 int i,j;
00059 double& f = (double&)_f;
00060
00061 initial_params::set_inactive_only_random_effects();
00062 gradient_structure::set_NO_DERIVATIVES();
00063 initial_params::reset(x);
00064
00065
00066 pfmin->AD_uf_inner();
00067 double fval1=value(*objective_function_value::pobjfun);
00068
00069 gradient_structure::set_YES_DERIVATIVES();
00070
00071 initial_params::set_active_only_random_effects();
00072 initial_params::xinit(uhat);
00073
00074 if (ad_comm::time_flag)
00075 {
00076 if (ad_comm::ptm1)
00077 {
00078 ad_comm::ptm1->get_elapsed_time_and_reset();
00079 }
00080 if (ad_comm::ptm)
00081 {
00082 ad_comm::ptm->get_elapsed_time_and_reset();
00083 }
00084 }
00085 if (ad_comm::time_flag)
00086 {
00087 if (ad_comm::ptm)
00088 {
00089 double time=ad_comm::ptm->get_elapsed_time();
00090 if (ad_comm::global_logfile)
00091 {
00092 (*ad_comm::global_logfile) << " Time pos 0 "
00093 << time << endl;
00094 }
00095 }
00096 }
00097
00098 double maxg = 0;
00099 dvector uhat_old(1,usize);
00100
00101
00102 for (i=1;i<=xsize;i++)
00103 {
00104 y(i)=x(i);
00105 }
00106 for (i=1;i<=usize;i++)
00107 {
00108 y(i+xsize)=uhat(i);
00109 }
00110
00111 int ierr=0;
00112 int niters=0;
00113 if (function_minimizer::first_hessian_flag)
00114 niters=num_nr_iters+1;
00115 else
00116 niters=num_nr_iters;
00117
00118 int nv=0;
00119 if (quadratic_prior::get_num_quadratic_prior()>0)
00120 {
00121 nv=initial_df1b2params::set_index();
00122 if (allocated(used_flags))
00123 {
00124 if (used_flags.indexmax() != nv)
00125 {
00126 used_flags.safe_deallocate();
00127 }
00128 }
00129 if (!allocated(used_flags))
00130 {
00131 used_flags.safe_allocate(1,nv);
00132 }
00133 }
00134
00135 for(int ii=1;ii<=niters;ii++)
00136 {
00137 if (quadratic_prior::get_num_quadratic_prior()>0)
00138 {
00139 check_pool_size();
00140 }
00141 {
00142
00143 Hess.initialize();
00144 cout << "Checking derivatives " << ii << endl;
00145 check_derivatives(x,pfmin,fval1);
00146
00147 if (quadratic_prior::get_num_quadratic_prior()>0)
00148 {
00149 laplace_approximation_calculator::where_are_we_flag=2;
00150 evaluate_function_quiet(uhat,pfmin);
00151 laplace_approximation_calculator::where_are_we_flag=0;
00152 quadratic_prior::get_cHessian_contribution(Hess,xsize);
00153 quadratic_prior::get_cgradient_contribution(grad,xsize);
00154 }
00155
00156
00157
00158
00159
00160
00161 #ifdef DIAG
00162 int print_hess_in_newton_raphson_flag=0;
00163 if (print_hess_in_newton_raphson_flag)
00164 {
00165 cout << norm2(Hess-trans(Hess)) << endl;
00166 if (ad_comm::global_logfile)
00167 {
00168 (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00169 << setw(12) << sort(eigenvalues(Hess)) << endl;
00170 (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00171 << setw(12) << Hess << endl;
00172 }
00173 }
00174 #endif
00175
00176 dvector step;
00177 #if defined(USE_ATLAS)
00178 if (!ad_comm::no_atlas_flag)
00179 {
00180 step=-atlas_solve_spd(Hess,grad,ierr);
00181 }
00182 else
00183 {
00184 dmatrix A=choleski_decomp_positive(Hess,ierr);
00185 if (!ierr)
00186 {
00187 step=-solve(Hess,grad);
00188
00189 }
00190 }
00191 if (ierr)
00192 {
00193 f1b2gradlist->reset();
00194 f1b2gradlist->list.initialize();
00195 f1b2gradlist->list2.initialize();
00196 f1b2gradlist->list3.initialize();
00197 f1b2gradlist->nlist.initialize();
00198 f1b2gradlist->nlist2.initialize();
00199 f1b2gradlist->nlist3.initialize();
00200 break;
00201 }
00202 #else
00203 step=-solve(Hess,grad);
00204 #endif
00205
00206 if (ad_comm::time_flag)
00207 {
00208 if (ad_comm::ptm)
00209 {
00210 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00211 if (ad_comm::global_logfile)
00212 {
00213 (*ad_comm::global_logfile) << " time_in solve " << ii << " "
00214 << time << endl;
00215 }
00216 }
00217 }
00218
00219 f1b2gradlist->reset();
00220 f1b2gradlist->list.initialize();
00221 f1b2gradlist->list2.initialize();
00222 f1b2gradlist->list3.initialize();
00223 f1b2gradlist->nlist.initialize();
00224 f1b2gradlist->nlist2.initialize();
00225 f1b2gradlist->nlist3.initialize();
00226
00227 uhat_old=uhat;
00228 uhat+=step;
00229
00230 double maxg_old=maxg;
00231 maxg=fabs(evaluate_function(uhat,pfmin));
00232 if (maxg>maxg_old)
00233 {
00234 uhat=uhat_old;
00235 evaluate_function(uhat,pfmin);
00236 break;
00237 }
00238 if (maxg < 1.e-13)
00239 {
00240 break;
00241 }
00242 }
00243 for (i=1;i<=usize;i++)
00244 {
00245 y(i+xsize)=uhat(i);
00246 }
00247 }
00248
00249 if (num_nr_iters<=0)
00250 {
00251 evaluate_function(uhat,pfmin);
00252 }
00253
00254 for (i=1;i<=usize;i++)
00255 {
00256 y(i+xsize)=uhat(i);
00257 }
00258
00259
00260 if (ad_comm::time_flag)
00261 {
00262 if (ad_comm::ptm)
00263 {
00264 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00265 if (ad_comm::global_logfile)
00266 {
00267 (*ad_comm::global_logfile) << " Time in reset and evaluate function"
00268 << time << endl;
00269 }
00270 }
00271 }
00272 get_second_ders(xsize,usize,y,Hess,Dux,f1b2gradlist,pfmin,this);
00273
00274
00275 if (ad_comm::time_flag)
00276 {
00277 if (ad_comm::ptm)
00278 {
00279 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00280 if (ad_comm::global_logfile)
00281 {
00282 (*ad_comm::global_logfile) << " Time in dget second ders "
00283 << time << endl;
00284 }
00285 }
00286 }
00287 if (!ierr)
00288 {
00289 if (num_importance_samples==0)
00290 {
00291
00292 f=calculate_laplace_approximation(x,uhat,Hess,xadjoint,uadjoint,
00293 Hessadjoint,pfmin);
00294 }
00295 else
00296 {
00297 if (isfunnel_flag==0)
00298 {
00299 f=calculate_importance_sample(x,uhat,Hess,xadjoint,uadjoint,
00300 Hessadjoint,pfmin);
00301 }
00302 else
00303 {
00304 f=calculate_importance_sample_funnel(x,uhat,Hess,xadjoint,uadjoint,
00305 Hessadjoint,pfmin);
00306 }
00307 }
00308 }
00309 else
00310 {
00311 f=1.e+30;
00312 }
00313
00314 if (ad_comm::time_flag)
00315 {
00316 if (ad_comm::ptm)
00317 {
00318 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00319 if (ad_comm::global_logfile)
00320 {
00321 (*ad_comm::global_logfile) << "Time in calculate laplace approximation "
00322 << time << endl;
00323 }
00324 }
00325 }
00326
00327 for (int ip=num_der_blocks;ip>=1;ip--)
00328 {
00329 df1b2variable::minder=minder(ip);
00330 df1b2variable::maxder=maxder(ip);
00331 int mind=y(1).minder;
00332 int jmin=max(mind,xsize+1);
00333 int jmax=min(y(1).maxder,xsize+usize);
00334 for (i=1;i<=usize;i++)
00335 {
00336 for (j=jmin;j<=jmax;j++)
00337 {
00338
00339 y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00340 }
00341 }
00342
00343 if (initial_df1b2params::separable_flag)
00344 {
00345 for (j=1;j<=xsize+usize;j++)
00346 {
00347 *y(j).get_u_tilde()=0;
00348 }
00349 Hess.initialize();
00350 initial_df1b2params::separable_calculation_type=3;
00351 pfmin->user_function();
00352 }
00353 else
00354 {
00355 if (ip<num_der_blocks)
00356 {
00357 f1b2gradlist->reset();
00358 set_u_dot(ip);
00359 df1b2_gradlist::set_yes_derivatives();
00360 (*re_objective_function_value::pobjfun)=0;
00361 df1b2variable pen=0.0;
00362 df1b2variable zz=0.0;
00363
00364 initial_df1b2params::reset(y,pen);
00365 pfmin->user_function();
00366
00367 re_objective_function_value::fun_without_pen=
00368 value(*re_objective_function_value::pobjfun);
00369
00370 (*re_objective_function_value::pobjfun)+=pen;
00371 (*re_objective_function_value::pobjfun)+=zz;
00372
00373 set_dependent_variable(*re_objective_function_value::pobjfun);
00374 df1b2_gradlist::set_no_derivatives();
00375 df1b2variable::passnumber=1;
00376 df1b2_gradcalc1();
00377 }
00378
00379 for (i=1;i<=usize;i++)
00380 {
00381 for (j=jmin;j<=jmax;j++)
00382 {
00383
00384 y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00385 }
00386 }
00387
00388
00389 df1b2variable::passnumber=2;
00390 df1b2_gradcalc1();
00391
00392 df1b2variable::passnumber=3;
00393 df1b2_gradcalc1();
00394
00395 f1b2gradlist->reset();
00396 f1b2gradlist->list.initialize();
00397 f1b2gradlist->list2.initialize();
00398 f1b2gradlist->list3.initialize();
00399 f1b2gradlist->nlist.initialize();
00400 f1b2gradlist->nlist2.initialize();
00401 f1b2gradlist->nlist3.initialize();
00402 }
00403
00404 if (ad_comm::time_flag)
00405 {
00406 if (ad_comm::ptm)
00407 {
00408 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00409 if (ad_comm::global_logfile)
00410 {
00411 (*ad_comm::global_logfile) << " time for 3rd derivatives "
00412 << time << endl;
00413 }
00414 }
00415 }
00416
00417 dvector dtmp(1,xsize);
00418 for (i=1;i<=xsize;i++)
00419 {
00420 dtmp(i)=*y(i).get_u_tilde();
00421 }
00422 if (initial_df1b2params::separable_flag)
00423 {
00424 #ifndef OPT_LIB
00425 assert(nvar <= INT_MAX);
00426 #endif
00427 dvector scale(1,(int)nvar);
00428 initial_params::stddev_scale(scale,x);
00429 dvector sscale=scale(1,Dux(1).indexmax());
00430 for (i=1;i<=usize;i++)
00431 {
00432 Dux(i)=elem_prod(Dux(i),sscale);
00433 }
00434 dtmp=elem_prod(dtmp,sscale);
00435 }
00436
00437 for (i=1;i<=xsize;i++)
00438 {
00439 xadjoint(i)+=dtmp(i);
00440 }
00441 for (i=1;i<=usize;i++)
00442 uadjoint(i)+=*y(xsize+i).get_u_tilde();
00443 }
00444
00445
00446
00447
00448 int xstuff=3;
00449 if (xstuff && df1b2quadratic_prior::get_num_quadratic_prior()>0)
00450 {
00451 initial_params::straight_through_flag=0;
00452 funnel_init_var::lapprox=0;
00453 block_diagonal_flag=0;
00454 #ifndef OPT_LIB
00455 assert(nvar <= INT_MAX);
00456 #endif
00457 dvector scale1(1,(int)nvar);
00458 initial_params::set_inactive_only_random_effects();
00459 initial_params::stddev_scale(scale1,x);
00460
00461 laplace_approximation_calculator::where_are_we_flag=3;
00462 quadratic_prior::in_qp_calculations=1;
00463 funnel_init_var::lapprox=this;
00464 df1b2_gradlist::set_no_derivatives();
00465 dvector scale(1,(int)nvar);
00466 initial_params::stddev_scale(scale,x);
00467 dvector sscale=scale(1,Dux(1).indexmax());
00468
00469 for (i=1;i<=usize;i++)
00470 {
00471 Dux(i)=elem_div(Dux(i),sscale);
00472 }
00473
00474 if (xstuff>1)
00475 {
00476 df1b2quadratic_prior::get_Lxu_contribution(Dux);
00477 }
00478 quadratic_prior::in_qp_calculations=0;
00479 funnel_init_var::lapprox=0;
00480 laplace_approximation_calculator::where_are_we_flag=0;
00481
00482 for (i=1;i<=usize;i++)
00483 {
00484 Dux(i)=elem_prod(Dux(i),sscale);
00485 }
00486
00487
00488 if (xstuff>2)
00489 {
00490 dvector tmp=evaluate_function_with_quadprior(x,usize,pfmin);
00491 for (i=1;i<=xsize;i++)
00492 {
00493 xadjoint(i)+=tmp(i);
00494 }
00495 }
00496
00497 if (xstuff>2)
00498 {
00499 quadratic_prior::get_cHessian_contribution_from_vHessian(Hess,xsize);
00500 }
00501 }
00502
00503
00504
00505
00506 if (ad_comm::ptm)
00507 {
00508 ad_comm::ptm->get_elapsed_time_and_reset();
00509 }
00510
00511 #if defined(USE_ATLAS)
00512 if (!ad_comm::no_atlas_flag)
00513 {
00514
00515 xadjoint -= atlas_solve_spd_trans(Hess,uadjoint)*Dux;
00516 }
00517 else
00518 {
00519
00520 xadjoint -= solve(Hess,uadjoint)*Dux;
00521 }
00522 #else
00523
00524 xadjoint -= solve(Hess,uadjoint)*Dux;
00525 #endif
00526
00527
00528 if (ad_comm::ptm)
00529 {
00530 double time=ad_comm::ptm->get_elapsed_time_and_reset();
00531 if (ad_comm::global_logfile)
00532 {
00533 (*ad_comm::global_logfile) << " Time in second solve "
00534 << time << endl;
00535 }
00536 }
00537 if (ad_comm::ptm1)
00538 {
00539 double time=ad_comm::ptm1->get_elapsed_time_and_reset();
00540 if (ad_comm::global_logfile)
00541 {
00542 (*ad_comm::global_logfile) << " Total time in function evaluation "
00543 << time << endl << endl;
00544 }
00545 }
00546
00547 return xadjoint;
00548 }