Skip to content

Commit

Permalink
Fix gradient of piewise-linear
Browse files Browse the repository at this point in the history
  • Loading branch information
4ment committed Nov 16, 2023
1 parent 9695602 commit ab3524e
Showing 1 changed file with 45 additions and 41 deletions.
86 changes: 45 additions & 41 deletions src/phyc/demographicmodels.c
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ double _coalescent_piecewise_linear_grid_calculate( Coalescent* coal ){
if ( coal->need_update_intervals ) {
coal->update_intervals(coal);
}

if ( coal->need_update ) {
coal->logP = 0;
size_t currentGridIndex = 0;
Expand All @@ -2002,15 +2002,16 @@ double _coalescent_piecewise_linear_grid_calculate( Coalescent* coal ){
double lchoose2;
double t = 0;
double popSizeCurrent = popSizeGridStart; // can be a grid pop size or any interval

for(size_t i = 0; i < coal->n; i++){
t += coal->times[i];
if(coal->times[i] != 0.0){
lchoose2 = CHOOSE2(coal->lineages[i]);
double popSize;
if (coal->nodes[i]->index >= 0){
// after the last grid point we use piecewise constant
if(currentGridIndex >= coal->gridCount){
// handle case when consecutive pop sizes are equal: equivalent to constant
if(currentGridIndex >= coal->gridCount || popSizeGridEnd == popSizeGridStart){
popSize = popSizeGridEnd;
}
else{
Expand All @@ -2026,16 +2027,11 @@ double _coalescent_piecewise_linear_grid_calculate( Coalescent* coal ){

// integral
double integral;
if (currentGridIndex < coal->gridCount){
// coal->times[i] == t_i-t_i-1
integral = lchoose2 * coal->times[i] * (log(popSize) - log(popSizeCurrent))/(popSize - popSizeCurrent);
if (currentGridIndex < coal->gridCount && popSizeGridEnd != popSizeGridStart){
coal->logP -= lchoose2 * coal->times[i] * (log(popSize) - log(popSizeCurrent))/(popSize - popSizeCurrent);

}
else{
coal->logP -= lchoose2 * coal->times[i] / popSize;
integral = lchoose2 * coal->times[i] / popSize;
// fprintf(stderr, "%f\n", integral);
}

popSizeCurrent = popSize;
Expand Down Expand Up @@ -2068,8 +2064,7 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){

if(compute_grad_theta){
memset(coal->gradient, 0, Parameters_count(coal->p)*sizeof(double));
// analytical derivatives not working yet
// if(useFiniteDifferences){
if(useFiniteDifferences){
for(size_t i = 0; i < Parameters_count(coal->p); i++){
double value = Parameters_value(coal->p, i);
Parameters_set_value(coal->p, i, value+eps);
Expand All @@ -2080,7 +2075,7 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){
Parameters_set_value(coal->p, i, value);
}
compute_grad_theta = false;
// }
}

offset += Parameters_count(coal->p);
}
Expand Down Expand Up @@ -2120,7 +2115,7 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){

if ( compute_grad_time || compute_grad_theta ) {
double* heightGradient = NULL;

if(compute_grad_time){
heightGradient = dvector(Tree_tip_count(coal->tree)-1);
memset(heightGradient, 0.0, sizeof(double)*(Tree_tip_count(coal->tree)-1));
Expand All @@ -2136,16 +2131,18 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){
double t = 0;
double popSizeCurrent = popSizeGridStart; // can be a grid pop size or any interval
Node** nodes = Tree_nodes(coal->tree);

for(size_t i = 0; i < coal->n; i++){
t += coal->times[i];
if(coal->times[i] != 0.0){
lchoose2 = CHOOSE2(coal->lineages[i]);
double popSize;
double logPopSize;
double deltaGrid = timeGridEnd - timeGridStart;

if (coal->nodes[i]->index >= 0){
// after the last grid point we use piecewise constant
if(currentGridIndex >= coal->gridCount){
if(currentGridIndex >= coal->gridCount || popSizeGridEnd == popSizeGridStart){
popSize = popSizeGridEnd;
logPopSize = log(popSize);
if( coal->iscoalescent[i]){
Expand All @@ -2156,19 +2153,19 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){
}
}
else{
popSize = popSizeGridStart + (popSizeGridEnd - popSizeGridStart) * (t - timeGridStart)/(timeGridEnd - timeGridStart);
popSize = popSizeGridStart + (popSizeGridEnd - popSizeGridStart) * (t - timeGridStart)/deltaGrid;
logPopSize = log(popSize);

if( coal->iscoalescent[i]){
coal->logP -= logPopSize;
if(compute_grad_time){
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/(timeGridEnd - timeGridStart);
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/deltaGrid;
size_t node_class_id = Node_class_id(nodes[coal->nodes[i]->index]);
heightGradient[node_class_id] -= dpopSizedt/popSize;
}

if(compute_grad_theta){
double c = (t - timeGridStart)/(timeGridEnd - timeGridStart);
double c = (t - timeGridStart)/deltaGrid;
coal->gradient[currentGridIndex] -= (1.0 - c)/popSize;
coal->gradient[currentGridIndex + 1] -= c/popSize;
}
Expand All @@ -2181,60 +2178,67 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){
}

// integral
if (currentGridIndex < coal->gridCount){
if (currentGridIndex < coal->gridCount && popSizeGridEnd != popSizeGridStart){
coal->logP -= lchoose2 * coal->times[i] * (log(popSize) - log(popSizeCurrent))/(popSize - popSizeCurrent);

if(compute_grad_theta){
double gradStart = 0.0;
double gradEnd = 0.0;
// interval is 2 consecutive grid points or starts with a sampling event at time 0 followed by grid
if (coal->nodes[i]->index < 0 && (coal->nodes[i-1]->index < 0 || t - coal->times[i] == 0.0)){
gradStart = (popSizeGridStart*(log(popSizeGridEnd) + 1.0 - log(popSizeGridStart)) - popSizeGridEnd)/(popSizeGridStart*pow(popSizeGridEnd - popSizeGridStart, 2.0));
gradEnd = (popSizeGridEnd*(log(popSizeGridStart) + 1.0 - log(popSizeGridEnd)) - popSizeGridStart)/(popSizeGridEnd*pow(popSizeGridEnd - popSizeGridStart, 2.0));
double logPopSizeGridEnd = log(popSizeGridEnd);
double logPopSizeGridStart = log(popSizeGridStart);
gradStart = (popSizeGridStart*(logPopSizeGridEnd + 1.0 - logPopSizeGridStart) - popSizeGridEnd)/(popSizeGridStart*pow(popSizeGridEnd - popSizeGridStart, 2.0));
gradEnd = (popSizeGridEnd*(logPopSizeGridStart + 1.0 - logPopSizeGridEnd) - popSizeGridStart)/(popSizeGridEnd*pow(popSizeGridEnd - popSizeGridStart, 2.0));
}
// interval ends with a grid point
else if (coal->nodes[i]->index < 0){
double dpopSizeCurrent_dEnd = (t-coal->times[i-1] - timeGridStart)/(timeGridEnd - timeGridStart);
double logPopSizeGridEnd = log(popSizeGridEnd);
double logPopSizeCurrent = log(popSizeCurrent);
double dpopSizeCurrent_dEnd = (t-coal->times[i] - timeGridStart)/deltaGrid;
double dpopSizeCurrent_dStart = 1.0 - dpopSizeCurrent_dEnd;

gradStart = -(dpopSizeCurrent_dStart*(popSizeCurrent*(log(popSizeCurrent) - log(popSizeGridEnd) - 1.0) + popSizeGridEnd))/(popSizeCurrent*pow(popSizeGridEnd - popSizeCurrent, 2.0));
gradEnd = ((popSizeGridEnd - popSizeCurrent)*(1./popSizeGridEnd - dpopSizeCurrent_dEnd/popSizeCurrent) + (dpopSizeCurrent_dEnd - 1.0)*(log(popSizeGridEnd) - log(popSizeCurrent)))/pow(popSizeGridEnd - popSizeCurrent, 2.0);
gradStart = -(dpopSizeCurrent_dStart*(popSizeCurrent*(logPopSizeCurrent - logPopSizeGridEnd - 1.0) + popSizeGridEnd))/(popSizeCurrent*pow(popSizeGridEnd - popSizeCurrent, 2.0));
gradEnd = ((popSizeGridEnd - popSizeCurrent)*(1./popSizeGridEnd - dpopSizeCurrent_dEnd/popSizeCurrent) + (dpopSizeCurrent_dEnd - 1.0)*(logPopSizeGridEnd - logPopSizeCurrent))/pow(popSizeGridEnd - popSizeCurrent, 2.0);
}
// interval starts with a grid point or sampling event at time 0
else if (coal->nodes[i-1]->index < 0 || t - coal->times[i] == 0.0){
double dPopSize_dEnd = (t - timeGridStart)/(timeGridEnd - timeGridStart);
double logPopSizeGridStart = log(popSizeGridStart);
double dPopSize_dEnd = (t - timeGridStart)/deltaGrid;
double dPopSize_dStart = 1.0 - dPopSize_dEnd;
gradStart = ((popSize - popSizeGridStart)*(dPopSize_dStart/popSize - 1./popSizeGridStart) + (dPopSize_dStart - 1.0)*(log(popSizeGridStart) - log(popSize)))/pow(popSize - popSizeGridStart, 2.0);
gradEnd = (dPopSize_dEnd*(popSize*(log(popSizeGridStart) + 1.0 - log(popSize)) - popSizeGridStart))/(popSize*pow(popSize - popSizeGridStart, 2.0));

gradStart = ((popSize - popSizeGridStart)*(dPopSize_dStart/popSize - 1./popSizeGridStart) + (dPopSize_dStart - 1.0)*(logPopSizeGridStart - logPopSize))/pow(popSize - popSizeGridStart, 2.0);
gradEnd = (dPopSize_dEnd*(popSize*(logPopSizeGridStart + 1.0 - logPopSize) - popSizeGridStart))/(popSize*pow(popSize - popSizeGridStart, 2.0));
}
else{
double dpopSizeCurrent_dEnd = (t-coal->times[i-1] - timeGridStart)/(timeGridEnd - timeGridStart);
double logPopSizeCurrent = log(popSizeCurrent);
double dpopSizeCurrent_dEnd = (t - coal->times[i] - timeGridStart)/deltaGrid;
double dpopSizeCurrent_dStart = 1.0 - dpopSizeCurrent_dEnd;
double dPopSize_dEnd = (t - timeGridStart)/(timeGridEnd - timeGridStart);

double dPopSize_dEnd = (t - timeGridStart)/deltaGrid;
double dPopSize_dStart = 1.0 - dPopSize_dEnd;

gradStart = (dPopSize_dStart/popSize - dpopSizeCurrent_dStart/popSizeCurrent)/(popSize - popSizeCurrent) - (dPopSize_dStart - dpopSizeCurrent_dStart)*(log(popSize) - log(popSizeCurrent))/pow(popSize - popSizeCurrent, 2.0);
gradEnd = (dPopSize_dEnd/popSize - dpopSizeCurrent_dEnd/popSizeCurrent)/(popSize - popSizeCurrent) - (dPopSize_dEnd - dpopSizeCurrent_dEnd)*(log(popSize) - log(popSizeCurrent))/pow(popSize - popSizeCurrent, 2.0);
gradStart = (dPopSize_dStart/popSize - dpopSizeCurrent_dStart/popSizeCurrent)/(popSize - popSizeCurrent) - (dPopSize_dStart - dpopSizeCurrent_dStart)*(logPopSize - logPopSizeCurrent)/pow(popSize - popSizeCurrent, 2.0);
gradEnd = (dPopSize_dEnd/popSize - dpopSizeCurrent_dEnd/popSizeCurrent)/(popSize - popSizeCurrent) - (dPopSize_dEnd - dpopSizeCurrent_dEnd)*(logPopSize - logPopSizeCurrent)/pow(popSize - popSizeCurrent, 2.0);
}
coal->gradient[currentGridIndex] += -lchoose2*coal->times[i]*gradStart;
coal->gradient[currentGridIndex+1] += -lchoose2*coal->times[i]*gradEnd;
}

if(compute_grad_time){
double deltaLogPopSize = logPopSize - log(popSizeCurrent);
if(coal->nodes[i]->index >= 0 && coal->iscoalescent[i]){
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/(timeGridEnd - timeGridStart);
double d = -coal->times[i]*dpopSizedt*(log(popSize) - log(popSizeCurrent))/pow(popSize - popSizeCurrent, 2);
d += coal->times[i]*dpopSizedt/(popSize*(popSize - popSizeCurrent)) + (log(popSize) - log(popSizeCurrent))/(popSize - popSizeCurrent);
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/deltaGrid;
double d = -coal->times[i]*dpopSizedt*deltaLogPopSize/pow(popSize - popSizeCurrent, 2);
d += coal->times[i]*dpopSizedt/(popSize*(popSize - popSizeCurrent)) + deltaLogPopSize/(popSize - popSizeCurrent);

size_t node_class_id = Node_class_id(nodes[coal->nodes[i]->index]);
heightGradient[node_class_id] -= lchoose2*d;
}
if(coal->nodes[i-1]->index >= 0 && coal->iscoalescent[i-1]){
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/(timeGridEnd - timeGridStart);
double d = coal->times[i]*dpopSizedt*(log(popSize) - log(popSizeCurrent))/pow(popSize - popSizeCurrent, 2);
d += -coal->times[i]*dpopSizedt/(popSizeCurrent*(popSize - popSizeCurrent)) - (log(popSize) - log(popSizeCurrent))/(popSize - popSizeCurrent);
double dpopSizedt = (popSizeGridEnd - popSizeGridStart)/deltaGrid;
double d = coal->times[i]*dpopSizedt*deltaLogPopSize/pow(popSize - popSizeCurrent, 2);
d += -coal->times[i]*dpopSizedt/(popSizeCurrent*(popSize - popSizeCurrent)) - deltaLogPopSize/(popSize - popSizeCurrent);

size_t start_node_class_id = Node_class_id(nodes[coal->nodes[i-1]->index]);
heightGradient[start_node_class_id] -= lchoose2*d;
Expand All @@ -2245,7 +2249,7 @@ void _coalescent_piecewise_linear_grid_gradient_time( Coalescent* coal ){
coal->logP -= lchoose2 * coal->times[i] / popSize;

if(compute_grad_theta){
coal->gradient[currentGridIndex+1] += lchoose2 * coal->times[i] / (popSize*popSize);
coal->gradient[currentGridIndex] += lchoose2 * coal->times[i] / (popSize*popSize);
}

if(compute_grad_time){
Expand Down

0 comments on commit ab3524e

Please sign in to comment.