done TODOs, deleted prints and added type of convergence to output of lm
This commit is contained in:
parent
64e563340a
commit
cfe8e9bfee
@ -120,13 +120,22 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
|
|||||||
*/
|
*/
|
||||||
public fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double>
|
public fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double>
|
||||||
|
|
||||||
data class LMResultInfo (
|
public enum class TypeOfConvergence{
|
||||||
|
inRHS_JtWdy,
|
||||||
|
inParameters,
|
||||||
|
inReducedChi_square,
|
||||||
|
noConvergence
|
||||||
|
}
|
||||||
|
|
||||||
|
public data class LMResultInfo (
|
||||||
var iterations:Int,
|
var iterations:Int,
|
||||||
var func_calls: Int,
|
var func_calls: Int,
|
||||||
var example_number: Int,
|
var example_number: Int,
|
||||||
var result_chi_sq: Double,
|
var result_chi_sq: Double,
|
||||||
var result_lambda: Double,
|
var result_lambda: Double,
|
||||||
var result_parameters: MutableStructure2D<Double>
|
var result_parameters: MutableStructure2D<Double>,
|
||||||
|
var typeOfConvergence: TypeOfConvergence,
|
||||||
|
var epsilon: Double
|
||||||
)
|
)
|
||||||
|
|
||||||
public fun lm(
|
public fun lm(
|
||||||
|
@ -723,9 +723,9 @@ public open class DoubleTensorAlgebra :
|
|||||||
weight_input: MutableStructure2D<Double>, dp_input: MutableStructure2D<Double>, p_min_input: MutableStructure2D<Double>, p_max_input: MutableStructure2D<Double>,
|
weight_input: MutableStructure2D<Double>, dp_input: MutableStructure2D<Double>, p_min_input: MutableStructure2D<Double>, p_max_input: MutableStructure2D<Double>,
|
||||||
c_input: MutableStructure2D<Double>, opts_input: DoubleArray, nargin: Int, example_number: Int): LinearOpsTensorAlgebra.LMResultInfo {
|
c_input: MutableStructure2D<Double>, opts_input: DoubleArray, nargin: Int, example_number: Int): LinearOpsTensorAlgebra.LMResultInfo {
|
||||||
|
|
||||||
var resultInfo = LinearOpsTensorAlgebra.LMResultInfo(0, 0, example_number, 0.0, 0.0, p_input)
|
var resultInfo = LinearOpsTensorAlgebra.LMResultInfo(0, 0, example_number, 0.0,
|
||||||
|
0.0, p_input, LinearOpsTensorAlgebra.TypeOfConvergence.noConvergence, 0.0)
|
||||||
|
|
||||||
val tensor_parameter = 0
|
|
||||||
val eps:Double = 2.2204e-16
|
val eps:Double = 2.2204e-16
|
||||||
|
|
||||||
var settings = LMSettings(0, 0, example_number)
|
var settings = LMSettings(0, 0, example_number)
|
||||||
@ -751,7 +751,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
var cvg_hist = 0
|
var cvg_hist = 0
|
||||||
|
|
||||||
if (length(t) != length(y_dat)) {
|
if (length(t) != length(y_dat)) {
|
||||||
println("lm.m error: the length of t must equal the length of y_dat")
|
// println("lm.m error: the length of t must equal the length of y_dat")
|
||||||
val length_t = length(t)
|
val length_t = length(t)
|
||||||
val length_y_dat = length(y_dat)
|
val length_y_dat = length(y_dat)
|
||||||
X2 = 0.0
|
X2 = 0.0
|
||||||
@ -761,10 +761,6 @@ public open class DoubleTensorAlgebra :
|
|||||||
sigma_y = 0
|
sigma_y = 0
|
||||||
R_sq = 0
|
R_sq = 0
|
||||||
cvg_hist = 0
|
cvg_hist = 0
|
||||||
|
|
||||||
// if (tensor_parameter != 0) { // Зачем эта проверка?
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var weight = weight_input
|
var weight = weight_input
|
||||||
@ -827,8 +823,8 @@ public open class DoubleTensorAlgebra :
|
|||||||
val y_init = feval(func, t, p, settings) // residual error using p_try
|
val y_init = feval(func, t, p, settings) // residual error using p_try
|
||||||
|
|
||||||
if (weight.shape.component1() == 1 || variance(weight) == 0.0) { // identical weights vector
|
if (weight.shape.component1() == 1 || variance(weight) == 0.0) { // identical weights vector
|
||||||
weight = ones(ShapeND(intArrayOf(Npnt, 1))).div(1 / abs(weight[0, 0])).as2D() // !!! need to check
|
weight = ones(ShapeND(intArrayOf(Npnt, 1))).div(1 / abs(weight[0, 0])).as2D()
|
||||||
println("using uniform weights for error analysis")
|
// println("using uniform weights for error analysis")
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
weight = make_column(weight)
|
weight = make_column(weight)
|
||||||
@ -844,8 +840,8 @@ public open class DoubleTensorAlgebra :
|
|||||||
J = lm_matx_ans[4]
|
J = lm_matx_ans[4]
|
||||||
|
|
||||||
if ( abs(JtWdy).max()!! < epsilon_1 ) {
|
if ( abs(JtWdy).max()!! < epsilon_1 ) {
|
||||||
println(" *** Your Initial Guess is Extremely Close to Optimal ***\n")
|
// println(" *** Your Initial Guess is Extremely Close to Optimal ***\n")
|
||||||
println(" *** epsilon_1 = %e\n$epsilon_1")
|
// println(" *** epsilon_1 = %e\n$epsilon_1")
|
||||||
stop = true
|
stop = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -885,11 +881,14 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
var delta_y = y_dat.minus(feval(func, t, p_try, settings)) // residual error using p_try
|
var delta_y = y_dat.minus(feval(func, t, p_try, settings)) // residual error using p_try
|
||||||
|
|
||||||
// TODO
|
for (i in 0 until delta_y.shape.component1()) { // floating point error; break
|
||||||
//if ~all(isfinite(delta_y)) // floating point error; break
|
for (j in 0 until delta_y.shape.component2()) {
|
||||||
// stop = 1;
|
if (delta_y[i, j] == Double.POSITIVE_INFINITY || delta_y[i, j] == Double.NEGATIVE_INFINITY) {
|
||||||
// break
|
stop = true
|
||||||
//end
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
settings.func_calls += 1
|
settings.func_calls += 1
|
||||||
|
|
||||||
@ -900,17 +899,16 @@ public open class DoubleTensorAlgebra :
|
|||||||
if (Update_Type == 2) { // Quadratic
|
if (Update_Type == 2) { // Quadratic
|
||||||
// One step of quadratic line update in the h direction for minimum X2
|
// One step of quadratic line update in the h direction for minimum X2
|
||||||
|
|
||||||
// TODO
|
val alpha = JtWdy.transpose().dot(h) / ( (X2_try.minus(X2)).div(2.0).plus(2 * JtWdy.transpose().dot(h)) )
|
||||||
// val alpha = JtWdy.transpose().dot(h) / ((X2_try.minus(X2)).div(2.0).plus(2 * JtWdy.transpose().dot(h)))
|
h = h.dot(alpha)
|
||||||
// alpha = JtWdy'*h / ( (X2_try - X2)/2 + 2*JtWdy'*h ) ;
|
p_try = p.plus(h).as2D() // update only [idx] elements
|
||||||
// h = alpha * h;
|
p_try = smallest_element_comparison(largest_element_comparison(p_min, p_try), p_max) // apply constraints
|
||||||
//
|
|
||||||
// p_try = p + h(idx); % update only [idx] elements
|
var delta_y = y_dat.minus(feval(func, t, p_try, settings)) // residual error using p_try
|
||||||
// p_try = min(max(p_min,p_try),p_max); % apply constraints
|
settings.func_calls += 1
|
||||||
//
|
|
||||||
// delta_y = y_dat - feval(func,t,p_try,c); % residual error using p_try
|
val tmp = delta_y.times(weight)
|
||||||
// func_calls = func_calls + 1;
|
X2_try = delta_y.as2D().transpose().dot(tmp) // Chi-squared error criteria
|
||||||
// тX2_try = delta_y' * ( delta_y .* weight ); % Chi-squared error criteria
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val rho = when (Update_Type) { // Nielsen
|
val rho = when (Update_Type) { // Nielsen
|
||||||
@ -924,9 +922,6 @@ public open class DoubleTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println()
|
|
||||||
println("rho = " + rho)
|
|
||||||
|
|
||||||
if (rho > epsilon_4) { // it IS significantly better
|
if (rho > epsilon_4) { // it IS significantly better
|
||||||
val dX2 = X2.minus(X2_old)
|
val dX2 = X2.minus(X2_old)
|
||||||
X2_old = X2
|
X2_old = X2
|
||||||
@ -984,15 +979,15 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
if (prnt > 1) {
|
if (prnt > 1) {
|
||||||
val chi_sq = X2 / DoF
|
val chi_sq = X2 / DoF
|
||||||
println("Iteration $settings | chi_sq=$chi_sq | lambda=$lambda")
|
// println("Iteration $settings | chi_sq=$chi_sq | lambda=$lambda")
|
||||||
print("param: ")
|
// print("param: ")
|
||||||
for (pn in 0 until Npar) {
|
// for (pn in 0 until Npar) {
|
||||||
print(p[pn, 0].toString() + " ")
|
// print(p[pn, 0].toString() + " ")
|
||||||
}
|
// }
|
||||||
print("\ndp/p: ")
|
// print("\ndp/p: ")
|
||||||
for (pn in 0 until Npar) {
|
// for (pn in 0 until Npar) {
|
||||||
print((h.as2D()[pn, 0] / p[pn, 0]).toString() + " ")
|
// print((h.as2D()[pn, 0] / p[pn, 0]).toString() + " ")
|
||||||
}
|
// }
|
||||||
resultInfo.iterations = settings.iteration
|
resultInfo.iterations = settings.iteration
|
||||||
resultInfo.func_calls = settings.func_calls
|
resultInfo.func_calls = settings.func_calls
|
||||||
resultInfo.result_chi_sq = chi_sq
|
resultInfo.result_chi_sq = chi_sq
|
||||||
@ -1004,22 +999,30 @@ public open class DoubleTensorAlgebra :
|
|||||||
// cvg_hst(iteration,:) = [ func_calls p' X2/DoF lambda ];
|
// cvg_hst(iteration,:) = [ func_calls p' X2/DoF lambda ];
|
||||||
|
|
||||||
if (abs(JtWdy).max()!! < epsilon_1 && settings.iteration > 2) {
|
if (abs(JtWdy).max()!! < epsilon_1 && settings.iteration > 2) {
|
||||||
println(" **** Convergence in r.h.s. (\"JtWdy\") ****")
|
// println(" **** Convergence in r.h.s. (\"JtWdy\") ****")
|
||||||
println(" **** epsilon_1 = $epsilon_1")
|
// println(" **** epsilon_1 = $epsilon_1")
|
||||||
|
resultInfo.typeOfConvergence = LinearOpsTensorAlgebra.TypeOfConvergence.inRHS_JtWdy
|
||||||
|
resultInfo.epsilon = epsilon_1
|
||||||
stop = true
|
stop = true
|
||||||
}
|
}
|
||||||
if ((abs(h.as2D()).div(abs(p) + 1e-12)).max() < epsilon_2 && settings.iteration > 2) {
|
if ((abs(h.as2D()).div(abs(p) + 1e-12)).max() < epsilon_2 && settings.iteration > 2) {
|
||||||
println(" **** Convergence in Parameters ****")
|
// println(" **** Convergence in Parameters ****")
|
||||||
println(" **** epsilon_2 = $epsilon_2")
|
// println(" **** epsilon_2 = $epsilon_2")
|
||||||
|
resultInfo.typeOfConvergence = LinearOpsTensorAlgebra.TypeOfConvergence.inParameters
|
||||||
|
resultInfo.epsilon = epsilon_2
|
||||||
stop = true
|
stop = true
|
||||||
}
|
}
|
||||||
if (X2 / DoF < epsilon_3 && settings.iteration > 2) {
|
if (X2 / DoF < epsilon_3 && settings.iteration > 2) {
|
||||||
println(" **** Convergence in reduced Chi-square **** ")
|
// println(" **** Convergence in reduced Chi-square **** ")
|
||||||
println(" **** epsilon_3 = $epsilon_3")
|
// println(" **** epsilon_3 = $epsilon_3")
|
||||||
|
resultInfo.typeOfConvergence = LinearOpsTensorAlgebra.TypeOfConvergence.inReducedChi_square
|
||||||
|
resultInfo.epsilon = epsilon_3
|
||||||
stop = true
|
stop = true
|
||||||
}
|
}
|
||||||
if (settings.iteration == MaxIter) {
|
if (settings.iteration == MaxIter) {
|
||||||
println(" !! Maximum Number of Iterations Reached Without Convergence !!")
|
// println(" !! Maximum Number of Iterations Reached Without Convergence !!")
|
||||||
|
resultInfo.typeOfConvergence = LinearOpsTensorAlgebra.TypeOfConvergence.noConvergence
|
||||||
|
resultInfo.epsilon = 0.0
|
||||||
stop = true
|
stop = true
|
||||||
}
|
}
|
||||||
} // --- End of Main Loop
|
} // --- End of Main Loop
|
||||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.tensors.core
|
|||||||
|
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.core.internal.LMSettings
|
import space.kscience.kmath.tensors.core.internal.LMSettings
|
||||||
import space.kscience.kmath.testutils.assertBufferEquals
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
import kotlin.math.roundToInt
|
import kotlin.math.roundToInt
|
||||||
@ -290,5 +291,6 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
assertEquals(1, result.example_number)
|
assertEquals(1, result.example_number)
|
||||||
assertEquals(0.9131368192633, (result.result_chi_sq * 1e13).roundToLong() / 1e13)
|
assertEquals(0.9131368192633, (result.result_chi_sq * 1e13).roundToLong() / 1e13)
|
||||||
assertEquals(3.7790980 * 1e-7, (result.result_lambda * 1e13).roundToLong() / 1e13)
|
assertEquals(3.7790980 * 1e-7, (result.result_lambda * 1e13).roundToLong() / 1e13)
|
||||||
|
assertEquals(result.typeOfConvergence, LinearOpsTensorAlgebra.TypeOfConvergence.inParameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user