forked from kscience/kmath
Explicit broadcasting enforced
This commit is contained in:
parent
1fa0da2810
commit
274be61330
@ -17,7 +17,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
|||||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||||
public fun luPivot(lu: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
public fun luPivot(lu: TensorType, pivots: IndexTensorType): Triple<TensorType, TensorType, TensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||||
|
@ -1,10 +1,91 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.structures.*
|
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||||
|
|
||||||
|
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
|
||||||
|
val broadcast = broadcastTensors(this, other)
|
||||||
|
val newThis = broadcast[0]
|
||||||
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
|
newThis.buffer.array()[i] + newOther.buffer.array()[i]
|
||||||
|
}
|
||||||
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||||
|
val newOther = broadcastTo(other, this.shape)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] +=
|
||||||
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
|
||||||
|
val broadcast = broadcastTensors(this, other)
|
||||||
|
val newThis = broadcast[0]
|
||||||
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
|
newThis.buffer.array()[i] - newOther.buffer.array()[i]
|
||||||
|
}
|
||||||
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||||
|
val newOther = broadcastTo(other, this.shape)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] -=
|
||||||
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
|
val broadcast = broadcastTensors(this, other)
|
||||||
|
val newThis = broadcast[0]
|
||||||
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
|
newThis.buffer.array()[newOther.bufferStart + i] *
|
||||||
|
newOther.buffer.array()[newOther.bufferStart + i]
|
||||||
|
}
|
||||||
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||||
|
val newOther = broadcastTo(other, this.shape)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] *=
|
||||||
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
||||||
|
val broadcast = broadcastTensors(this, other)
|
||||||
|
val newThis = broadcast[0]
|
||||||
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
|
newThis.buffer.array()[newOther.bufferStart + i] /
|
||||||
|
newOther.buffer.array()[newOther.bufferStart + i]
|
||||||
|
}
|
||||||
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
||||||
|
val newOther = broadcastTo(other, this.shape)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] /=
|
||||||
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
||||||
|
BroadcastDoubleTensorAlgebra().block()
|
||||||
|
|
||||||
|
|
||||||
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||||
|
println(shapes)
|
||||||
var totalDim = 0
|
var totalDim = 0
|
||||||
for (shape in shapes) {
|
for (shape in shapes) {
|
||||||
totalDim = max(totalDim, shape.size)
|
totalDim = max(totalDim, shape.size)
|
||||||
@ -100,67 +181,3 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
|
||||||
val sa = a.shape
|
|
||||||
val sb = b.shape
|
|
||||||
val na = sa.size
|
|
||||||
val nb = sb.size
|
|
||||||
var status: Boolean
|
|
||||||
if (nb == 1) {
|
|
||||||
status = sa.last() == sb[0]
|
|
||||||
} else {
|
|
||||||
status = sa.last() == sb[nb - 2]
|
|
||||||
if ((na > 2) and (nb > 2)) {
|
|
||||||
status = status and
|
|
||||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
|
||||||
}
|
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
|
||||||
check((i < dim) and (j < dim)) {
|
|
||||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
|
||||||
}
|
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
|
||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
|
||||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
|
||||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
internal fun Buffer<Int>.array(): IntArray = when(this) {
|
|
||||||
is IntBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
internal fun Buffer<Long>.array(): LongArray = when(this) {
|
|
||||||
is LongBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
internal fun Buffer<Float>.array(): FloatArray = when(this) {
|
|
||||||
is FloatBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
internal fun Buffer<Double>.array(): DoubleArray = when(this) {
|
|
||||||
is RealBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
|
||||||
}
|
|
@ -1,8 +1,10 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Matrix
|
import space.kscience.kmath.linear.Matrix
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
import space.kscience.kmath.tensors.TensorStrides
|
||||||
|
import space.kscience.kmath.tensors.TensorStructure
|
||||||
|
|
||||||
|
|
||||||
public open class BufferedTensor<T>(
|
public open class BufferedTensor<T>(
|
||||||
@ -76,25 +78,25 @@ public open class BufferedTensor<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class IntTensor(
|
public class IntTensor internal constructor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray,
|
buffer: IntArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
||||||
|
|
||||||
public class LongTensor(
|
public class LongTensor internal constructor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: LongArray,
|
buffer: LongArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Long>(shape, LongBuffer(buffer), offset)
|
) : BufferedTensor<Long>(shape, LongBuffer(buffer), offset)
|
||||||
|
|
||||||
public class FloatTensor(
|
public class FloatTensor internal constructor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: FloatArray,
|
buffer: FloatArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Float>(shape, FloatBuffer(buffer), offset)
|
) : BufferedTensor<Float>(shape, FloatBuffer(buffer), offset)
|
||||||
|
|
||||||
public class DoubleTensor(
|
public class DoubleTensor internal constructor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: DoubleArray,
|
buffer: DoubleArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
@ -1,4 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.AnalyticTensorAlgebra
|
||||||
|
|
||||||
public class DoubleAnalyticTensorAlgebra:
|
public class DoubleAnalyticTensorAlgebra:
|
||||||
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
@ -1,4 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||||
|
|
||||||
public class DoubleLinearOpsTensorAlgebra :
|
public class DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
@ -1,4 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.OrderedTensorAlgebra
|
||||||
|
|
||||||
public open class DoubleOrderedTensorAlgebra:
|
public open class DoubleOrderedTensorAlgebra:
|
||||||
OrderedTensorAlgebra<Double, DoubleTensor>,
|
OrderedTensorAlgebra<Double, DoubleTensor>,
|
@ -1,4 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.ReduceOpsTensorAlgebra
|
||||||
|
|
||||||
public class DoubleReduceOpsTensorAlgebra:
|
public class DoubleReduceOpsTensorAlgebra:
|
||||||
DoubleTensorAlgebra(),
|
DoubleTensorAlgebra(),
|
@ -1,8 +1,18 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
|
||||||
|
|
||||||
|
|
||||||
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||||
|
|
||||||
|
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
|
||||||
|
checkEmptyShape(shape)
|
||||||
|
checkEmptyDoubleBuffer(buffer)
|
||||||
|
checkBufferShapeConsistency(shape, buffer)
|
||||||
|
return DoubleTensor(shape, buffer, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||||
val lastShape = this.shape.drop(1).toIntArray()
|
val lastShape = this.shape.drop(1).toIntArray()
|
||||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
@ -53,13 +63,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.plus(value: Double): DoubleTensor = value + this
|
override fun DoubleTensor.plus(value: Double): DoubleTensor = value + this
|
||||||
|
|
||||||
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
|
||||||
val broadcast = broadcastTensors(this, other)
|
checkShapesCompatible(this, other)
|
||||||
val newThis = broadcast[0]
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
val newOther = broadcast[1]
|
this.buffer.array()[i] + other.buffer.array()[i]
|
||||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
|
||||||
newThis.buffer.array()[i] + newOther.buffer.array()[i]
|
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(value: Double) {
|
override fun DoubleTensor.plusAssign(value: Double) {
|
||||||
@ -69,10 +77,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||||
val newOther = broadcastTo(other, this.shape)
|
checkShapesCompatible(this, other)
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] +=
|
this.buffer.array()[this.bufferStart + i] +=
|
||||||
newOther.buffer.array()[this.bufferStart + i]
|
other.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,13 +99,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
|
||||||
val broadcast = broadcastTensors(this, other)
|
checkShapesCompatible(this, other)
|
||||||
val newThis = broadcast[0]
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
val newOther = broadcast[1]
|
this.buffer.array()[i] - other.buffer.array()[i]
|
||||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
|
||||||
newThis.buffer.array()[i] - newOther.buffer.array()[i]
|
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minusAssign(value: Double) {
|
override fun DoubleTensor.minusAssign(value: Double) {
|
||||||
@ -107,10 +113,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||||
val newOther = broadcastTo(other, this.shape)
|
checkShapesCompatible(this, other)
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] -=
|
this.buffer.array()[this.bufferStart + i] -=
|
||||||
newOther.buffer.array()[this.bufferStart + i]
|
other.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,15 +130,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
||||||
|
|
||||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
val broadcast = broadcastTensors(this, other)
|
checkShapesCompatible(this, other)
|
||||||
val newThis = broadcast[0]
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
val newOther = broadcast[1]
|
this.buffer.array()[other.bufferStart + i] *
|
||||||
|
other.buffer.array()[other.bufferStart + i]
|
||||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
|
||||||
newThis.buffer.array()[newOther.bufferStart + i] *
|
|
||||||
newOther.buffer.array()[newOther.bufferStart + i]
|
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(value: Double) {
|
override fun DoubleTensor.timesAssign(value: Double) {
|
||||||
@ -142,10 +145,40 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||||
val newOther = broadcastTo(other, this.shape)
|
checkShapesCompatible(this, other)
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] *=
|
this.buffer.array()[this.bufferStart + i] *=
|
||||||
newOther.buffer.array()[this.bufferStart + i]
|
other.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.div(value: Double): DoubleTensor {
|
||||||
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
|
this.buffer.array()[this.bufferStart + i] / value
|
||||||
|
}
|
||||||
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
||||||
|
checkShapesCompatible(this, other)
|
||||||
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
|
this.buffer.array()[other.bufferStart + i] /
|
||||||
|
other.buffer.array()[other.bufferStart + i]
|
||||||
|
}
|
||||||
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.divAssign(value: Double) {
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] /= value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
||||||
|
checkShapesCompatible(this, other)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] /=
|
||||||
|
other.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,27 +262,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.div(value: Double): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.flatten(startDim: Int, endDim: Int): DoubleTensor {
|
override fun DoubleTensor.flatten(startDim: Int, endDim: Int): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.divAssign(value: Double) {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.mean(dim: Int, keepDim: Boolean): DoubleTensor {
|
override fun DoubleTensor.mean(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
@ -272,5 +288,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||||
DoubleTensorAlgebra().block()
|
DoubleTensorAlgebra().block()
|
@ -0,0 +1,67 @@
|
|||||||
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.TensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.TensorStructure
|
||||||
|
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkEmptyShape(shape: IntArray): Unit =
|
||||||
|
check(shape.isNotEmpty()) {
|
||||||
|
"Illegal empty shape provided"
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun < TensorType : TensorStructure<Double>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
|
||||||
|
check(buffer.isNotEmpty()) {
|
||||||
|
"Illegal empty buffer provided"
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun < TensorType : TensorStructure<Double>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
|
||||||
|
check(buffer.size == shape.reduce(Int::times)) {
|
||||||
|
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkShapesCompatible(a: TensorType, b: TensorType): Unit =
|
||||||
|
check(a.shape contentEquals b.shape) {
|
||||||
|
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||||
|
val sa = a.shape
|
||||||
|
val sb = b.shape
|
||||||
|
val na = sa.size
|
||||||
|
val nb = sb.size
|
||||||
|
var status: Boolean
|
||||||
|
if (nb == 1) {
|
||||||
|
status = sa.last() == sb[0]
|
||||||
|
} else {
|
||||||
|
status = sa.last() == sb[nb - 2]
|
||||||
|
if ((na > 2) and (nb > 2)) {
|
||||||
|
status = status and
|
||||||
|
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
check(status) { "Incompatible shapes ${sa.toList()} and ${sb.toList()} provided for dot product" }
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||||
|
check((i < dim) and (j < dim)) {
|
||||||
|
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
|
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||||
|
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -0,0 +1,36 @@
|
|||||||
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.*
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||||
|
is IntBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Long>.array(): LongArray = when (this) {
|
||||||
|
is LongBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Float>.array(): FloatArray = when (this) {
|
||||||
|
is FloatBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||||
|
is RealBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||||
|
}
|
@ -1,105 +0,0 @@
|
|||||||
package space.kscience.kmath.tensors
|
|
||||||
|
|
||||||
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
class TestDoubleTensorAlgebra {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun doublePlus() = DoubleTensorAlgebra {
|
|
||||||
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
|
||||||
val res = 10.0 + tensor
|
|
||||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun transpose1x1() = DoubleTensorAlgebra {
|
|
||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
|
||||||
val res = tensor.transpose(0, 0)
|
|
||||||
|
|
||||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
|
|
||||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun transpose3x2() = DoubleTensorAlgebra {
|
|
||||||
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
val res = tensor.transpose(1, 0)
|
|
||||||
|
|
||||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
|
||||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun transpose1x2x3() = DoubleTensorAlgebra {
|
|
||||||
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
val res01 = tensor.transpose(0, 1)
|
|
||||||
val res02 = tensor.transpose(0, 2)
|
|
||||||
val res12 = tensor.transpose(1, 2)
|
|
||||||
|
|
||||||
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
|
||||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
|
||||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
|
||||||
|
|
||||||
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
|
||||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun broadcastShapes() = DoubleTensorAlgebra {
|
|
||||||
assertTrue(broadcastShapes(
|
|
||||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
|
||||||
) contentEquals intArrayOf(1, 2, 3))
|
|
||||||
|
|
||||||
assertTrue(broadcastShapes(
|
|
||||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
|
||||||
) contentEquals intArrayOf(5, 6, 7))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun broadcastTo() = DoubleTensorAlgebra {
|
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
|
||||||
|
|
||||||
val res = broadcastTo(tensor2, tensor1.shape)
|
|
||||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
|
||||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
|
||||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
|
||||||
|
|
||||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
|
||||||
|
|
||||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
|
||||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
|
||||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
|
||||||
|
|
||||||
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
|
||||||
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun minusTensor() = DoubleTensorAlgebra {
|
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
|
||||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
|
||||||
|
|
||||||
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
|
||||||
assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
|
||||||
assertTrue((tensor3 - tensor1).buffer.array()
|
|
||||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
|
||||||
assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -0,0 +1,82 @@
|
|||||||
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class TestBroadcasting {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastShapes() = DoubleTensorAlgebra {
|
||||||
|
assertTrue(
|
||||||
|
broadcastShapes(
|
||||||
|
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||||
|
) contentEquals intArrayOf(1, 2, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
assertTrue(
|
||||||
|
broadcastShapes(
|
||||||
|
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7), intArrayOf(5, 1, 7)
|
||||||
|
) contentEquals intArrayOf(5, 6, 7)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastTo() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
|
||||||
|
val res = broadcastTo(tensor2, tensor1.shape)
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
|
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
|
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
|
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun minusTensor() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
val tensor21 = broadcastDoubleTensorAlgebra {
|
||||||
|
tensor2 - tensor1
|
||||||
|
}
|
||||||
|
|
||||||
|
val tensor31 = broadcastDoubleTensorAlgebra {
|
||||||
|
tensor3 - tensor1
|
||||||
|
}
|
||||||
|
|
||||||
|
val tensor32 = broadcastDoubleTensorAlgebra {
|
||||||
|
tensor3 - tensor2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
||||||
|
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
|
||||||
|
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(
|
||||||
|
tensor31.buffer.array()
|
||||||
|
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
|
||||||
|
assertTrue(tensor32.buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
@ -13,20 +13,20 @@ class TestDoubleTensor {
|
|||||||
@Test
|
@Test
|
||||||
fun valueTest() = DoubleReduceOpsTensorAlgebra {
|
fun valueTest() = DoubleReduceOpsTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
|
||||||
assertEquals(tensor.value(), value)
|
assertEquals(tensor.value(), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun stridesTest(){
|
fun stridesTest() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
val tensor = fromArray(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||||
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
||||||
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun getTest() = DoubleTensorAlgebra {
|
fun getTest() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
val tensor = fromArray(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||||
val matrix = tensor[0].as2D()
|
val matrix = tensor[0].as2D()
|
||||||
assertEquals(matrix[0,1], 5.8)
|
assertEquals(matrix[0,1], 5.8)
|
||||||
|
|
@ -0,0 +1,50 @@
|
|||||||
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class TestDoubleTensorAlgebra {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun doublePlus() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
|
val res = 10.0 + tensor
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose1x1() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
|
val res = tensor.transpose(0, 0)
|
||||||
|
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose3x2() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val res = tensor.transpose(1, 0)
|
||||||
|
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val res01 = tensor.transpose(0, 1)
|
||||||
|
val res02 = tensor.transpose(0, 2)
|
||||||
|
val res12 = tensor.transpose(1, 2)
|
||||||
|
|
||||||
|
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||||
|
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||||
|
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||||
|
|
||||||
|
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
|
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user