Explicit broadcasting enforced
This commit is contained in:
parent
1fa0da2810
commit
274be61330
@ -17,7 +17,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||
public fun luPivot(lu: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
||||
public fun luPivot(lu: TensorType, pivots: IndexTensorType): Triple<TensorType, TensorType, TensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||
|
@ -1,10 +1,91 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.*
|
||||
import kotlin.math.max
|
||||
|
||||
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[i] + newOther.buffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] +=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[i] - newOther.buffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] -=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[newOther.bufferStart + i] *
|
||||
newOther.buffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] *=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[newOther.bufferStart + i] /
|
||||
newOther.buffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] /=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
||||
BroadcastDoubleTensorAlgebra().block()
|
||||
|
||||
|
||||
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||
println(shapes)
|
||||
var totalDim = 0
|
||||
for (shape in shapes) {
|
||||
totalDim = max(totalDim, shape.size)
|
||||
@ -99,68 +180,4 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes $sa and $sb for dot product" }
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Int>.array(): IntArray = when(this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Long>.array(): LongArray = when(this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Float>.array(): FloatArray = when(this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Double>.array(): DoubleArray = when(this) {
|
||||
is RealBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
||||
}
|
@ -1,8 +1,10 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.linear.Matrix
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.tensors.TensorStrides
|
||||
import space.kscience.kmath.tensors.TensorStructure
|
||||
|
||||
|
||||
public open class BufferedTensor<T>(
|
||||
@ -76,25 +78,25 @@ public open class BufferedTensor<T>(
|
||||
}
|
||||
|
||||
|
||||
public class IntTensor(
|
||||
public class IntTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
||||
|
||||
public class LongTensor(
|
||||
public class LongTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: LongArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Long>(shape, LongBuffer(buffer), offset)
|
||||
|
||||
public class FloatTensor(
|
||||
public class FloatTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: FloatArray,
|
||||
offset: Int = 0
|
||||
) : BufferedTensor<Float>(shape, FloatBuffer(buffer), offset)
|
||||
|
||||
public class DoubleTensor(
|
||||
public class DoubleTensor internal constructor(
|
||||
shape: IntArray,
|
||||
buffer: DoubleArray,
|
||||
offset: Int = 0
|
@ -1,4 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.AnalyticTensorAlgebra
|
||||
|
||||
public class DoubleAnalyticTensorAlgebra:
|
||||
AnalyticTensorAlgebra<Double, DoubleTensor>,
|
@ -1,4 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||
|
||||
public class DoubleLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
@ -1,4 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.OrderedTensorAlgebra
|
||||
|
||||
public open class DoubleOrderedTensorAlgebra:
|
||||
OrderedTensorAlgebra<Double, DoubleTensor>,
|
@ -1,8 +1,10 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.ReduceOpsTensorAlgebra
|
||||
|
||||
public class DoubleReduceOpsTensorAlgebra:
|
||||
DoubleTensorAlgebra(),
|
||||
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
|
||||
ReduceOpsTensorAlgebra<Double, DoubleTensor> {
|
||||
|
||||
override fun DoubleTensor.value(): Double {
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
@ -1,8 +1,18 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra
|
||||
|
||||
|
||||
public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, DoubleTensor> {
|
||||
|
||||
public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor {
|
||||
checkEmptyShape(shape)
|
||||
checkEmptyDoubleBuffer(buffer)
|
||||
checkBufferShapeConsistency(shape, buffer)
|
||||
return DoubleTensor(shape, buffer, 0)
|
||||
}
|
||||
|
||||
|
||||
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||
val lastShape = this.shape.drop(1).toIntArray()
|
||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||
@ -53,13 +63,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
override fun DoubleTensor.plus(value: Double): DoubleTensor = value + this
|
||||
|
||||
override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[i] + newOther.buffer.array()[i]
|
||||
checkShapesCompatible(this, other)
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[i] + other.buffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.plusAssign(value: Double) {
|
||||
@ -69,10 +77,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
checkShapesCompatible(this, other)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] +=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,13 +99,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[i] - newOther.buffer.array()[i]
|
||||
checkShapesCompatible(this, other)
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[i] - other.buffer.array()[i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minusAssign(value: Double) {
|
||||
@ -107,10 +113,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
checkShapesCompatible(this, other)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] -=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,15 +130,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
||||
|
||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[newOther.bufferStart + i] *
|
||||
newOther.buffer.array()[newOther.bufferStart + i]
|
||||
checkShapesCompatible(this, other)
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[other.bufferStart + i] *
|
||||
other.buffer.array()[other.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.timesAssign(value: Double) {
|
||||
@ -142,10 +145,40 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
checkShapesCompatible(this, other)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] *=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.div(value: Double): DoubleTensor {
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[this.bufferStart + i] / value
|
||||
}
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
||||
checkShapesCompatible(this, other)
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[other.bufferStart + i] /
|
||||
other.buffer.array()[other.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.divAssign(value: Double) {
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] /= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
||||
checkShapesCompatible(this, other)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] /=
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,27 +262,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
|
||||
override fun DoubleTensor.div(value: Double): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.flatten(startDim: Int, endDim: Int): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.divAssign(value: Double) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.divAssign(other: DoubleTensor) {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.mean(dim: Int, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
@ -272,5 +288,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
|
||||
}
|
||||
|
||||
|
||||
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||
DoubleTensorAlgebra().block()
|
||||
DoubleTensorAlgebra().block()
|
@ -0,0 +1,67 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.TensorStructure
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkEmptyShape(shape: IntArray): Unit =
|
||||
check(shape.isNotEmpty()) {
|
||||
"Illegal empty shape provided"
|
||||
}
|
||||
|
||||
internal inline fun < TensorType : TensorStructure<Double>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||
TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit =
|
||||
check(buffer.isNotEmpty()) {
|
||||
"Illegal empty buffer provided"
|
||||
}
|
||||
|
||||
internal inline fun < TensorType : TensorStructure<Double>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<Double, TensorType>>
|
||||
TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit =
|
||||
check(buffer.size == shape.reduce(Int::times)) {
|
||||
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
|
||||
}
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkShapesCompatible(a: TensorType, b: TensorType): Unit =
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
|
||||
}
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkDot(a: TensorType, b: TensorType): Unit {
|
||||
val sa = a.shape
|
||||
val sb = b.shape
|
||||
val na = sa.size
|
||||
val nb = sb.size
|
||||
var status: Boolean
|
||||
if (nb == 1) {
|
||||
status = sa.last() == sb[0]
|
||||
} else {
|
||||
status = sa.last() == sb[nb - 2]
|
||||
if ((na > 2) and (nb > 2)) {
|
||||
status = status and
|
||||
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
|
||||
}
|
||||
}
|
||||
check(status) { "Incompatible shapes ${sa.toList()} and ${sb.toList()} provided for dot product" }
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit =
|
||||
check((i < dim) and (j < dim)) {
|
||||
"Cannot transpose $i to $j for a tensor of dim $dim"
|
||||
}
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
@ -0,0 +1,36 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.*
|
||||
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Long>.array(): LongArray = when (this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Float>.array(): FloatArray = when (this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||
is RealBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TestDoubleTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun doublePlus() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||
val res = 10.0 + tensor
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x1() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||
val res = tensor.transpose(0, 0)
|
||||
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose3x2() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res = tensor.transpose(1, 0)
|
||||
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res01 = tensor.transpose(0, 1)
|
||||
val res02 = tensor.transpose(0, 2)
|
||||
val res12 = tensor.transpose(1, 2)
|
||||
|
||||
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||
|
||||
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastShapes() = DoubleTensorAlgebra {
|
||||
assertTrue(broadcastShapes(
|
||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||
) contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(broadcastShapes(
|
||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
||||
) contentEquals intArrayOf(5, 6, 7))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTo() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun minusTensor() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue((tensor3 - tensor1).buffer.array()
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
||||
|
||||
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
||||
assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,82 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TestBroadcasting {
|
||||
|
||||
@Test
|
||||
fun broadcastShapes() = DoubleTensorAlgebra {
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||
) contentEquals intArrayOf(1, 2, 3)
|
||||
)
|
||||
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7), intArrayOf(5, 1, 7)
|
||||
) contentEquals intArrayOf(5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTo() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun minusTensor() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val tensor21 = broadcastDoubleTensorAlgebra {
|
||||
tensor2 - tensor1
|
||||
}
|
||||
|
||||
val tensor31 = broadcastDoubleTensorAlgebra {
|
||||
tensor3 - tensor1
|
||||
}
|
||||
|
||||
val tensor32 = broadcastDoubleTensorAlgebra {
|
||||
tensor3 - tensor2
|
||||
}
|
||||
|
||||
|
||||
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertTrue(tensor31.shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(
|
||||
tensor31.buffer.array()
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||
)
|
||||
|
||||
assertTrue(tensor32.shape contentEquals intArrayOf(1, 1, 3))
|
||||
assertTrue(tensor32.buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package space.kscience.kmath.tensors
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
|
||||
import space.kscience.kmath.nd.as1D
|
||||
@ -13,20 +13,20 @@ class TestDoubleTensor {
|
||||
@Test
|
||||
fun valueTest() = DoubleReduceOpsTensorAlgebra {
|
||||
val value = 12.5
|
||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(value))
|
||||
val tensor = fromArray(intArrayOf(1), doubleArrayOf(value))
|
||||
assertEquals(tensor.value(), value)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun stridesTest(){
|
||||
val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||
fun stridesTest() = DoubleTensorAlgebra {
|
||||
val tensor = fromArray(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
||||
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun getTest() = DoubleTensorAlgebra {
|
||||
val tensor = DoubleTensor(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||
val tensor = fromArray(intArrayOf(1,2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||
val matrix = tensor[0].as2D()
|
||||
assertEquals(matrix[0,1], 5.8)
|
||||
|
@ -0,0 +1,50 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TestDoubleTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun doublePlus() = DoubleTensorAlgebra {
|
||||
val tensor = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||
val res = 10.0 + tensor
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x1() = DoubleTensorAlgebra {
|
||||
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
|
||||
val res = tensor.transpose(0, 0)
|
||||
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose3x2() = DoubleTensorAlgebra {
|
||||
val tensor = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res = tensor.transpose(1, 0)
|
||||
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res01 = tensor.transpose(0, 1)
|
||||
val res02 = tensor.transpose(0, 2)
|
||||
val res12 = tensor.transpose(1, 2)
|
||||
|
||||
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||
|
||||
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user