Testing linear structure
This commit is contained in:
parent
3535e51248
commit
93d3cb47be
@ -3,19 +3,17 @@ package space.kscience.kmath.tensors
|
|||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||||
|
public fun full(value: T, shape: IntArray): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||||
|
public fun TensorType.fullLike(value: T): TensorType
|
||||||
|
|
||||||
public fun zeros(shape: IntArray): TensorType
|
public fun zeros(shape: IntArray): TensorType
|
||||||
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
|
public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ?
|
||||||
public fun ones(shape: IntArray): TensorType
|
public fun ones(shape: IntArray): TensorType
|
||||||
public fun TensorType.onesLike(): TensorType
|
public fun TensorType.onesLike(): TensorType
|
||||||
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
|
||||||
public fun full(shape: IntArray, value: T): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
|
||||||
public fun TensorType.fullLike(value: T): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||||
public fun eye(n: Int): TensorType
|
public fun eye(n: Int): TensorType
|
||||||
|
|
||||||
|
@ -20,23 +20,25 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): DoubleTensor {
|
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
checkEmptyShape(shape)
|
||||||
}
|
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||||
|
|
||||||
override fun DoubleTensor.zeroesLike(): DoubleTensor {
|
|
||||||
val shape = this.shape
|
|
||||||
val buffer = DoubleArray(this.strides.linearSize) { 0.0 }
|
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun ones(shape: IntArray): DoubleTensor {
|
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
val shape = this.shape
|
||||||
|
val buffer = DoubleArray(this.strides.linearSize) { value }
|
||||||
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.onesLike(): DoubleTensor {
|
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||||
|
|
||||||
|
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||||
|
|
||||||
|
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||||
|
|
||||||
override fun eye(n: Int): DoubleTensor {
|
override fun eye(n: Int): DoubleTensor {
|
||||||
val shape = intArrayOf(n, n)
|
val shape = intArrayOf(n, n)
|
||||||
@ -225,15 +227,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun full(shape: IntArray, value: Double): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor {
|
override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
@ -47,4 +47,36 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun linearStructure() = DoubleTensorAlgebra {
|
||||||
|
val shape = intArrayOf(3)
|
||||||
|
val tensorA = full(value = -4.5, shape = shape)
|
||||||
|
val tensorB = full(value = 10.9, shape = shape)
|
||||||
|
val tensorC = full(value = 789.3, shape = shape)
|
||||||
|
val tensorD = full(value = -72.9, shape = shape)
|
||||||
|
val tensorE = full(value = 553.1, shape = shape)
|
||||||
|
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||||
|
|
||||||
|
val expected = fromArray(
|
||||||
|
shape,
|
||||||
|
(1..3).map {
|
||||||
|
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||||
|
}.toDoubleArray()
|
||||||
|
)
|
||||||
|
|
||||||
|
val assignResult = zeros(shape)
|
||||||
|
tensorA *= 15.8
|
||||||
|
tensorB *= 1.5
|
||||||
|
tensorB *= -tensorD
|
||||||
|
tensorC *= 0.02
|
||||||
|
tensorC /= tensorE
|
||||||
|
assignResult += tensorA
|
||||||
|
assignResult -= tensorB
|
||||||
|
assignResult += tensorC
|
||||||
|
assignResult += -39.4
|
||||||
|
|
||||||
|
assertTrue(expected.buffer.array() contentEquals result.buffer.array())
|
||||||
|
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user