Fixed strides flipping
This commit is contained in:
parent
c2db3a23e1
commit
23ea4a95a1
@ -37,8 +37,7 @@ public class IntTensor internal constructor(
|
|||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray,
|
buffer: IntArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset) {
|
||||||
{
|
|
||||||
internal constructor(bufferedTensor: BufferedTensor<Int>) :
|
internal constructor(bufferedTensor: BufferedTensor<Int>) :
|
||||||
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
||||||
}
|
}
|
||||||
@ -47,8 +46,7 @@ public class DoubleTensor internal constructor(
|
|||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: DoubleArray,
|
buffer: DoubleArray,
|
||||||
offset: Int = 0
|
offset: Int = 0
|
||||||
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset)
|
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
|
||||||
{
|
|
||||||
internal constructor(bufferedTensor: BufferedTensor<Double>) :
|
internal constructor(bufferedTensor: BufferedTensor<Double>) :
|
||||||
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart)
|
||||||
|
|
||||||
@ -59,10 +57,17 @@ public class DoubleTensor internal constructor(
|
|||||||
internal inline fun BufferedTensor<Int>.asTensor(): IntTensor = IntTensor(this)
|
internal inline fun BufferedTensor<Int>.asTensor(): IntTensor = IntTensor(this)
|
||||||
internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTensor(this)
|
internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTensor(this)
|
||||||
|
|
||||||
|
internal inline fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
||||||
|
BufferedTensor(
|
||||||
|
this.shape,
|
||||||
|
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
|
||||||
|
)
|
||||||
|
|
||||||
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||||
is BufferedTensor<T> -> this
|
is BufferedTensor<T> -> this
|
||||||
is MutableBufferND<T> -> BufferedTensor(this.shape, this.mutableBuffer, 0)
|
is MutableBufferND<T> -> if (this.strides.strides.toIntArray() contentEquals TensorLinearStructure(this.shape).strides)
|
||||||
else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0)
|
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
||||||
|
else -> this.copyToBufferedTensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
internal val TensorStructure<Double>.tensor: DoubleTensor
|
internal val TensorStructure<Double>.tensor: DoubleTensor
|
||||||
@ -77,3 +82,5 @@ internal val TensorStructure<Int>.tensor: IntTensor
|
|||||||
else -> this.toBufferedTensor().asTensor()
|
else -> this.toBufferedTensor().asTensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun TensorStructure<Double>.toTypedTensor(): DoubleTensor = this.tensor
|
||||||
|
public fun TensorStructure<Int>.toTypedTensor(): IntTensor = this.tensor
|
@ -25,7 +25,9 @@ class TestDoubleTensor {
|
|||||||
fun stridesTest() = DoubleTensorAlgebra {
|
fun stridesTest() = DoubleTensorAlgebra {
|
||||||
val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
|
val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
|
||||||
assertEquals(tensor[intArrayOf(0, 1)], 5.8)
|
assertEquals(tensor[intArrayOf(0, 1)], 5.8)
|
||||||
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
assertTrue(
|
||||||
|
tensor.elements().map { it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray()
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -53,32 +55,32 @@ class TestDoubleTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun bufferProtocol() {
|
fun noBufferProtocol() {
|
||||||
|
|
||||||
// create buffers
|
// create buffer
|
||||||
val doubleBuffer = DoubleBuffer(doubleArrayOf(1.0,2.0,3.0))
|
val doubleArray = DoubleBuffer(doubleArrayOf(1.0, 2.0, 3.0))
|
||||||
val doubleList = MutableList(3, doubleBuffer::get)
|
|
||||||
|
|
||||||
// create ND buffers
|
// create ND buffers, no data is copied
|
||||||
val ndBuffer = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleBuffer)
|
val ndArray = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleArray)
|
||||||
val ndList = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleList.asMutableBuffer())
|
|
||||||
|
|
||||||
// map to tensors
|
// map to tensors
|
||||||
val bufferedTensorBuffer = ndBuffer.toBufferedTensor() // strides are flipped
|
val bufferedTensorArray = ndArray.toBufferedTensor() // strides are flipped so data copied
|
||||||
val tensorBuffer = bufferedTensorBuffer.asTensor() // no data copied
|
val tensorArray = bufferedTensorArray.asTensor() // data not contiguous so copied again
|
||||||
|
|
||||||
val bufferedTensorList = ndList.toBufferedTensor() // strides are flipped
|
val tensorArrayPublic = ndArray.toTypedTensor() // public API, data copied twice
|
||||||
val tensorList = bufferedTensorList.asTensor() // data copied
|
val sharedTensorArray = tensorArrayPublic.toTypedTensor() // no data copied by matching type
|
||||||
|
|
||||||
tensorBuffer[intArrayOf(0)] = 55.9
|
assertTrue(tensorArray.buffer.array() contentEquals sharedTensorArray.buffer.array())
|
||||||
assertEquals(ndBuffer[intArrayOf(0)], 55.9)
|
|
||||||
assertEquals(doubleBuffer[0], 55.9)
|
|
||||||
|
|
||||||
tensorList[intArrayOf(0)] = 55.9
|
tensorArray[intArrayOf(0)] = 55.9
|
||||||
assertEquals(ndList[intArrayOf(0)], 1.0)
|
assertEquals(tensorArrayPublic[intArrayOf(0)], 1.0)
|
||||||
assertEquals(doubleList[0], 1.0)
|
|
||||||
|
tensorArrayPublic[intArrayOf(0)] = 55.9
|
||||||
|
assertEquals(sharedTensorArray[intArrayOf(0)], 55.9)
|
||||||
|
assertEquals(bufferedTensorArray[intArrayOf(0)], 1.0)
|
||||||
|
|
||||||
|
bufferedTensorArray[intArrayOf(0)] = 55.9
|
||||||
|
assertEquals(ndArray[intArrayOf(0)], 1.0)
|
||||||
|
|
||||||
ndList[intArrayOf(0)] = 55.9
|
|
||||||
assertEquals(doubleList[0], 55.9)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user