forked from kscience/kmath
OrderedTensorAlgebra
This commit is contained in:
parent
b6a5fbfc14
commit
b227a82a80
@ -8,6 +8,12 @@ public interface ComplexTensorAlgebra<T,
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
|
//https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
|
||||||
public fun RealTensorType.viewAsComplex(): ComplexTensorType
|
public fun RealTensorType.viewAsComplex(): ComplexTensorType
|
||||||
|
|
||||||
|
// Embed a real tensor as real + i * imaginary
|
||||||
|
public fun RealTensorType.cartesianEmbedding(imaginary: RealTensorType): ComplexTensorType
|
||||||
|
|
||||||
|
// Embed a real tensor as real * exp(i * angle)
|
||||||
|
public fun RealTensorType.polarEmbedding(angle: RealTensorType): ComplexTensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.angle.html
|
//https://pytorch.org/docs/stable/generated/torch.angle.html
|
||||||
public fun ComplexTensorType.angle(): RealTensorType
|
public fun ComplexTensorType.angle(): RealTensorType
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
public class RealAnalyticTensorAlgebra:
|
public class DoubleAnalyticTensorAlgebra:
|
||||||
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
||||||
RealTensorAlgebra()
|
DoubleTensorAlgebra()
|
||||||
{
|
{
|
||||||
override fun DoubleTensor.exp(): DoubleTensor {
|
override fun DoubleTensor.exp(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
@ -146,5 +146,5 @@ public class RealAnalyticTensorAlgebra:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> RealAnalyticTensorAlgebra(block: RealAnalyticTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleAnalyticTensorAlgebra(block: DoubleAnalyticTensorAlgebra.() -> R): R =
|
||||||
RealAnalyticTensorAlgebra().block()
|
DoubleAnalyticTensorAlgebra().block()
|
@ -2,27 +2,12 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
public class DoubleLinearOpsTensorAlgebra :
|
public class DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, DoubleTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor>,
|
||||||
RealTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
override fun eye(n: Int): DoubleTensor {
|
|
||||||
val shape = intArrayOf(n, n)
|
override fun DoubleTensor.inv(): DoubleTensor {
|
||||||
val buffer = DoubleArray(n * n) { 0.0 }
|
TODO("Not yet implemented")
|
||||||
val res = DoubleTensor(shape, buffer)
|
|
||||||
for (i in 0 until n) {
|
|
||||||
res[intArrayOf(i, i)] = 1.0
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
|
||||||
TODO("Alya")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
|
||||||
TODO("Alya")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||||
// todo checks
|
// todo checks
|
||||||
val luTensor = this.copy()
|
val luTensor = this.copy()
|
||||||
@ -115,15 +100,6 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return Triple(p, l, u)
|
return Triple(p, l, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor {
|
|
||||||
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.inv(): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
public class DoubleOrderedTensorAlgebra:
|
||||||
|
OrderedTensorAlgebra<Double, DoubleTensor>,
|
||||||
|
DoubleTensorAlgebra()
|
||||||
|
{
|
||||||
|
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <R> DoubleOrderedTensorAlgebra(block: DoubleOrderedTensorAlgebra.() -> R): R =
|
||||||
|
DoubleOrderedTensorAlgebra().block()
|
@ -1,7 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
|
||||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||||
|
|
||||||
override fun DoubleTensor.value(): Double {
|
override fun DoubleTensor.value(): Double {
|
||||||
check(this.shape contentEquals intArrayOf(1)) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
@ -27,6 +27,15 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
|||||||
override fun DoubleTensor.onesLike(): DoubleTensor {
|
override fun DoubleTensor.onesLike(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
override fun eye(n: Int): DoubleTensor {
|
||||||
|
val shape = intArrayOf(n, n)
|
||||||
|
val buffer = DoubleArray(n * n) { 0.0 }
|
||||||
|
val res = DoubleTensor(shape, buffer)
|
||||||
|
for (i in 0 until n) {
|
||||||
|
res[intArrayOf(i, i)] = 1.0
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.copy(): DoubleTensor {
|
override fun DoubleTensor.copy(): DoubleTensor {
|
||||||
// should be rework as soon as copy() method for NDBuffer will be available
|
// should be rework as soon as copy() method for NDBuffer will be available
|
||||||
@ -199,36 +208,12 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.max(dim: Int, keepDim: Boolean): DoubleTensor {
|
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Alya")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cummax(dim: Int): DoubleTensor {
|
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Alya")
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.min(dim: Int, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.cummin(dim: Int): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.median(dim: Int, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun maximum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun minimum(lhs: DoubleTensor, rhs: DoubleTensor) {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.sort(dim: Int, keepDim: Boolean, descending: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun cat(tensors: List<DoubleTensor>, dim: Int): DoubleTensor {
|
override fun cat(tensors: List<DoubleTensor>, dim: Int): DoubleTensor {
|
||||||
@ -276,7 +261,11 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, Doubl
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.det(): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> RealTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||||
RealTensorAlgebra().block()
|
DoubleTensorAlgebra().block()
|
@ -4,21 +4,6 @@ package space.kscience.kmath.tensors
|
|||||||
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
TensorPartialDivisionAlgebra<T, TensorType> {
|
TensorPartialDivisionAlgebra<T, TensorType> {
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
|
||||||
public fun eye(n: Int): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
|
||||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
|
||||||
public fun diagonalEmbedding(
|
|
||||||
diagonalEntries: TensorType,
|
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
|
||||||
): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
|
||||||
public fun TensorType.det(): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||||
public fun TensorType.inv(): TensorType
|
public fun TensorType.inv(): TensorType
|
||||||
|
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
public interface OrderedTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||||
|
TensorAlgebra<T, TensorType> {
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
|
||||||
|
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
|
||||||
|
public fun TensorType.cummax(dim: Int): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
|
||||||
|
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
|
||||||
|
public fun TensorType.cummin(dim: Int): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
|
||||||
|
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
|
||||||
|
public fun maximum(lhs: TensorType, rhs: TensorType)
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
|
||||||
|
public fun minimum(lhs: TensorType, rhs: TensorType)
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
|
||||||
|
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
|
||||||
|
}
|
@ -17,6 +17,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||||
public fun TensorType.fullLike(value: T): TensorType
|
public fun TensorType.fullLike(value: T): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||||
|
public fun eye(n: Int): TensorType
|
||||||
|
|
||||||
public fun TensorType.copy(): TensorType
|
public fun TensorType.copy(): TensorType
|
||||||
|
|
||||||
public operator fun T.plus(other: TensorType): TensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
@ -46,6 +49,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||||
|
public fun TensorType.det(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||||
public fun TensorType.abs(): TensorType
|
public fun TensorType.abs(): TensorType
|
||||||
|
|
||||||
@ -61,29 +67,14 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.cumprod.html#torch.cumprod
|
//https://pytorch.org/docs/stable/generated/torch.cumprod.html#torch.cumprod
|
||||||
public fun TensorType.cumprod(dim: Int): TensorType
|
public fun TensorType.cumprod(dim: Int): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.max.html#torch.max
|
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
public fun TensorType.max(dim: Int, keepDim: Boolean): TensorType
|
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.cummax.html#torch.cummax
|
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
public fun TensorType.cummax(dim: Int): TensorType
|
public fun diagonalEmbedding(
|
||||||
|
diagonalEntries: TensorType,
|
||||||
//https://pytorch.org/docs/stable/generated/torch.min.html#torch.min
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
public fun TensorType.min(dim: Int, keepDim: Boolean): TensorType
|
): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.cummin.html#torch.cummin
|
|
||||||
public fun TensorType.cummin(dim: Int): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.median.html#torch.median
|
|
||||||
public fun TensorType.median(dim: Int, keepDim: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum
|
|
||||||
public fun maximum(lhs: TensorType, rhs: TensorType)
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.minimum.html#torch.minimum
|
|
||||||
public fun minimum(lhs: TensorType, rhs: TensorType)
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
|
|
||||||
public fun TensorType.sort(dim: Int, keepDim: Boolean, descending: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
|
//https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat
|
||||||
public fun cat(tensors: List<TensorType>, dim: Int): TensorType
|
public fun cat(tensors: List<TensorType>, dim: Int): TensorType
|
||||||
|
@ -9,7 +9,7 @@ import kotlin.test.assertTrue
|
|||||||
class TestRealTensor {
|
class TestRealTensor {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueTest() = RealTensorAlgebra {
|
fun valueTest() = DoubleTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
||||||
assertEquals(tensor.value(), value)
|
assertEquals(tensor.value(), value)
|
||||||
|
@ -7,14 +7,14 @@ import kotlin.test.assertTrue
|
|||||||
class TestRealTensorAlgebra {
|
class TestRealTensorAlgebra {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun doublePlus() = RealTensorAlgebra {
|
fun doublePlus() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
val res = 10.0 + tensor
|
val res = 10.0 + tensor
|
||||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun transpose1x1() = RealTensorAlgebra {
|
fun transpose1x1() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
val res = tensor.transpose(0, 0)
|
val res = tensor.transpose(0, 0)
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class TestRealTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun transpose3x2() = RealTensorAlgebra {
|
fun transpose3x2() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val res = tensor.transpose(1, 0)
|
val res = tensor.transpose(1, 0)
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class TestRealTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun transpose1x2x3() = RealTensorAlgebra {
|
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val res01 = tensor.transpose(0, 1)
|
val res01 = tensor.transpose(0, 1)
|
||||||
val res02 = tensor.transpose(0, 2)
|
val res02 = tensor.transpose(0, 2)
|
||||||
@ -48,7 +48,7 @@ class TestRealTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun broadcastShapes() = RealTensorAlgebra {
|
fun broadcastShapes() = DoubleTensorAlgebra {
|
||||||
assertTrue(broadcastShapes(
|
assertTrue(broadcastShapes(
|
||||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||||
) contentEquals intArrayOf(1, 2, 3))
|
) contentEquals intArrayOf(1, 2, 3))
|
||||||
@ -59,7 +59,7 @@ class TestRealTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun broadcastTensors() = RealTensorAlgebra {
|
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
@ -76,7 +76,7 @@ class TestRealTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun minusTensor() = RealTensorAlgebra {
|
fun minusTensor() = DoubleTensorAlgebra {
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
Loading…
Reference in New Issue
Block a user