00001
00002
00003
00004
00005
00006
00007 #include <fvar.hpp>
00008 #define ITMAX 200
00009
00010 #define EPS 1.0e-9
00011 #define FPMIN 1.0e-30
00012 static void gcf(double& gammcf,double a,double x,double &gln);
00013 static void gser(double& gamser,double a,double x,double& gln);
00014
00019 dvariable gamma_deviate(const prevariable& _x,const prevariable& _a)
00020 {
00021 prevariable& x= (prevariable&)(_x);
00022 prevariable& a= (prevariable&)(_a);
00023
00024 dvariable y=cumd_norm(x);
00025
00026 y=.9999*y+.00005;
00027
00028 dvariable z=inv_cumd_gamma(y,a);
00029
00030 return z;
00031 }
00032
00033
00034 static double gammp(double a,double x)
00035 {
00036 double gamser = 0.0,gammcf,gln;
00037
00038 if (x < 0.0 || a <= 0.0)
00039 cerr << "Invalid arguments in routine gammp" << endl;
00040 if (x < (a+1.0)) {
00041 gser(gamser,a,x,gln);
00042 return gamser;
00043 } else {
00044 gcf(gammcf,a,x,gln);
00045 return 1.0-gammcf;
00046 }
00047 }
00048
00055 static void gcf(double& gammcf,double a,double x,double &gln)
00056 {
00057 int i;
00058 double an,b,c,d,del,h;
00059
00060 gln=gammln(a);
00061 b=x+1.0-a;
00062 c=1.0/FPMIN;
00063 d=1.0/b;
00064 h=d;
00065 for (i=1;i<=ITMAX;i++) {
00066 an = -i*(i-a);
00067 b += 2.0;
00068 d=an*d+b;
00069 if (fabs(d) < FPMIN) d=FPMIN;
00070 c=b+an/c;
00071 if (fabs(c) < FPMIN) c=FPMIN;
00072 d=1.0/d;
00073 del=d*c;
00074 h *= del;
00075 if (fabs(del-1.0) < EPS) break;
00076 }
00077 if (i > ITMAX)
00078 cerr << "a too large, ITMAX too small in gcf" << endl;
00079 gammcf=exp(-x+a*log(x)-(gln))*h;
00080 }
00081
00088 static void gser(double& gamser,double a,double x,double& gln)
00089 {
00090 int n;
00091 double sum,del,ap;
00092
00093 gln=gammln(a);
00094 if (x <= 0.0) {
00095 if (x < 0.0)
00096 cerr << "x less than 0 in routine gser" << endl;
00097 gamser=0.0;
00098 return;
00099 } else {
00100 ap=a;
00101 del=sum=1.0/a;
00102 for (n=1;n<=ITMAX;n++) {
00103 ++ap;
00104 del *= x/ap;
00105 sum += del;
00106 if (fabs(del) < fabs(sum)*EPS) {
00107 gamser=sum*exp(-x+a*log(x)-(gln));
00108 return;
00109 }
00110 }
00111 cerr << "a too large, ITMAX too small in routine gser" << endl;
00112 return;
00113 }
00114 }
00115
00116 static double get_initial_u(double a,double y);
00117
00118 double Sn(double x,double a);
00119
00120 #include <df32fun.h>
00121 df3_two_variable cumd_gamma(const df3_two_variable& x,
00122 const df3_two_variable& a);
00123
00124 dvariable inv_cumd_gamma(const prevariable& _y,const prevariable& _a)
00125 {
00126 double a=value(_a);
00127 double y=value(_y);
00128 if (a<0.05)
00129 {
00130 cerr << "a musdt be > 0.1" << endl;
00131 ad_exit(1);
00132 }
00133 double u=get_initial_u(a,y);
00134 double h;
00135 int loop_counter=0;
00136 do
00137 {
00138 loop_counter++;
00139 double z=gammp(a,a*exp(u));
00140 double d=y-z;
00141
00142 double log_fprime=a*log(a)+a*(u-exp(u)) -gammln(a);
00143 double fprime=exp(log_fprime);
00144 h=d/fprime;
00145 u+=h;
00146 if (loop_counter>1000)
00147 {
00148 cerr << "Error in inv_cumd_gamma"
00149 " maximum number of interations exceeded for values"
00150 << endl << " x = " << y << " a = " << a << " h = " << h << endl;
00151 }
00152 }
00153 while(fabs(h)>1.e-12);
00154
00155 double x=a*exp(u);
00156
00157 init_df3_two_variable xx(x);
00158 init_df3_two_variable aa(a);
00159 *xx.get_u_x()=1.0;
00160 *aa.get_u_y()=1.0;
00161
00162 df3_two_variable z=cumd_gamma(xx,aa);
00163 double F_x=1.0/(*z.get_u_x());
00164 double F_y=-F_x*(*z.get_u_y());
00165
00166 dvariable vz=0.0;
00167 value(vz)=x;
00168
00169 gradient_structure::GRAD_STACK1->set_gradient_stack(default_evaluation,
00170 &(vz.v->x),&(_y.v->x),F_x,&(_a.v->x),F_y);
00171
00172 return vz;
00173 }
00174
00175 #undef ITMAX
00176 #undef EPS
00177
00178 double Sn(double x,double a)
00179 {
00180 double summ=1.0;
00181
00182 const double xp=x;
00183 double prod=1.0;
00184
00185 int i=1;
00186 for (; i <= 50; i++)
00187 {
00188 prod*=(a+i);
00189 double summand=xp/prod;
00190 if (summand<1.e-4) break;
00191 summ+=summand;
00192 }
00193 if (i > 50)
00194 {
00195 cerr << "convergence error" << endl;
00196 ad_exit(1);
00197 }
00198 return summ;
00199 }
00200
00201 static double get_initial_u(double a,double y)
00202 {
00203 const double c=0.57721;
00204
00205 double logP=log(y);
00206 double logQ=log(1-y);
00207 double logB=logQ+gammln(a);
00208 double x0=1.e+100;
00209 double log_x0=1.e+100;
00210
00211 if (a<1.0)
00212 {
00213 if ( logB>log(.6) || (logB > log(.45) && a>=.3) )
00214 {
00215 double logu;
00216 if (logB+logQ > log(1.e-8))
00217 {
00218 logu=(logP+gammln(1.0+a))/a;
00219 }
00220 else
00221 {
00222 logu=-exp(logQ)/a -c;
00223 }
00224 double u=exp(logu);
00225 x0=u/(1-u/(1.0+a));
00226 double tmp=log(1-u/(1.0+a));
00227 log_x0=logu;
00228 log_x0-=tmp;
00229 }
00230 else if ( a<.3 && log(.35) <= logB && logB <= log(.6) )
00231 {
00232 double t=exp(-c-exp(logB));
00233 double logt=-c-exp(logB);
00234 double u=t*exp(t);
00235 x0=t*exp(u);
00236 log_x0=logt+u;
00237 }
00238 else if ( (log(.15)<=logB && logB <=log(.35)) ||
00239 ((log(.15)<=logB && logB <=log(.45)) && a>=.3) )
00240 {
00241 double y=-logB;
00242 double v=y-(1-a)*log(y);
00243 x0=y-(1-a)*log(v)-log(1+(1.0-a)/(1.0+v));
00244 log_x0=log(x0);
00245 }
00246 else if (log(.01)<logB && logB < log(.15))
00247 {
00248 double y=-logB;
00249 double v=y-(1-a)*log(y);
00250 x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00251 log_x0=log(x0);
00252 }
00253 else if (logB < log(.01))
00254 {
00255 double y=-logB;
00256 double v=y-(1-a)*log(y);
00257 x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00258 log_x0=log(x0);
00259 }
00260 else
00261 {
00262 cerr << "this can't happen" << endl;
00263 ad_exit(1);
00264 }
00265 }
00266 else if (a>=1.0)
00267 {
00268 const double a0 = 3.31125922108741;
00269 const double b1 = 6.61053765625462;
00270 const double a1 = 11.6616720288968;
00271 const double b2 = 6.40691597760039;
00272 const double a2 = 4.28342155967104;
00273 const double b3 = 1.27364489782223;
00274 const double a3 = .213623493715853;
00275 const double b4 = .03611708101884203;
00276
00277 int sgn=1;
00278 double logtau;
00279 if (logP< log(0.5))
00280 {
00281 logtau=logP;
00282 sgn=-1;
00283 }
00284 else
00285 {
00286 logtau=logQ;
00287 sgn=1;
00288 }
00289
00290 double t=sqrt(-2.0*logtau);
00291
00292 double num = (((a3*t+a2)*t+a1)*t)+a0;
00293 double den = ((((b4*t+b3)*t+b2)*t)+b1)*t+1;
00294 double s=sgn*(t-num/den);
00295 double s2=s*s;
00296 double s3=s2*s;
00297 double s4=s3*s;
00298 double s5=s4*s;
00299 double roota=sqrt(a);
00300 double w=a+s*roota+(s2-1)/3.0+(s3-7.0*s)/(36.*roota)
00301 -(3.0*s4+7.0*s2-16)/(810.0*a)
00302 +(9.0*s5+256.0*s3-433.0*s)/(38880.0*a*roota);
00303 if (logP< log(0.5))
00304 {
00305 if (w>.15*(a+1))
00306 {
00307 x0=w;
00308 }
00309 else
00310 {
00311 double v=logP+gammln(a+1);
00312 double u1=exp(v+w)/a;
00313 double S1=1+u1/(a+1);
00314 double u2=exp((v+u1-log(S1))/a);
00315 double S2=1+u2/(a+1)+u2*u2/((a+1)*(a+2));
00316 double u3=exp((v+u2-log(S2))/a);
00317 double S3=1+u3/(a+1)+u3*u3/((a+1)*(a+2))
00318 + u3*u3*u3/((a+1)*(a+2)*(a+3));
00319 double z=exp((v+u3-log(S3))/a);
00320 if (z<.002*(a+1.0))
00321 {
00322 x0=z;
00323 }
00324 else
00325 {
00326 double sn=Sn(z,a);
00327 double zbar=exp((v+z-log(sn))/a);
00328 x0=zbar*(1.0-(a*log(zbar)-zbar-v+log(sn))/(a-zbar));
00329 }
00330 }
00331 log_x0=log(x0);
00332 }
00333 else
00334 {
00335 double u = -logB +(a-1.0)*log(w)-log(1.0+(1.0-a)/(1+w));
00336 x0=u;
00337 log_x0=log(x0);
00338 }
00339 }
00340 if (a==1.0)
00341 {
00342 x0=-log(1.0-y);
00343 log_x0=log(x0);
00344 }
00345 return log_x0-log(a);
00346 }