ADMB Documentation  11.5.3197
 All Classes Files Functions Variables Typedefs Friends Defines
df1b2chkder.cpp
Go to the documentation of this file.
00001 /*
00002  * $Id$
00003  *
00004  * Author: David Fournier
00005  * Copyright (c) 2008-2012 Regents of the University of California
00006  */
00011 #include <fvar.hpp>
00012 #include <admodel.h>
00013 #include <df1b2fun.h>
00014 #include <adrndeff.h>
00015 #ifndef OPT_LIB
00016   #include <cassert>
00017   #include <climits>
00018 #endif
00019 double evaluate_function(const dvector& x,function_minimizer * pfmin);
00020 void get_second_ders(int xs,int us,const init_df1b2vector y,dmatrix& Hess,
00021   dmatrix& Dux, df1b2_gradlist * f1b2gradlist,function_minimizer * pfmin,
00022   laplace_approximation_calculator* lap);
00023 double calculate_laplace_approximation(const dvector& x,const dvector& u0,
00024   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00025   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00026 
00027 double calculate_importance_sample(const dvector& x,const dvector& u0,
00028   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00029   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00030 
00031 double calculate_importance_sample_funnel(const dvector& x,const dvector& u0,
00032   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00033   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00034 
00035 dmatrix choleski_decomp_positive(const dmatrix& M,double b);
00036 
00041 void laplace_approximation_calculator::
00042   check_derivatives(const dvector& _x,function_minimizer * pfmin,double f)
00043 {
00044   cerr << "need to define this" << endl;
00045   ad_exit(1);
00046 }
00047 
00052 dvector laplace_approximation_calculator::
00053   default_calculations_check_derivatives(const dvector& _x,
00054     function_minimizer * pfmin, const double& _f)
00055 {
00056   // for use when there is no separability
00057   ADUNCONST(dvector,x)
00058   int i,j;
00059   double& f = (double&)_f;
00060 
00061   initial_params::set_inactive_only_random_effects();
00062   gradient_structure::set_NO_DERIVATIVES();
00063   initial_params::reset(x);    // get current x values into the model
00064 
00065 
00066   pfmin->AD_uf_inner();
00067   double fval1=value(*objective_function_value::pobjfun);
00068 
00069   gradient_structure::set_YES_DERIVATIVES();
00070 
00071   initial_params::set_active_only_random_effects();
00072   initial_params::xinit(uhat);    // get current x values into the model
00073   //int lmn_flag=0;
00074   if (ad_comm::time_flag)
00075   {
00076     if (ad_comm::ptm1)
00077     {
00078       ad_comm::ptm1->get_elapsed_time_and_reset();
00079     }
00080     if (ad_comm::ptm)
00081     {
00082       ad_comm::ptm->get_elapsed_time_and_reset();
00083     }
00084   }
00085   if (ad_comm::time_flag)
00086   {
00087     if (ad_comm::ptm)
00088     {
00089       double time=ad_comm::ptm->get_elapsed_time();
00090       if (ad_comm::global_logfile)
00091       {
00092         (*ad_comm::global_logfile) << " Time pos 0 "
00093           << time << endl;
00094       }
00095     }
00096   }
00097 
00098   double maxg = 0;
00099   dvector uhat_old(1,usize);
00100   //double f_from_1=0.0;
00101 
00102   for (i=1;i<=xsize;i++)
00103   {
00104     y(i)=x(i);
00105   }
00106   for (i=1;i<=usize;i++)
00107   {
00108     y(i+xsize)=uhat(i);
00109   }
00110 
00111   int ierr=0;
00112   int niters=0;
00113   if (function_minimizer::first_hessian_flag)
00114     niters=num_nr_iters+1;
00115   else
00116     niters=num_nr_iters;
00117 
00118   int nv=0;
00119   if (quadratic_prior::get_num_quadratic_prior()>0)
00120   {
00121     nv=initial_df1b2params::set_index();
00122     if (allocated(used_flags))
00123     {
00124       if (used_flags.indexmax() != nv)
00125       {
00126         used_flags.safe_deallocate();
00127       }
00128     }
00129     if (!allocated(used_flags))
00130     {
00131       used_flags.safe_allocate(1,nv);
00132     }
00133   }
00134 
00135   for(int ii=1;ii<=niters;ii++)
00136   {
00137     if (quadratic_prior::get_num_quadratic_prior()>0)
00138     {
00139       check_pool_size();
00140     }
00141     {
00142       // test newton raphson
00143       Hess.initialize();
00144       cout << "Checking derivatives " << ii << endl;
00145       check_derivatives(x,pfmin,fval1);
00146 
00147       if (quadratic_prior::get_num_quadratic_prior()>0)
00148       {
00149         laplace_approximation_calculator::where_are_we_flag=2;
00150         /*double maxg = */evaluate_function_quiet(uhat,pfmin);
00151         laplace_approximation_calculator::where_are_we_flag=0;
00152         quadratic_prior::get_cHessian_contribution(Hess,xsize);
00153         quadratic_prior::get_cgradient_contribution(grad,xsize);
00154       }
00155 
00156       /*
00157       if (ii == 1)
00158         { double diff = fabs(re_objective_function_value::fun_without_pen - objective_function_value::fun_without_pen); }
00159       */
00160 
00161 #ifdef DIAG
00162       int print_hess_in_newton_raphson_flag=0;
00163       if (print_hess_in_newton_raphson_flag)
00164       {
00165         cout << norm2(Hess-trans(Hess)) << endl;
00166         if (ad_comm::global_logfile)
00167         {
00168           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00169             << setw(12) << sort(eigenvalues(Hess)) << endl;
00170           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00171             << setw(12) << Hess << endl;
00172         }
00173       }
00174 #endif
00175 
00176       dvector step;
00177 #if defined(USE_ATLAS)
00178       if (!ad_comm::no_atlas_flag)
00179       {
00180         step=-atlas_solve_spd(Hess,grad,ierr);
00181       }
00182       else
00183       {
00184         dmatrix A=choleski_decomp_positive(Hess,ierr);
00185         if (!ierr)
00186         {
00187           step=-solve(Hess,grad);
00188           //step=-solve(A*trans(A),grad);
00189         }
00190       }
00191       if (ierr)
00192       {
00193         f1b2gradlist->reset();
00194         f1b2gradlist->list.initialize();
00195         f1b2gradlist->list2.initialize();
00196         f1b2gradlist->list3.initialize();
00197         f1b2gradlist->nlist.initialize();
00198         f1b2gradlist->nlist2.initialize();
00199         f1b2gradlist->nlist3.initialize();
00200         break;
00201       }
00202 #else
00203       step=-solve(Hess,grad);
00204 #endif
00205 
00206       if (ad_comm::time_flag)
00207       {
00208         if (ad_comm::ptm)
00209         {
00210           double time=ad_comm::ptm->get_elapsed_time_and_reset();
00211           if (ad_comm::global_logfile)
00212           {
00213             (*ad_comm::global_logfile) << " time_in solve " <<  ii << "  "
00214               << time << endl;
00215           }
00216         }
00217       }
00218 
00219       f1b2gradlist->reset();
00220       f1b2gradlist->list.initialize();
00221       f1b2gradlist->list2.initialize();
00222       f1b2gradlist->list3.initialize();
00223       f1b2gradlist->nlist.initialize();
00224       f1b2gradlist->nlist2.initialize();
00225       f1b2gradlist->nlist3.initialize();
00226 
00227       uhat_old=uhat;
00228       uhat+=step;
00229 
00230       double maxg_old=maxg;
00231       maxg=fabs(evaluate_function(uhat,pfmin));
00232       if (maxg>maxg_old)
00233       {
00234         uhat=uhat_old;
00235         evaluate_function(uhat,pfmin);
00236         break;
00237       }
00238       if (maxg < 1.e-13)
00239       {
00240         break;
00241       }
00242     }
00243     for (i=1;i<=usize;i++)
00244     {
00245       y(i+xsize)=uhat(i);
00246     }
00247   }
00248 
00249   if (num_nr_iters<=0)
00250   {
00251     evaluate_function(uhat,pfmin);
00252   }
00253 
00254   for (i=1;i<=usize;i++)
00255   {
00256     y(i+xsize)=uhat(i);
00257   }
00258 
00259 
00260   if (ad_comm::time_flag)
00261   {
00262     if (ad_comm::ptm)
00263     {
00264       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00265       if (ad_comm::global_logfile)
00266       {
00267         (*ad_comm::global_logfile) << " Time in reset and evaluate function"
00268           << time << endl;
00269       }
00270     }
00271   }
00272   get_second_ders(xsize,usize,y,Hess,Dux,f1b2gradlist,pfmin,this);
00273   //int sgn=0;
00274 
00275   if (ad_comm::time_flag)
00276   {
00277     if (ad_comm::ptm)
00278     {
00279       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00280       if (ad_comm::global_logfile)
00281       {
00282         (*ad_comm::global_logfile) << " Time in dget second ders "
00283           << time << endl;
00284       }
00285     }
00286   }
00287   if (!ierr)
00288   {
00289     if (num_importance_samples==0)
00290     {
00291       //cout << "Hess " << endl << Hess << endl;
00292       f=calculate_laplace_approximation(x,uhat,Hess,xadjoint,uadjoint,
00293         Hessadjoint,pfmin);
00294     }
00295     else
00296     {
00297       if (isfunnel_flag==0)
00298       {
00299         f=calculate_importance_sample(x,uhat,Hess,xadjoint,uadjoint,
00300           Hessadjoint,pfmin);
00301       }
00302       else
00303       {
00304         f=calculate_importance_sample_funnel(x,uhat,Hess,xadjoint,uadjoint,
00305           Hessadjoint,pfmin);
00306       }
00307     }
00308   }
00309   else
00310   {
00311     f=1.e+30;
00312   }
00313 
00314   if (ad_comm::time_flag)
00315   {
00316     if (ad_comm::ptm)
00317     {
00318       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00319       if (ad_comm::global_logfile)
00320       {
00321         (*ad_comm::global_logfile) << "Time in calculate laplace approximation "
00322           << time << endl;
00323       }
00324     }
00325   }
00326 
00327   for (int ip=num_der_blocks;ip>=1;ip--)
00328   {
00329     df1b2variable::minder=minder(ip);
00330     df1b2variable::maxder=maxder(ip);
00331     int mind=y(1).minder;
00332     int jmin=max(mind,xsize+1);
00333     int jmax=min(y(1).maxder,xsize+usize);
00334     for (i=1;i<=usize;i++)
00335     {
00336       for (j=jmin;j<=jmax;j++)
00337       {
00338         //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00339         y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00340       }
00341     }
00342 
00343     if (initial_df1b2params::separable_flag)
00344     {
00345       for (j=1;j<=xsize+usize;j++)
00346       {
00347         *y(j).get_u_tilde()=0;
00348       }
00349       Hess.initialize();
00350       initial_df1b2params::separable_calculation_type=3;
00351       pfmin->user_function();
00352     }
00353     else
00354     {
00355       if (ip<num_der_blocks)
00356       {
00357         f1b2gradlist->reset();
00358         set_u_dot(ip);
00359         df1b2_gradlist::set_yes_derivatives();
00360         (*re_objective_function_value::pobjfun)=0;
00361         df1b2variable pen=0.0;
00362         df1b2variable zz=0.0;
00363 
00364         initial_df1b2params::reset(y,pen);
00365         pfmin->user_function();
00366 
00367         re_objective_function_value::fun_without_pen=
00368           value(*re_objective_function_value::pobjfun);
00369 
00370         (*re_objective_function_value::pobjfun)+=pen;
00371         (*re_objective_function_value::pobjfun)+=zz;
00372 
00373         set_dependent_variable(*re_objective_function_value::pobjfun);
00374         df1b2_gradlist::set_no_derivatives();
00375         df1b2variable::passnumber=1;
00376         df1b2_gradcalc1();
00377       }
00378 
00379       for (i=1;i<=usize;i++)
00380       {
00381         for (j=jmin;j<=jmax;j++)
00382         {
00383           //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00384           y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00385         }
00386       }
00387 
00388       //int mind=y(1).minder;
00389       df1b2variable::passnumber=2;
00390       df1b2_gradcalc1();
00391 
00392       df1b2variable::passnumber=3;
00393       df1b2_gradcalc1();
00394 
00395       f1b2gradlist->reset();
00396       f1b2gradlist->list.initialize();
00397       f1b2gradlist->list2.initialize();
00398       f1b2gradlist->list3.initialize();
00399       f1b2gradlist->nlist.initialize();
00400       f1b2gradlist->nlist2.initialize();
00401       f1b2gradlist->nlist3.initialize();
00402     }
00403 
00404     if (ad_comm::time_flag)
00405     {
00406       if (ad_comm::ptm)
00407       {
00408         double time=ad_comm::ptm->get_elapsed_time_and_reset();
00409         if (ad_comm::global_logfile)
00410         {
00411           (*ad_comm::global_logfile) << " time for 3rd derivatives "
00412             << time << endl;
00413         }
00414       }
00415     }
00416 
00417     dvector dtmp(1,xsize);
00418     for (i=1;i<=xsize;i++)
00419     {
00420       dtmp(i)=*y(i).get_u_tilde();
00421     }
00422     if (initial_df1b2params::separable_flag)
00423     {
00424 #ifndef OPT_LIB
00425       assert(nvar <= INT_MAX);
00426 #endif
00427       dvector scale(1,(int)nvar);   // need to get scale from somewhere
00428       /*int check=*/initial_params::stddev_scale(scale,x);
00429       dvector sscale=scale(1,Dux(1).indexmax());
00430       for (i=1;i<=usize;i++)
00431       {
00432         Dux(i)=elem_prod(Dux(i),sscale);
00433       }
00434       dtmp=elem_prod(dtmp,sscale);
00435     }
00436 
00437     for (i=1;i<=xsize;i++)
00438     {
00439       xadjoint(i)+=dtmp(i);
00440     }
00441     for (i=1;i<=usize;i++)
00442       uadjoint(i)+=*y(xsize+i).get_u_tilde();
00443   }
00444  // *****************************************************************
00445  // new stuff to deal with quadraticprior
00446  // *****************************************************************
00447 
00448     int xstuff=3;
00449     if (xstuff && df1b2quadratic_prior::get_num_quadratic_prior()>0)
00450     {
00451       initial_params::straight_through_flag=0;
00452       funnel_init_var::lapprox=0;
00453       block_diagonal_flag=0;
00454 #ifndef OPT_LIB
00455       assert(nvar <= INT_MAX);
00456 #endif
00457       dvector scale1(1,(int)nvar);   // need to get scale from somewhere
00458       initial_params::set_inactive_only_random_effects();
00459       /*int check=*/initial_params::stddev_scale(scale1,x);
00460 
00461       laplace_approximation_calculator::where_are_we_flag=3;
00462       quadratic_prior::in_qp_calculations=1;
00463       funnel_init_var::lapprox=this;
00464       df1b2_gradlist::set_no_derivatives();
00465       dvector scale(1,(int)nvar);   // need to get scale from somewhere
00466       /*check=*/initial_params::stddev_scale(scale,x);
00467       dvector sscale=scale(1,Dux(1).indexmax());
00468 
00469       for (i=1;i<=usize;i++)
00470       {
00471         Dux(i)=elem_div(Dux(i),sscale);
00472       }
00473 
00474       if (xstuff>1)
00475       {
00476         df1b2quadratic_prior::get_Lxu_contribution(Dux);
00477       }
00478       quadratic_prior::in_qp_calculations=0;
00479       funnel_init_var::lapprox=0;
00480       laplace_approximation_calculator::where_are_we_flag=0;
00481 
00482       for (i=1;i<=usize;i++)
00483       {
00484         Dux(i)=elem_prod(Dux(i),sscale);
00485       }
00486       //local_dtemp=elem_prod(local_dtemp,sscale);
00487 
00488       if (xstuff>2)
00489       {
00490         dvector tmp=evaluate_function_with_quadprior(x,usize,pfmin);
00491         for (i=1;i<=xsize;i++)
00492         {
00493           xadjoint(i)+=tmp(i);
00494         }
00495       }
00496 
00497       if (xstuff>2)
00498       {
00499         quadratic_prior::get_cHessian_contribution_from_vHessian(Hess,xsize);
00500       }
00501     }
00502 
00503  // *****************************************************************
00504  // new stuff to deal with quadraticprior
00505  // *****************************************************************
00506   if (ad_comm::ptm)
00507   {
00508     /*double time=*/ad_comm::ptm->get_elapsed_time_and_reset();
00509   }
00510 
00511 #if defined(USE_ATLAS)
00512       if (!ad_comm::no_atlas_flag)
00513       {
00514         //xadjoint -= uadjoint*atlas_solve_spd_trans(Hess,Dux);
00515         xadjoint -= atlas_solve_spd_trans(Hess,uadjoint)*Dux;
00516       }
00517       else
00518       {
00519         //xadjoint -= uadjoint*solve(Hess,Dux);
00520         xadjoint -= solve(Hess,uadjoint)*Dux;
00521       }
00522 #else
00523       //xadjoint -= uadjoint*solve(Hess,Dux);
00524       xadjoint -= solve(Hess,uadjoint)*Dux;
00525 #endif
00526 
00527 
00528   if (ad_comm::ptm)
00529   {
00530     double time=ad_comm::ptm->get_elapsed_time_and_reset();
00531     if (ad_comm::global_logfile)
00532     {
00533       (*ad_comm::global_logfile) << " Time in second solve "
00534         << time << endl;
00535     }
00536   }
00537   if (ad_comm::ptm1)
00538   {
00539     double time=ad_comm::ptm1->get_elapsed_time_and_reset();
00540     if (ad_comm::global_logfile)
00541     {
00542       (*ad_comm::global_logfile) << " Total time in function evaluation "
00543         << time << endl << endl;
00544     }
00545   }
00546 
00547   return xadjoint;
00548 }