Fix missing negative on residual
This commit is contained in:
parent
fdc3ffd164
commit
c9205415f2
@ -99,20 +99,18 @@ pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
|||||||
// Check for convergence
|
// Check for convergence
|
||||||
f = fun.eval(&xs);
|
f = fun.eval(&xs);
|
||||||
if (f - f_iminus1).abs() < tolerance {
|
if (f - f_iminus1).abs() < tolerance {
|
||||||
println!("{f} {f_iminus1}");
|
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
f_iminus1 = f;
|
f_iminus1 = f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update using polack-ribiere
|
// Update using polack-ribiere
|
||||||
let new_residual = fun.prime(&xs);
|
let new_residual = T::scale_prime(&fun.prime(&xs), -1.0);
|
||||||
let beta = new_residual
|
let beta = (new_residual.mul(&new_residual.sub(&prev_residual)))
|
||||||
.mul(&new_residual.sub(&prev_residual))
|
.div(&prev_residual.mul(&prev_residual));
|
||||||
.div(&new_residual.mul(&new_residual));
|
|
||||||
let beta = beta.max(0.0);
|
let beta = beta.max(0.0);
|
||||||
direction = new_residual.add(&beta.mul(&direction));
|
direction = new_residual.add(&beta.mul(&direction));
|
||||||
prev_residual = new_residual;
|
prev_residual = new_residual.clone();
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,6 +151,7 @@ mod test {
|
|||||||
res.best_fun_val, res.best_xs
|
res.best_fun_val, res.best_xs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
println!("Exitted with {:?}", res.exit_con);
|
||||||
if let ExitCondition::MaxIter = res.exit_con {
|
if let ExitCondition::MaxIter = res.exit_con {
|
||||||
panic!("Failed to converge to minima");
|
panic!("Failed to converge to minima");
|
||||||
}
|
}
|
||||||
@ -186,11 +185,12 @@ mod test {
|
|||||||
gamma: 0.9,
|
gamma: 0.9,
|
||||||
c: 0.01,
|
c: 0.01,
|
||||||
};
|
};
|
||||||
let res = conjugate_gradient(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search);
|
let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 10000, 1e-12, &line_search);
|
||||||
println!(
|
println!(
|
||||||
"Best val is {:?} for xs {:?}",
|
"Best val is {:?} for xs {:?}",
|
||||||
res.best_fun_val, res.best_xs
|
res.best_fun_val, res.best_xs
|
||||||
);
|
);
|
||||||
|
|
||||||
println!("Exit condition is: {:?}", res.exit_con);
|
println!("Exit condition is: {:?}", res.exit_con);
|
||||||
assert!(res.best_fun_val < 1e-7);
|
assert!(res.best_fun_val < 1e-7);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user