OrderedTensorAlgebra

This commit is contained in:
Roland Grinis 2021-03-15 19:06:33 +00:00
parent b6a5fbfc14
commit b227a82a80
10 changed files with 125 additions and 108 deletions

View File

@ -8,6 +8,12 @@ public interface ComplexTensorAlgebra<T,
//https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
public fun RealTensorType.viewAsComplex(): ComplexTensorType
// Embed a real tensor as real + i * imaginary
public fun RealTensorType.cartesianEmbedding(imaginary: RealTensorType): ComplexTensorType
// Embed a real tensor as real * exp(i * angle)
public fun RealTensorType.polarEmbedding(angle: RealTensorType): ComplexTensorType
//https://pytorch.org/docs/stable/generated/torch.angle.html
public fun ComplexTensorType.angle(): RealTensorType

View File

@ -1,8 +1,8 @@
package space.kscience.kmath.tensors
public class RealAnalyticTensorAlgebra:
public class DoubleAnalyticTensorAlgebra:
AnalyticTensorAlgebra<Double, DoubleTensor>,
RealTensorAlgebra()
DoubleTensorAlgebra()
{
override fun DoubleTensor.exp(): DoubleTensor {
TODO("Not yet implemented")
@ -146,5 +146,5 @@ public class RealAnalyticTensorAlgebra:
}
public inline fun <R> RealAnalyticTensorAlgebra(block: RealAnalyticTensorAlgebra.() -> R): R =
RealAnalyticTensorAlgebra().block()
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =
DoubleAnalyticTensorAlgebra().block()

View File

@ -2,27 +2,12 @@ package space.kscience.kmath.tensors
public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor>,
RealTensorAlgebra() {
override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 }
val res = DoubleTensor(shape, buffer)
for (i in 0 until n) {
res[intArrayOf(i, i)] = 1.0
}
return res
}
DoubleTensorAlgebra() {
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
TODO("Alya")
override fun DoubleTensor.inv(): DoubleTensor {
TODO("Not yet implemented")
}
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
TODO("Alya")
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
// todo checks
val luTensor = this.copy()
@ -115,15 +100,6 @@ public class DoubleLinearOpsTensorAlgebra :
return Triple(p, l, u)
}
override fun DoubleTensor.det(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.inv(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cholesky(): DoubleTensor {
TODO("Not yet implemented")
}

View File

@ -0,0 +1,41 @@
package space.kscience.kmath.tensors
public class DoubleOrderedTensorAlgebra:
OrderedTensorAlgebra<Double, DoubleTensor>,
DoubleTensorAlgebra()
{
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
TODO("Not yet implemented")
}
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
TODO("Not yet implemented")
}
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
}
public inline fun <R> DoubleOrderedTensorAlgebra(block: DoubleOrderedTensorAlgebra.() -> R): R =
DoubleOrderedTensorAlgebra().block()

View File

@ -1,7 +1,7 @@
package space.kscience.kmath.tensors
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
override fun DoubleTensor.value(): Double {
check(this.shape contentEquals intArrayOf(1)) {
@ -27,6 +27,15 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
override fun DoubleTensor.onesLike(): DoubleTensor {
TODO("Not yet implemented")
}
override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 }
val res = DoubleTensor(shape, buffer)
for (i in 0 until n) {
res[intArrayOf(i, i)] = 1.0
}
return res
}
override fun DoubleTensor.copy(): DoubleTensor {
// should be rework as soon as copy() method for NDBuffer will be available
@ -199,36 +208,12 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
TODO("Not yet implemented")
}
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
TODO("Alya")
}
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
TODO("Not yet implemented")
}
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
TODO("Not yet implemented")
}
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
TODO("Not yet implemented")
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
TODO("Alya")
}
override fun cat(tensors: List<DoubleTensor>, dim: Int): DoubleTensor {
@ -276,7 +261,11 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
TODO("Not yet implemented")
}
override fun DoubleTensor.det(): DoubleTensor {
TODO("Not yet implemented")
}
}
public inline fun <R> RealTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
RealTensorAlgebra().block()
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
DoubleTensorAlgebra().block()

View File

@ -4,21 +4,6 @@ package space.kscience.kmath.tensors
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
TensorPartialDivisionAlgebra<T, TensorType> {
//https://pytorch.org/docs/stable/generated/torch.eye.html
public fun eye(n: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.matmul.html
public infix fun TensorType.dot(other: TensorType): TensorType
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TensorType
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
public fun TensorType.det(): TensorType
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
public fun TensorType.inv(): TensorType

View File

@ -0,0 +1,29 @@
package space.kscience.kmath.tensors
public interface OrderedTensorAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
public fun TensorType.cummax(dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
public fun TensorType.cummin(dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
public fun maximum(lhs: TensorType, rhs: TensorType)
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
public fun minimum(lhs: TensorType, rhs: TensorType)
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
}

View File

@ -17,6 +17,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
public fun TensorType.fullLike(value: T): TensorType
//https://pytorch.org/docs/stable/generated/torch.eye.html
public fun eye(n: Int): TensorType
public fun TensorType.copy(): TensorType
public operator fun T.plus(other: TensorType): TensorType
@ -46,6 +49,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.view(shape: IntArray): TensorType
public fun TensorType.viewAs(other: TensorType): TensorType
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
public fun TensorType.det(): TensorType
//https://pytorch.org/docs/stable/generated/torch.abs.html
public fun TensorType.abs(): TensorType
@ -61,29 +67,14 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.cumprod.html#torch.cumprod
public fun TensorType.cumprod(dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.matmul.html
public infix fun TensorType.dot(other: TensorType): TensorType
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
public fun TensorType.cummax(dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
public fun TensorType.cummin(dim: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
public fun maximum(lhs: TensorType, rhs: TensorType)
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
public fun minimum(lhs: TensorType, rhs: TensorType)
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TensorType
//https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
public fun cat(tensors: List<TensorType>, dim: Int): TensorType

View File

@ -9,7 +9,7 @@ import kotlin.test.assertTrue
class TestRealTensor {
@Test
fun valueTest() = RealTensorAlgebra {
fun valueTest() = DoubleTensorAlgebra {
val value = 12.5
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
assertEquals(tensor.value(), value)

View File

@ -7,14 +7,14 @@ import kotlin.test.assertTrue
class TestRealTensorAlgebra {
@Test
fun doublePlus() = RealTensorAlgebra {
fun doublePlus() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
}
@Test
fun transpose1x1() = RealTensorAlgebra {
fun transpose1x1() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0)
@ -23,7 +23,7 @@ class TestRealTensorAlgebra {
}
@Test
fun transpose3x2() = RealTensorAlgebra {
fun transpose3x2() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0)
@ -32,7 +32,7 @@ class TestRealTensorAlgebra {
}
@Test
fun transpose1x2x3() = RealTensorAlgebra {
fun transpose1x2x3() = DoubleTensorAlgebra {
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res01 = tensor.transpose(0, 1)
val res02 = tensor.transpose(0, 2)
@ -48,7 +48,7 @@ class TestRealTensorAlgebra {
}
@Test
fun broadcastShapes() = RealTensorAlgebra {
fun broadcastShapes() = DoubleTensorAlgebra {
assertTrue(broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3))
@ -59,7 +59,7 @@ class TestRealTensorAlgebra {
}
@Test
fun broadcastTensors() = RealTensorAlgebra {
fun broadcastTensors() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
@ -76,7 +76,7 @@ class TestRealTensorAlgebra {
}
@Test
fun minusTensor() = RealTensorAlgebra {
fun minusTensor() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))