ADMB Documentation  11.5.3197
 All Classes Files Functions Variables Typedefs Friends Defines
vgamdev.cpp
Go to the documentation of this file.
00001 /*
00002   $Id$
00003 
00004   Author: David Fournier
00005   Copyright (c) 2008, 2009, 2010 Regents of the University of California
00006  */
00007 #include <fvar.hpp>
00008 #define ITMAX 200
00009 //#define EPS 3.0e-7
00010 #define EPS 1.0e-9
00011 #define FPMIN 1.0e-30
00012 static void gcf(double& gammcf,double a,double x,double &gln);
00013 static void gser(double& gamser,double a,double x,double& gln);
00014 
00019   dvariable gamma_deviate(const prevariable& _x,const prevariable& _a)
00020   {
00021     prevariable& x= (prevariable&)(_x);
00022     prevariable& a= (prevariable&)(_a);
00023 
00024     dvariable y=cumd_norm(x);
00025 
00026     y=.9999*y+.00005;
00027 
00028     dvariable z=inv_cumd_gamma(y,a);
00029 
00030     return z;
00031   }
00032 
00033 
00034 static double gammp(double a,double x)
00035 {
00036   double gamser = 0.0,gammcf,gln;
00037 
00038   if (x < 0.0 || a <= 0.0)
00039     cerr << "Invalid arguments in routine gammp" << endl;
00040   if (x < (a+1.0)) {
00041     gser(gamser,a,x,gln);
00042     return gamser;
00043   } else {
00044     gcf(gammcf,a,x,gln);
00045     return 1.0-gammcf;
00046   }
00047 }
00048 
00055 static void gcf(double& gammcf,double a,double x,double &gln)
00056 {
00057   int i;
00058   double an,b,c,d,del,h;
00059 
00060   gln=gammln(a);
00061   b=x+1.0-a;
00062   c=1.0/FPMIN;
00063   d=1.0/b;
00064   h=d;
00065   for (i=1;i<=ITMAX;i++) {
00066     an = -i*(i-a);
00067     b += 2.0;
00068     d=an*d+b;
00069     if (fabs(d) < FPMIN) d=FPMIN;
00070     c=b+an/c;
00071     if (fabs(c) < FPMIN) c=FPMIN;
00072     d=1.0/d;
00073     del=d*c;
00074     h *= del;
00075     if (fabs(del-1.0) < EPS) break;
00076   }
00077   if (i > ITMAX)
00078     cerr << "a too large, ITMAX too small in gcf" << endl;
00079   gammcf=exp(-x+a*log(x)-(gln))*h;
00080 }
00081 
00088 static void gser(double& gamser,double a,double x,double& gln)
00089 {
00090   int n;
00091   double sum,del,ap;
00092 
00093   gln=gammln(a);
00094   if (x <= 0.0) {
00095     if (x < 0.0)
00096       cerr << "x less than 0 in routine gser" << endl;
00097     gamser=0.0;
00098     return;
00099   } else {
00100     ap=a;
00101     del=sum=1.0/a;
00102     for (n=1;n<=ITMAX;n++) {
00103       ++ap;
00104       del *= x/ap;
00105       sum += del;
00106       if (fabs(del) < fabs(sum)*EPS) {
00107         gamser=sum*exp(-x+a*log(x)-(gln));
00108         return;
00109       }
00110     }
00111     cerr << "a too large, ITMAX too small in routine gser" << endl;
00112     return;
00113   }
00114 }
00115 
00116 static double get_initial_u(double a,double y);
00117 
00118 double Sn(double x,double a);
00119 
00120 #include <df32fun.h>
00121 df3_two_variable cumd_gamma(const df3_two_variable& x,
00122   const df3_two_variable& a);
00123 
00124 dvariable inv_cumd_gamma(const prevariable& _y,const prevariable& _a)
00125 {
00126   double a=value(_a);
00127   double y=value(_y);
00128   if (a<0.05)
00129   {
00130     cerr << "a musdt be > 0.1" << endl;
00131     ad_exit(1);
00132   }
00133   double u=get_initial_u(a,y);
00134   double h;
00135   int loop_counter=0;
00136   do
00137   {
00138     loop_counter++;
00139     double z=gammp(a,a*exp(u));
00140     double d=y-z;
00141     //cout << d << endl;
00142     double log_fprime=a*log(a)+a*(u-exp(u)) -gammln(a);
00143     double fprime=exp(log_fprime);
00144     h=d/fprime;
00145     u+=h;
00146     if (loop_counter>1000)
00147     {
00148       cerr << "Error in inv_cumd_gamma"
00149         " maximum number of interations exceeded for values"
00150         << endl << "  x = " << y << "  a =  " << a  << "  h =  " << h  << endl;
00151     }
00152   }
00153   while(fabs(h)>1.e-12);
00154 
00155   double x=a*exp(u);
00156 
00157   init_df3_two_variable xx(x);
00158   init_df3_two_variable aa(a);
00159   *xx.get_u_x()=1.0;
00160   *aa.get_u_y()=1.0;
00161 
00162   df3_two_variable z=cumd_gamma(xx,aa);
00163   double F_x=1.0/(*z.get_u_x());
00164   double F_y=-F_x*(*z.get_u_y());
00165 
00166   dvariable vz=0.0;
00167   value(vz)=x;
00168 
00169   gradient_structure::GRAD_STACK1->set_gradient_stack(default_evaluation,
00170     &(vz.v->x),&(_y.v->x),F_x,&(_a.v->x),F_y);
00171 
00172   return vz;
00173 }
00174 
00175 #undef ITMAX
00176 #undef EPS
00177 
00178 double Sn(double x,double a)
00179 {
00180   double summ=1.0;
00181 
00182   const double xp=x;
00183   double prod=1.0;
00184 
00185   int i=1;
00186   for (; i <= 50; i++)
00187   {
00188     prod*=(a+i);
00189     double summand=xp/prod;
00190     if (summand<1.e-4) break;
00191     summ+=summand;
00192   }
00193   if (i > 50)
00194   {
00195     cerr << "convergence error" << endl;
00196     ad_exit(1);
00197   }
00198   return summ;
00199 }
00200 
00201 static double get_initial_u(double a,double y)
00202 {
00203   const double c=0.57721;
00204   // note that P = y;
00205   double logP=log(y);
00206   double logQ=log(1-y);
00207   double logB=logQ+gammln(a);
00208   double x0=1.e+100;
00209   double log_x0=1.e+100;
00210 
00211   if (a<1.0)
00212   {
00213     if ( logB>log(.6) || (logB > log(.45) && a>=.3) )
00214     {
00215       double logu;
00216       if (logB+logQ > log(1.e-8))
00217       {
00218         logu=(logP+gammln(1.0+a))/a;
00219       }
00220       else
00221       {
00222         logu=-exp(logQ)/a -c;
00223       }
00224       double u=exp(logu);
00225       x0=u/(1-u/(1.0+a));
00226       double tmp=log(1-u/(1.0+a));
00227       log_x0=logu;
00228       log_x0-=tmp;
00229     }
00230     else if ( a<.3 && log(.35) <= logB && logB <= log(.6) )
00231     {
00232       double t=exp(-c-exp(logB));
00233       double logt=-c-exp(logB);
00234       double u=t*exp(t);
00235       x0=t*exp(u);
00236       log_x0=logt+u;
00237     }
00238     else if ( (log(.15)<=logB && logB <=log(.35)) ||
00239        ((log(.15)<=logB && logB <=log(.45)) && a>=.3) )
00240     {
00241       double y=-logB;
00242       double v=y-(1-a)*log(y);
00243       x0=y-(1-a)*log(v)-log(1+(1.0-a)/(1.0+v));
00244       log_x0=log(x0);
00245     }
00246     else if (log(.01)<logB && logB < log(.15))
00247     {
00248       double y=-logB;
00249       double v=y-(1-a)*log(y);
00250       x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00251       log_x0=log(x0);
00252     }
00253     else if (logB < log(.01))
00254     {
00255       double y=-logB;
00256       double v=y-(1-a)*log(y);
00257       x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00258       log_x0=log(x0);
00259     }
00260     else
00261     {
00262       cerr << "this can't happen" << endl;
00263       ad_exit(1);
00264     }
00265   }
00266   else  if (a>=1.0)
00267   {
00268     const double a0 = 3.31125922108741;
00269     const double b1 = 6.61053765625462;
00270     const double a1 = 11.6616720288968;
00271     const double b2 = 6.40691597760039;
00272     const double a2 = 4.28342155967104;
00273     const double b3 = 1.27364489782223;
00274     const double a3 = .213623493715853;
00275     const double b4 = .03611708101884203;
00276 
00277     int sgn=1;
00278     double logtau;
00279     if (logP< log(0.5))
00280     {
00281       logtau=logP;
00282       sgn=-1;
00283     }
00284     else
00285     {
00286       logtau=logQ;
00287       sgn=1;
00288     }
00289 
00290     double t=sqrt(-2.0*logtau);
00291 
00292     double num = (((a3*t+a2)*t+a1)*t)+a0;
00293     double den = ((((b4*t+b3)*t+b2)*t)+b1)*t+1;
00294     double s=sgn*(t-num/den);
00295     double s2=s*s;
00296     double s3=s2*s;
00297     double s4=s3*s;
00298     double s5=s4*s;
00299     double roota=sqrt(a);
00300     double w=a+s*roota+(s2-1)/3.0+(s3-7.0*s)/(36.*roota)
00301       -(3.0*s4+7.0*s2-16)/(810.0*a)
00302       +(9.0*s5+256.0*s3-433.0*s)/(38880.0*a*roota);
00303     if (logP< log(0.5))
00304     {
00305       if (w>.15*(a+1))
00306       {
00307         x0=w;
00308       }
00309       else
00310       {
00311         double v=logP+gammln(a+1);
00312         double u1=exp(v+w)/a;
00313         double S1=1+u1/(a+1);
00314         double u2=exp((v+u1-log(S1))/a);
00315         double S2=1+u2/(a+1)+u2*u2/((a+1)*(a+2));
00316         double u3=exp((v+u2-log(S2))/a);
00317         double S3=1+u3/(a+1)+u3*u3/((a+1)*(a+2))
00318          + u3*u3*u3/((a+1)*(a+2)*(a+3));
00319         double z=exp((v+u3-log(S3))/a);
00320         if (z<.002*(a+1.0))
00321         {
00322           x0=z;
00323         }
00324         else
00325         {
00326           double sn=Sn(z,a);
00327           double zbar=exp((v+z-log(sn))/a);
00328           x0=zbar*(1.0-(a*log(zbar)-zbar-v+log(sn))/(a-zbar));
00329         }
00330       }
00331       log_x0=log(x0);
00332     }
00333     else
00334     {
00335       double u = -logB +(a-1.0)*log(w)-log(1.0+(1.0-a)/(1+w));
00336       x0=u;
00337       log_x0=log(x0);
00338     }
00339   }
00340   if (a==1.0)
00341   {
00342     x0=-log(1.0-y);
00343     log_x0=log(x0);
00344   }
00345   return log_x0-log(a);
00346 }