Global refactor of tensors
This commit is contained in:
parent
b5d04ba02c
commit
20886d6f6b
@ -132,7 +132,10 @@ public open class DoubleTensorAlgebra :
|
|||||||
val dt = asDoubleTensor()
|
val dt = asDoubleTensor()
|
||||||
val lastShape = shape.drop(1).toIntArray()
|
val lastShape = shape.drop(1).toIntArray()
|
||||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
return DoubleTensor(newShape, dt.source.view(newShape.reduce(Int::times) * i))
|
return DoubleTensor(
|
||||||
|
newShape,
|
||||||
|
dt.source.view(newShape.reduce(Int::times) * i, TensorLinearStructure.linearSizeOf(newShape))
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -227,7 +230,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
override fun StructureND<Double>.minus(arg: Double): DoubleTensor = map { it - arg }
|
override fun StructureND<Double>.minus(arg: Double): DoubleTensor = map { it - arg }
|
||||||
|
|
||||||
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor = zip(this, arg) { l, r -> l + r }
|
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor = zip(this, arg) { l, r -> l - r }
|
||||||
|
|
||||||
override fun Tensor<Double>.minusAssign(value: Double) {
|
override fun Tensor<Double>.minusAssign(value: Double) {
|
||||||
mapInPlace { it - value }
|
mapInPlace { it - value }
|
||||||
|
@ -221,7 +221,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
|
|||||||
|
|
||||||
override fun StructureND<Int>.minus(arg: Int): IntTensor = map { it - arg }
|
override fun StructureND<Int>.minus(arg: Int): IntTensor = map { it - arg }
|
||||||
|
|
||||||
override fun StructureND<Int>.minus(arg: StructureND<Int>): IntTensor = zip(this, arg) { l, r -> l + r }
|
override fun StructureND<Int>.minus(arg: StructureND<Int>): IntTensor = zip(this, arg) { l, r -> l - r }
|
||||||
|
|
||||||
override fun Tensor<Int>.minusAssign(value: Int) {
|
override fun Tensor<Int>.minusAssign(value: Int) {
|
||||||
mapInPlace { it - value }
|
mapInPlace { it - value }
|
||||||
@ -283,8 +283,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
|
|||||||
view(other.shape)
|
view(other.shape)
|
||||||
|
|
||||||
override fun StructureND<Int>.dot(other: StructureND<Int>): IntTensor {
|
override fun StructureND<Int>.dot(other: StructureND<Int>): IntTensor {
|
||||||
return if (dimension in 0..2 && other.dimension in 0..2) TODO("not implemented")
|
TODO("not implemented for integers")
|
||||||
else error("Only vectors and matrices are allowed in non-broadcasting dot operation")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
|
@ -15,14 +15,12 @@ import kotlin.math.max
|
|||||||
* @param shape the shape of the tensor.
|
* @param shape the shape of the tensor.
|
||||||
*/
|
*/
|
||||||
public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||||
override val strides: IntArray
|
override val strides: IntArray get() = stridesFromShape(shape)
|
||||||
get() = stridesFromShape(shape)
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
override val linearSize: Int
|
override val linearSize: Int get() = linearSizeOf(shape)
|
||||||
get() = shape.reduce(Int::times)
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
@ -41,6 +39,8 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
|
|
||||||
|
public fun linearSizeOf(shape: IntArray): Int = shape.reduce(Int::times)
|
||||||
|
|
||||||
public fun stridesFromShape(shape: IntArray): IntArray {
|
public fun stridesFromShape(shape: IntArray): IntArray {
|
||||||
val nDim = shape.size
|
val nDim = shape.size
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
|
@ -14,11 +14,8 @@ import space.kscience.kmath.structures.DoubleBuffer
|
|||||||
import space.kscience.kmath.structures.VirtualBuffer
|
import space.kscience.kmath.structures.VirtualBuffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
|
import space.kscience.kmath.tensors.core.*
|
||||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.eye
|
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.eye
|
||||||
import space.kscience.kmath.tensors.core.BufferedTensor
|
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
|
||||||
import space.kscience.kmath.tensors.core.OffsetDoubleBuffer
|
|
||||||
import space.kscience.kmath.tensors.core.copyToTensor
|
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
@ -165,7 +162,7 @@ internal val DoubleTensor.vectors: VirtualBuffer<DoubleTensor>
|
|||||||
|
|
||||||
return VirtualBuffer(linearSize / vectorOffset) { index ->
|
return VirtualBuffer(linearSize / vectorOffset) { index ->
|
||||||
val offset = index * vectorOffset
|
val offset = index * vectorOffset
|
||||||
DoubleTensor(vectorShape, source.view(offset))
|
DoubleTensor(vectorShape, source.view(offset, vectorShape.first()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,16 +171,18 @@ internal fun DoubleTensor.vectorSequence(): Sequence<DoubleTensor> = vectors.asS
|
|||||||
|
|
||||||
|
|
||||||
internal val DoubleTensor.matrices: VirtualBuffer<DoubleTensor>
|
internal val DoubleTensor.matrices: VirtualBuffer<DoubleTensor>
|
||||||
get(){
|
get() {
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" }
|
check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" }
|
||||||
val matrixOffset = shape[n - 1] * shape[n - 2]
|
val matrixOffset = shape[n - 1] * shape[n - 2]
|
||||||
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
|
val matrixShape = intArrayOf(shape[n - 2], shape[n - 1])
|
||||||
|
|
||||||
return VirtualBuffer(linearSize / matrixOffset) { index ->
|
val size = TensorLinearStructure.linearSizeOf(matrixShape)
|
||||||
val offset = index * matrixOffset
|
|
||||||
DoubleTensor(matrixShape, source.view(offset))
|
return VirtualBuffer(linearSize / matrixOffset) { index ->
|
||||||
|
val offset = index * matrixOffset
|
||||||
|
DoubleTensor(matrixShape, source.view(offset, size))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
internal fun DoubleTensor.matrixSequence(): Sequence<DoubleTensor> = matrices.asSequence()
|
internal fun DoubleTensor.matrixSequence(): Sequence<DoubleTensor> = matrices.asSequence()
|
@ -40,7 +40,7 @@ internal val IntTensor.vectors: VirtualBuffer<IntTensor>
|
|||||||
|
|
||||||
return VirtualBuffer(linearSize / vectorOffset) { index ->
|
return VirtualBuffer(linearSize / vectorOffset) { index ->
|
||||||
val offset = index * vectorOffset
|
val offset = index * vectorOffset
|
||||||
IntTensor(vectorShape, source.view(offset))
|
IntTensor(vectorShape, source.view(offset, vectorShape.first()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,14 +5,20 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.DoubleBufferND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a mutable copy of given [StructureND].
|
||||||
|
*/
|
||||||
public fun StructureND<Double>.copyToTensor(): DoubleTensor = if (this is DoubleTensor) {
|
public fun StructureND<Double>.copyToTensor(): DoubleTensor = if (this is DoubleTensor) {
|
||||||
DoubleTensor(shape, source.copy())
|
DoubleTensor(shape, source.copy())
|
||||||
|
} else if (this is DoubleBufferND && indices is TensorLinearStructure) {
|
||||||
|
DoubleTensor(shape, buffer.copy())
|
||||||
} else {
|
} else {
|
||||||
DoubleTensor(
|
DoubleTensor(
|
||||||
shape,
|
shape,
|
||||||
@ -36,11 +42,14 @@ public fun StructureND<Int>.toDoubleTensor(): DoubleTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Casts [Tensor] of [Double] to [DoubleTensor]
|
* Transforms [StructureND] of [Double] to [DoubleTensor]. Zero copy if possible, but is not guaranteed
|
||||||
*/
|
*/
|
||||||
public fun StructureND<Double>.asDoubleTensor(): DoubleTensor = when (this) {
|
public fun StructureND<Double>.asDoubleTensor(): DoubleTensor = if (this is DoubleTensor) {
|
||||||
is DoubleTensor -> this
|
this
|
||||||
else -> copyToTensor()
|
} else if (this is DoubleBufferND && indices is TensorLinearStructure) {
|
||||||
|
DoubleTensor(shape, buffer)
|
||||||
|
} else {
|
||||||
|
copyToTensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,10 +6,7 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.DefaultStrides
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.MutableBufferND
|
|
||||||
import space.kscience.kmath.nd.as1D
|
|
||||||
import space.kscience.kmath.nd.as2D
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.toDoubleArray
|
||||||
@ -66,16 +63,16 @@ internal class TestDoubleTensor {
|
|||||||
fun testNoBufferProtocol() {
|
fun testNoBufferProtocol() {
|
||||||
|
|
||||||
// create buffer
|
// create buffer
|
||||||
val doubleArray = DoubleBuffer(doubleArrayOf(1.0, 2.0, 3.0))
|
val doubleArray = DoubleBuffer(1.0, 2.0, 3.0)
|
||||||
|
|
||||||
// create ND buffers, no data is copied
|
// create ND buffers, no data is copied
|
||||||
val ndArray: MutableBufferND<Double> = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleArray)
|
val ndArray: MutableBufferND<Double> = DoubleBufferND(DefaultStrides(intArrayOf(3)), doubleArray)
|
||||||
|
|
||||||
// map to tensors
|
// map to tensors
|
||||||
val bufferedTensorArray = ndArray.asDoubleTensor() // strides are flipped so data copied
|
val tensorArray = ndArray.asDoubleTensor() // Data is copied because of strides change.
|
||||||
val tensorArray = bufferedTensorArray.asDoubleTensor() // data not contiguous so copied again
|
|
||||||
|
|
||||||
val tensorArrayPublic = ndArray.asDoubleTensor() // public API, data copied twice
|
//protective copy
|
||||||
|
val tensorArrayPublic = ndArray.copyToTensor() // public API, data copied twice
|
||||||
val sharedTensorArray = tensorArrayPublic.asDoubleTensor() // no data copied by matching type
|
val sharedTensorArray = tensorArrayPublic.asDoubleTensor() // no data copied by matching type
|
||||||
|
|
||||||
assertTrue(tensorArray.source contentEquals sharedTensorArray.source)
|
assertTrue(tensorArray.source contentEquals sharedTensorArray.source)
|
||||||
@ -83,11 +80,11 @@ internal class TestDoubleTensor {
|
|||||||
tensorArray[intArrayOf(0)] = 55.9
|
tensorArray[intArrayOf(0)] = 55.9
|
||||||
assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0)
|
assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0)
|
||||||
|
|
||||||
tensorArrayPublic[intArrayOf(0)] = 55.9
|
tensorArrayPublic[intArrayOf(0)] = 57.9
|
||||||
assertEquals(sharedTensorArray[intArrayOf(0)], 55.9)
|
assertEquals(sharedTensorArray[intArrayOf(0)], 57.9)
|
||||||
assertEquals(bufferedTensorArray[intArrayOf(0)], 1.0)
|
assertEquals(tensorArray[intArrayOf(0)], 55.9)
|
||||||
|
|
||||||
bufferedTensorArray[intArrayOf(0)] = 55.9
|
tensorArray[intArrayOf(0)] = 55.9
|
||||||
assertEquals(ndArray[intArrayOf(0)], 1.0)
|
assertEquals(ndArray[intArrayOf(0)], 1.0)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.tensors.core
|
|||||||
|
|
||||||
import space.kscience.kmath.nd.get
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertFalse
|
import kotlin.test.assertFalse
|
||||||
@ -98,8 +99,8 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
assignResult += tensorC
|
assignResult += tensorC
|
||||||
assignResult += -39.4
|
assignResult += -39.4
|
||||||
|
|
||||||
assertTrue(expected.source contentEquals result.source)
|
assertBufferEquals(expected.source, result.source)
|
||||||
assertTrue(expected.source contentEquals assignResult.source)
|
assertBufferEquals(expected.source, assignResult.source)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -202,6 +203,6 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val r = tensor.getTensor(1).map { it - 1.0 }
|
val r = tensor.getTensor(1).map { it - 1.0 }
|
||||||
val res = l + r
|
val res = l + r
|
||||||
assertTrue { intArrayOf(5, 5) contentEquals res.shape }
|
assertTrue { intArrayOf(5, 5) contentEquals res.shape }
|
||||||
assertEquals(1.0, res[4, 4])
|
assertEquals(2.0, res[4, 4])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user