Testing linear structure

This commit is contained in:
Roland Grinis 2021-03-19 20:10:08 +00:00
parent 3535e51248
commit 93d3cb47be
3 changed files with 51 additions and 28 deletions

View File

@ -3,19 +3,17 @@ package space.kscience.kmath.tensors
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring // https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> { public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(value: T, shape: IntArray): TensorType
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
public fun TensorType.fullLike(value: T): TensorType
public fun zeros(shape: IntArray): TensorType public fun zeros(shape: IntArray): TensorType
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ? public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
public fun ones(shape: IntArray): TensorType public fun ones(shape: IntArray): TensorType
public fun TensorType.onesLike(): TensorType public fun TensorType.onesLike(): TensorType
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(shape: IntArray, value: T): TensorType
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
public fun TensorType.fullLike(value: T): TensorType
//https://pytorch.org/docs/stable/generated/torch.eye.html //https://pytorch.org/docs/stable/generated/torch.eye.html
public fun eye(n: Int): TensorType public fun eye(n: Int): TensorType

View File

@ -20,23 +20,25 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return DoubleTensor(newShape, this.buffer.array(), newStart) return DoubleTensor(newShape, this.buffer.array(), newStart)
} }
override fun zeros(shape: IntArray): DoubleTensor { override fun full(value: Double, shape: IntArray): DoubleTensor {
TODO("Not yet implemented") checkEmptyShape(shape)
} val buffer = DoubleArray(shape.reduce(Int::times)) { value }
override fun DoubleTensor.zeroesLike(): DoubleTensor {
val shape = this.shape
val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
} }
override fun ones(shape: IntArray): DoubleTensor { override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
TODO("Not yet implemented") val shape = this.shape
val buffer = DoubleArray(this.strides.linearSize) { value }
return DoubleTensor(shape, buffer)
} }
override fun DoubleTensor.onesLike(): DoubleTensor { override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
TODO("Not yet implemented")
} override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
override fun eye(n: Int): DoubleTensor { override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n) val shape = intArrayOf(n, n)
@ -225,15 +227,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun full(shape: IntArray, value: Double): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor { override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }

View File

@ -47,4 +47,36 @@ class TestDoubleTensorAlgebra {
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test
fun linearStructure() = DoubleTensorAlgebra {
val shape = intArrayOf(3)
val tensorA = full(value = -4.5, shape = shape)
val tensorB = full(value = 10.9, shape = shape)
val tensorC = full(value = 789.3, shape = shape)
val tensorD = full(value = -72.9, shape = shape)
val tensorE = full(value = 553.1, shape = shape)
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
val expected = fromArray(
shape,
(1..3).map {
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
}.toDoubleArray()
)
val assignResult = zeros(shape)
tensorA *= 15.8
tensorB *= 1.5
tensorB *= -tensorD
tensorC *= 0.02
tensorC /= tensorE
assignResult += tensorA
assignResult -= tensorB
assignResult += tensorC
assignResult += -39.4
assertTrue(expected.buffer.array() contentEquals result.buffer.array())
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
}
} }