refactor
This commit is contained in:
parent
bfba653904
commit
ac6608b5b4
@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
|
||||
fun main () {
|
||||
|
||||
// work in context with linear operations
|
||||
DoubleLinearOpsTensorAlgebra.invoke {
|
||||
DoubleLinearOpsTensorAlgebra {
|
||||
|
||||
// set true value of x
|
||||
val trueX = fromArray(
|
||||
|
@ -19,7 +19,7 @@ fun main() {
|
||||
val randSeed = 100500L
|
||||
|
||||
// work in context with linear operations
|
||||
DoubleLinearOpsTensorAlgebra.invoke {
|
||||
DoubleLinearOpsTensorAlgebra {
|
||||
// take coefficient vector from normal distribution
|
||||
val alpha = randomNormal(
|
||||
intArrayOf(5),
|
||||
|
@ -3,6 +3,9 @@ plugins {
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
all {
|
||||
languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||
}
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
|
@ -109,9 +109,13 @@ public fun Array<IntArray>.toIntTensor(): IntTensor {
|
||||
}
|
||||
|
||||
public fun DoubleTensor.toDoubleArray(): DoubleArray {
|
||||
return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toDoubleArray()
|
||||
return DoubleArray(numElements) { i ->
|
||||
mutableBuffer[bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
public fun IntTensor.toIntArray(): IntArray {
|
||||
return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toIntArray()
|
||||
return IntArray(numElements) { i ->
|
||||
mutableBuffer[bufferStart + i]
|
||||
}
|
||||
}
|
@ -5,38 +5,35 @@ import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
|
||||
|
||||
|
||||
internal fun checkEmptyShape(shape: IntArray): Unit =
|
||||
internal fun checkEmptyShape(shape: IntArray) =
|
||||
check(shape.isNotEmpty()) {
|
||||
"Illegal empty shape provided"
|
||||
}
|
||||
|
||||
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
|
||||
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) =
|
||||
check(buffer.isNotEmpty()) {
|
||||
"Illegal empty buffer provided"
|
||||
}
|
||||
|
||||
internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
|
||||
internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
|
||||
check(buffer.size == shape.reduce(Int::times)) {
|
||||
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
|
||||
}
|
||||
|
||||
|
||||
internal fun <T> checkShapesCompatible(a: TensorStructure<T>, b: TensorStructure<T>): Unit =
|
||||
internal fun <T> checkShapesCompatible(a: TensorStructure<T>, b: TensorStructure<T>) =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
|
||||
}
|
||||
|
||||
|
||||
internal fun checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
internal fun checkTranspose(dim: Int, i: Int, j: Int) =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
internal fun <T> checkView(a: TensorStructure<T>, shape: IntArray): Unit =
|
||||
internal fun <T> checkView(a: TensorStructure<T>, shape: IntArray) =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
|
||||
internal fun checkSquareMatrix(shape: IntArray): Unit {
|
||||
internal fun checkSquareMatrix(shape: IntArray) {
|
||||
val n = shape.size
|
||||
check(n >= 2) {
|
||||
"Expected tensor with 2 or more dimensions, got size $n instead"
|
||||
@ -48,14 +45,12 @@ internal fun checkSquareMatrix(shape: IntArray): Unit {
|
||||
|
||||
internal fun DoubleTensorAlgebra.checkSymmetric(
|
||||
tensor: TensorStructure<Double>, epsilon: Double = 1e-6
|
||||
): Unit =
|
||||
) =
|
||||
check(tensor.eq(tensor.transpose(), epsilon)) {
|
||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||
}
|
||||
|
||||
internal fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
|
||||
tensor: DoubleTensor, epsilon: Double = 1e-6
|
||||
): Unit {
|
||||
internal fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) {
|
||||
checkSymmetric(tensor, epsilon)
|
||||
for (mat in tensor.matrixSequence())
|
||||
check(mat.asTensor().detLU().value() > 0.0) {
|
||||
|
@ -34,18 +34,6 @@ internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>>
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <T> BufferedTensor<T>.forEachVector(vectorAction: (BufferedTensor<T>) -> Unit) {
|
||||
for (vector in vectorSequence()) {
|
||||
vectorAction(vector)
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun <T> BufferedTensor<T>.forEachMatrix(matrixAction: (BufferedTensor<T>) -> Unit) {
|
||||
for (matrix in matrixSequence()) {
|
||||
matrixAction(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun dotHelper(
|
||||
a: MutableStructure2D<Double>,
|
||||
b: MutableStructure2D<Double>,
|
||||
|
@ -14,7 +14,6 @@ internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||
else -> this.toIntArray()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
|
||||
*/
|
||||
@ -31,7 +30,7 @@ internal fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||
|
||||
internal fun getRandomUnitVector(n: Int, seed: Long): DoubleArray {
|
||||
val unnorm = getRandomNormals(n, seed)
|
||||
val norm = sqrt(unnorm.map { it * it }.sum())
|
||||
val norm = sqrt(unnorm.sumOf { it * it })
|
||||
return unnorm.map { it / norm }.toDoubleArray()
|
||||
}
|
||||
|
||||
@ -45,23 +44,33 @@ internal fun minusIndexFrom(n: Int, i: Int): Int = if (i >= 0) i else {
|
||||
|
||||
internal fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.dimension, i)
|
||||
|
||||
internal fun format(value: Double, digits: Int = 4): String {
|
||||
internal fun format(value: Double, digits: Int = 4): String = buildString {
|
||||
val res = buildString {
|
||||
val ten = 10.0
|
||||
val approxOrder = if (value == 0.0) 0 else ceil(log10(abs(value))).toInt()
|
||||
val order = if (
|
||||
((value % ten) == 0.0) or
|
||||
(value == 1.0) or
|
||||
((value % ten) == 0.0) ||
|
||||
(value == 1.0) ||
|
||||
((1 / value) % ten == 0.0)
|
||||
) approxOrder else approxOrder - 1
|
||||
val lead = value / ten.pow(order)
|
||||
val leadDisplay = round(lead * ten.pow(digits)) / ten.pow(digits)
|
||||
val orderDisplay = if (order == 0) "" else if (order > 0) "E+$order" else "E$order"
|
||||
val valueDisplay = "$leadDisplay$orderDisplay"
|
||||
val res = if (value < 0.0) valueDisplay else " $valueDisplay"
|
||||
|
||||
if (value >= 0.0) append(' ')
|
||||
append(round(lead * ten.pow(digits)) / ten.pow(digits))
|
||||
when {
|
||||
order == 0 -> Unit
|
||||
order > 0 -> {
|
||||
append("e+")
|
||||
append(order)
|
||||
}
|
||||
else -> {
|
||||
append('e')
|
||||
append(order)
|
||||
}
|
||||
}
|
||||
}
|
||||
val fLength = digits + 6
|
||||
val endSpace = " ".repeat(fLength - res.length)
|
||||
return "$res$endSpace"
|
||||
append(res)
|
||||
repeat(fLength - res.length) { append(' ') }
|
||||
}
|
||||
|
||||
internal fun DoubleTensor.toPrettyString(): String = buildString {
|
||||
|
Loading…
Reference in New Issue
Block a user