Testing linear structure
This commit is contained in:
parent
3535e51248
commit
93d3cb47be
@ -3,19 +3,17 @@ package space.kscience.kmath.tensors
|
||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||
public fun full(value: T, shape: IntArray): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||
public fun TensorType.fullLike(value: T): TensorType
|
||||
|
||||
public fun zeros(shape: IntArray): TensorType
|
||||
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
|
||||
public fun ones(shape: IntArray): TensorType
|
||||
public fun TensorType.onesLike(): TensorType
|
||||
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||
public fun full(shape: IntArray, value: T): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||
public fun TensorType.fullLike(value: T): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||
public fun eye(n: Int): TensorType
|
||||
|
||||
|
@ -20,23 +20,25 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||
}
|
||||
|
||||
override fun zeros(shape: IntArray): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.zeroesLike(): DoubleTensor {
|
||||
val shape = this.shape
|
||||
val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
|
||||
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||
checkEmptyShape(shape)
|
||||
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
override fun ones(shape: IntArray): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||
val shape = this.shape
|
||||
val buffer = DoubleArray(this.strides.linearSize) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.onesLike(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||
|
||||
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||
|
||||
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||
|
||||
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||
|
||||
override fun eye(n: Int): DoubleTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
@ -225,15 +227,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun full(shape: IntArray, value: Double): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
|
||||
override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
@ -47,4 +47,36 @@ class TestDoubleTensorAlgebra {
|
||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun linearStructure() = DoubleTensorAlgebra {
|
||||
val shape = intArrayOf(3)
|
||||
val tensorA = full(value = -4.5, shape = shape)
|
||||
val tensorB = full(value = 10.9, shape = shape)
|
||||
val tensorC = full(value = 789.3, shape = shape)
|
||||
val tensorD = full(value = -72.9, shape = shape)
|
||||
val tensorE = full(value = 553.1, shape = shape)
|
||||
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||
|
||||
val expected = fromArray(
|
||||
shape,
|
||||
(1..3).map {
|
||||
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||
}.toDoubleArray()
|
||||
)
|
||||
|
||||
val assignResult = zeros(shape)
|
||||
tensorA *= 15.8
|
||||
tensorB *= 1.5
|
||||
tensorB *= -tensorD
|
||||
tensorC *= 0.02
|
||||
tensorC /= tensorE
|
||||
assignResult += tensorA
|
||||
assignResult -= tensorB
|
||||
assignResult += tensorC
|
||||
assignResult += -39.4
|
||||
|
||||
assertTrue(expected.buffer.array() contentEquals result.buffer.array())
|
||||
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user