KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
7 changed files with 24 additions and 46 deletions
Showing only changes of commit 51eca003af - Show all commits

View File

@ -2654,10 +2654,6 @@ public final class space/kscience/kmath/tensors/LinearOpsTensorAlgebra$DefaultIm
public static synthetic fun symEig$default (Lspace/kscience/kmath/tensors/LinearOpsTensorAlgebra;Lspace/kscience/kmath/nd/MutableStructureND;ZILjava/lang/Object;)Lkotlin/Pair;
}
public abstract interface class space/kscience/kmath/tensors/ReduceOpsTensorAlgebra : space/kscience/kmath/tensors/TensorAlgebra {
public abstract fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
}
public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
public abstract fun copy (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun diagonalEmbedding (Lspace/kscience/kmath/nd/MutableStructureND;III)Lspace/kscience/kmath/nd/MutableStructureND;
@ -2686,6 +2682,7 @@ public abstract interface class space/kscience/kmath/tensors/TensorAlgebra {
public abstract fun timesAssign (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)V
public abstract fun transpose (Lspace/kscience/kmath/nd/MutableStructureND;II)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
public abstract fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public abstract fun zeroesLike (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
@ -2809,16 +2806,6 @@ public final class space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebr
public static final fun DoubleLinearOpsTensorAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}
public final class space/kscience/kmath/tensors/core/DoubleReduceOpsTensorAlgebra : space/kscience/kmath/tensors/core/DoubleTensorAlgebra, space/kscience/kmath/tensors/ReduceOpsTensorAlgebra {
public fun <init> ()V
public synthetic fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
public fun value (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Ljava/lang/Double;
}
public final class space/kscience/kmath/tensors/core/DoubleReduceOpsTensorAlgebraKt {
public static final fun DoubleReduceOpsTensorAlgebra (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}
public final class space/kscience/kmath/tensors/core/DoubleTensor : space/kscience/kmath/tensors/core/BufferedTensor {
}
@ -2892,6 +2879,8 @@ public class space/kscience/kmath/tensors/core/DoubleTensorAlgebra : space/kscie
public fun transpose (Lspace/kscience/kmath/tensors/core/DoubleTensor;II)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public synthetic fun unaryMinus (Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;
public fun unaryMinus (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public synthetic fun value (Lspace/kscience/kmath/nd/MutableStructureND;)Ljava/lang/Object;
public fun value (Lspace/kscience/kmath/tensors/core/DoubleTensor;)Ljava/lang/Double;
public synthetic fun view (Lspace/kscience/kmath/nd/MutableStructureND;[I)Lspace/kscience/kmath/nd/MutableStructureND;
public fun view (Lspace/kscience/kmath/tensors/core/DoubleTensor;[I)Lspace/kscience/kmath/tensors/core/DoubleTensor;
public synthetic fun viewAs (Lspace/kscience/kmath/nd/MutableStructureND;Lspace/kscience/kmath/nd/MutableStructureND;)Lspace/kscience/kmath/nd/MutableStructureND;

View File

@ -1,7 +0,0 @@
package space.kscience.kmath.tensors
public interface ReduceOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
public fun TensorType.value(): T
}

View File

@ -5,6 +5,8 @@ import space.kscience.kmath.tensors.core.DoubleTensor
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.value(): T
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(value: T, shape: IntArray): TensorType

View File

@ -1,18 +0,0 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.ReduceOpsTensorAlgebra
public class DoubleReduceOpsTensorAlgebra:
DoubleTensorAlgebra(),
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
override fun DoubleTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.array()[this.bufferStart]
}
}
public inline fun <R> DoubleReduceOpsTensorAlgebra(block: DoubleReduceOpsTensorAlgebra.() -> R): R =
DoubleReduceOpsTensorAlgebra().block()

View File

@ -7,6 +7,13 @@ import kotlin.math.abs
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
override fun DoubleTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.array()[this.bufferStart]
}
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
checkEmptyShape(shape)
checkEmptyDoubleBuffer(buffer)

View File

@ -72,11 +72,8 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test
fun testScalarProduct() = DoubleLinearOpsTensorAlgebra {
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8,2.5, 6.8))
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5,2.6, 6.4))
DoubleReduceOpsTensorAlgebra {
assertEquals(a.dot(b).value(), 59.92)
}
val a = fromArray(intArrayOf(3), doubleArrayOf(1.8, 2.5, 6.8))
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
assertEquals(a.dot(b).value(), 59.92)
}
}

View File

@ -10,7 +10,7 @@ import kotlin.test.assertTrue
class TestDoubleTensor {
@Test
fun valueTest() = DoubleReduceOpsTensorAlgebra {
fun valueTest() = DoubleTensorAlgebra {
val value = 12.5
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
assertEquals(tensor.value(), value)
@ -37,5 +37,13 @@ class TestDoubleTensor {
vector[0] = 109.56
assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
tensor.matrixSequence().forEach {
val a = it.asTensor()
val secondRow = a[1].as1D()
val secondColumn = a.transpose(0,1)[1].as1D()
assertEquals(secondColumn[0], 77.89)
assertEquals(secondRow[1], secondColumn[1])
}
}
}