forked from kscience/kmath
OrderedTensorAlgebra
This commit is contained in:
parent
b6a5fbfc14
commit
b227a82a80
@ -8,6 +8,12 @@ public interface ComplexTensorAlgebra<T,
|
||||
//https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
|
||||
public fun RealTensorType.viewAsComplex(): ComplexTensorType
|
||||
|
||||
// Embed a real tensor as real + i * imaginary
|
||||
public fun RealTensorType.cartesianEmbedding(imaginary: RealTensorType): ComplexTensorType
|
||||
|
||||
// Embed a real tensor as real * exp(i * angle)
|
||||
public fun RealTensorType.polarEmbedding(angle: RealTensorType): ComplexTensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.angle.html
|
||||
public fun ComplexTensorType.angle(): RealTensorType
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
public class RealAnalyticTensorAlgebra:
|
||||
public class DoubleAnalyticTensorAlgebra:
|
||||
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
||||
RealTensorAlgebra()
|
||||
DoubleTensorAlgebra()
|
||||
{
|
||||
override fun DoubleTensor.exp(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
@ -146,5 +146,5 @@ public class RealAnalyticTensorAlgebra:
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> RealAnalyticTensorAlgebra(block: RealAnalyticTensorAlgebra.() -> R): R =
|
||||
RealAnalyticTensorAlgebra().block()
|
||||
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =
|
||||
DoubleAnalyticTensorAlgebra().block()
|
@ -2,27 +2,12 @@ package space.kscience.kmath.tensors
|
||||
|
||||
public class DoubleLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, DoubleTensor>,
|
||||
RealTensorAlgebra() {
|
||||
override fun eye(n: Int): DoubleTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
val buffer = DoubleArray(n * n) { 0.0 }
|
||||
val res = DoubleTensor(shape, buffer)
|
||||
for (i in 0 until n) {
|
||||
res[intArrayOf(i, i)] = 1.0
|
||||
}
|
||||
return res
|
||||
}
|
||||
DoubleTensorAlgebra() {
|
||||
|
||||
|
||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||
TODO("Alya")
|
||||
override fun DoubleTensor.inv(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||
TODO("Alya")
|
||||
}
|
||||
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
// todo checks
|
||||
val luTensor = this.copy()
|
||||
@ -115,15 +100,6 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return Triple(p, l, u)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.det(): DoubleTensor {
|
||||
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.inv(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
@ -0,0 +1,41 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
public class DoubleOrderedTensorAlgebra:
|
||||
OrderedTensorAlgebra<Double, DoubleTensor>,
|
||||
DoubleTensorAlgebra()
|
||||
{
|
||||
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <R> DoubleOrderedTensorAlgebra(block: DoubleOrderedTensorAlgebra.() -> R): R =
|
||||
DoubleOrderedTensorAlgebra().block()
|
@ -1,7 +1,7 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
|
||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||
|
||||
override fun DoubleTensor.value(): Double {
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
||||
@ -27,6 +27,15 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
||||
override fun DoubleTensor.onesLike(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
override fun eye(n: Int): DoubleTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
val buffer = DoubleArray(n * n) { 0.0 }
|
||||
val res = DoubleTensor(shape, buffer)
|
||||
for (i in 0 until n) {
|
||||
res[intArrayOf(i, i)] = 1.0
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
override fun DoubleTensor.copy(): DoubleTensor {
|
||||
// should be rework as soon as copy() method for NDBuffer will be available
|
||||
@ -199,36 +208,12 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||
TODO("Alya")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||
TODO("Alya")
|
||||
}
|
||||
|
||||
override fun cat(tensors: List<DoubleTensor>, dim: Int): DoubleTensor {
|
||||
@ -276,7 +261,11 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.det(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
public inline fun <R> RealTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
|
||||
RealTensorAlgebra().block()
|
||||
}
|
||||
|
||||
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||
DoubleTensorAlgebra().block()
|
@ -4,21 +4,6 @@ package space.kscience.kmath.tensors
|
||||
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
TensorPartialDivisionAlgebra<T, TensorType> {
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||
public fun eye(n: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||
public fun TensorType.det(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||
public fun TensorType.inv(): TensorType
|
||||
|
||||
|
@ -0,0 +1,29 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
public interface OrderedTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
TensorAlgebra<T, TensorType> {
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
|
||||
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
|
||||
public fun TensorType.cummax(dim: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
|
||||
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
|
||||
public fun TensorType.cummin(dim: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
|
||||
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
|
||||
public fun maximum(lhs: TensorType, rhs: TensorType)
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
|
||||
public fun minimum(lhs: TensorType, rhs: TensorType)
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
|
||||
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
|
||||
}
|
@ -17,6 +17,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||
public fun TensorType.fullLike(value: T): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||
public fun eye(n: Int): TensorType
|
||||
|
||||
public fun TensorType.copy(): TensorType
|
||||
|
||||
public operator fun T.plus(other: TensorType): TensorType
|
||||
@ -46,6 +49,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
public fun TensorType.view(shape: IntArray): TensorType
|
||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||
public fun TensorType.det(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||
public fun TensorType.abs(): TensorType
|
||||
|
||||
@ -61,29 +67,14 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
//https://pytorch.org/docs/stable/generated/torch.cumprod.html#torch.cumprod
|
||||
public fun TensorType.cumprod(dim: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
|
||||
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
|
||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
|
||||
public fun TensorType.cummax(dim: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
|
||||
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
|
||||
public fun TensorType.cummin(dim: Int): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
|
||||
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
|
||||
public fun maximum(lhs: TensorType, rhs: TensorType)
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
|
||||
public fun minimum(lhs: TensorType, rhs: TensorType)
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
|
||||
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
|
||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
|
||||
public fun cat(tensors: List<TensorType>, dim: Int): TensorType
|
||||
|
@ -9,7 +9,7 @@ import kotlin.test.assertTrue
|
||||
class TestRealTensor {
|
||||
|
||||
@Test
|
||||
fun valueTest() = RealTensorAlgebra {
|
||||
fun valueTest() = DoubleTensorAlgebra {
|
||||
val value = 12.5
|
||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
||||
assertEquals(tensor.value(), value)
|
||||
|
@ -7,14 +7,14 @@ import kotlin.test.assertTrue
|
||||
class TestRealTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun doublePlus() = RealTensorAlgebra {
|
||||
fun doublePlus() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||
val res = 10.0 + tensor
|
||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x1() = RealTensorAlgebra {
|
||||
fun transpose1x1() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||
val res = tensor.transpose(0, 0)
|
||||
|
||||
@ -23,7 +23,7 @@ class TestRealTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose3x2() = RealTensorAlgebra {
|
||||
fun transpose3x2() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res = tensor.transpose(1, 0)
|
||||
|
||||
@ -32,7 +32,7 @@ class TestRealTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x2x3() = RealTensorAlgebra {
|
||||
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res01 = tensor.transpose(0, 1)
|
||||
val res02 = tensor.transpose(0, 2)
|
||||
@ -48,7 +48,7 @@ class TestRealTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastShapes() = RealTensorAlgebra {
|
||||
fun broadcastShapes() = DoubleTensorAlgebra {
|
||||
assertTrue(broadcastShapes(
|
||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||
) contentEquals intArrayOf(1, 2, 3))
|
||||
@ -59,7 +59,7 @@ class TestRealTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTensors() = RealTensorAlgebra {
|
||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
@ -76,7 +76,7 @@ class TestRealTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun minusTensor() = RealTensorAlgebra {
|
||||
fun minusTensor() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
Loading…
Reference in New Issue
Block a user