forked from kscience/kmath
get dim 0 operator for tensors
This commit is contained in:
parent
7cb5cd8f71
commit
791f55ee8a
@ -72,7 +72,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
return Pair(luTensor, pivotsTensor)*/
|
return Pair(luTensor, pivotsTensor)*/
|
||||||
|
|
||||||
TODO("Andrei, first we need to view and get(Int)")
|
TODO("Andrei, use view, get, as2D, as1D")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
|
@ -7,11 +7,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
check(this.shape contentEquals intArrayOf(1)) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
return this.buffer.unsafeToDoubleArray()[this.bufferStart]
|
return this.buffer.array()[this.bufferStart]
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.get(i: Int): DoubleTensor {
|
override operator fun DoubleTensor.get(i: Int): DoubleTensor {
|
||||||
TODO("TOP PRIORITY")
|
val lastShape = this.shape.drop(1).toIntArray()
|
||||||
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
|
val newStart = newShape.reduce(Int::times) * i + this.bufferStart
|
||||||
|
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): DoubleTensor {
|
override fun zeros(shape: IntArray): DoubleTensor {
|
||||||
@ -43,13 +46,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.copy(): DoubleTensor {
|
override fun DoubleTensor.copy(): DoubleTensor {
|
||||||
return DoubleTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf(), this.bufferStart)
|
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Double.plus(other: DoubleTensor): DoubleTensor {
|
override fun Double.plus(other: DoubleTensor): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] + this
|
other.buffer.array()[other.bufferStart + i] + this
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -61,35 +64,35 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
|
newThis.buffer.array()[i] + newOther.buffer.array()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(value: Double) {
|
override fun DoubleTensor.plusAssign(value: Double) {
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] += value
|
this.buffer.array()[this.bufferStart + i] += value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] +=
|
this.buffer.array()[this.bufferStart + i] +=
|
||||||
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
|
other.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.minus(other: DoubleTensor): DoubleTensor {
|
override fun Double.minus(other: DoubleTensor): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
this - other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
|
this - other.buffer.array()[other.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minus(value: Double): DoubleTensor {
|
override fun DoubleTensor.minus(value: Double): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] - value
|
this.buffer.array()[this.bufferStart + i] - value
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -99,14 +102,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
|
newThis.buffer.array()[i] - newOther.buffer.array()[i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(newThis.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minusAssign(value: Double) {
|
override fun DoubleTensor.minusAssign(value: Double) {
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] -= value
|
this.buffer.array()[this.bufferStart + i] -= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +120,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
other.buffer.unsafeToDoubleArray()[other.bufferStart + i] * this
|
other.buffer.array()[other.bufferStart + i] * this
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -128,8 +131,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[other.bufferStart + i] *
|
this.buffer.array()[other.bufferStart + i] *
|
||||||
other.buffer.unsafeToDoubleArray()[other.bufferStart + i]
|
other.buffer.array()[other.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -137,21 +140,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.timesAssign(value: Double) {
|
override fun DoubleTensor.timesAssign(value: Double) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *= value
|
this.buffer.array()[this.bufferStart + i] *= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i] *=
|
this.buffer.array()[this.bufferStart + i] *=
|
||||||
other.buffer.unsafeToDoubleArray()[this.bufferStart + i]
|
other.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.unaryMinus(): DoubleTensor {
|
override fun DoubleTensor.unaryMinus(): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + i].unaryMinus()
|
this.buffer.array()[this.bufferStart + i].unaryMinus()
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -172,8 +175,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||||
|
|
||||||
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
|
resTensor.buffer.array()[linearIndex] =
|
||||||
this.buffer.unsafeToDoubleArray()[this.bufferStart + offset]
|
this.buffer.array()[this.bufferStart + offset]
|
||||||
}
|
}
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
@ -181,7 +184,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
|
|
||||||
override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
|
override fun DoubleTensor.view(shape: IntArray): DoubleTensor {
|
||||||
checkView(this, shape)
|
checkView(this, shape)
|
||||||
return DoubleTensor(shape, this.buffer.unsafeToDoubleArray(), this.bufferStart)
|
return DoubleTensor(shape, this.buffer.array(), this.bufferStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor {
|
||||||
|
@ -42,7 +42,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public operator fun TensorType.unaryMinus(): TensorType
|
public operator fun TensorType.unaryMinus(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||||
public fun TensorType.get(i: Int): TensorType
|
public operator fun TensorType.get(i: Int): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||||
|
@ -55,8 +55,8 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
|||||||
}
|
}
|
||||||
|
|
||||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] =
|
resTensor.buffer.array()[linearIndex] =
|
||||||
tensor.buffer.unsafeToDoubleArray()[tensor.bufferStart + curLinearIndex]
|
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
|
||||||
}
|
}
|
||||||
res.add(resTensor)
|
res.add(resTensor)
|
||||||
}
|
}
|
||||||
@ -99,7 +99,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
/**
|
/**
|
||||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
internal fun Buffer<Int>.array(): IntArray = when(this) {
|
||||||
is IntBuffer -> array
|
is IntBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||||
}
|
}
|
||||||
@ -107,7 +107,7 @@ internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
|||||||
/**
|
/**
|
||||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
internal fun Buffer<Long>.array(): LongArray = when(this) {
|
||||||
is LongBuffer -> array
|
is LongBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||||
}
|
}
|
||||||
@ -115,7 +115,7 @@ internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
|||||||
/**
|
/**
|
||||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
internal fun Buffer<Float>.array(): FloatArray = when(this) {
|
||||||
is FloatBuffer -> array
|
is FloatBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
|||||||
/**
|
/**
|
||||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
internal fun Buffer<Double>.array(): DoubleArray = when(this) {
|
||||||
is RealBuffer -> array
|
is RealBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||||
}
|
}
|
||||||
|
@ -21,4 +21,10 @@ class TestDoubleTensor {
|
|||||||
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
assertEquals(tensor[intArrayOf(0,1)], 5.8)
|
||||||
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
assertTrue(tensor.elements().map{ it.second }.toList().toDoubleArray() contentEquals tensor.buffer.toDoubleArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun getTest() = DoubleTensorAlgebra {
|
||||||
|
val tensor = DoubleTensor(intArrayOf(2,2), doubleArrayOf(3.5,5.8,58.4,2.4))
|
||||||
|
assertTrue(tensor[0].elements().map{ it.second }.toList().toDoubleArray() contentEquals doubleArrayOf(3.5,5.8))
|
||||||
|
}
|
||||||
}
|
}
|
@ -10,7 +10,7 @@ class TestDoubleTensorAlgebra {
|
|||||||
fun doublePlus() = DoubleTensorAlgebra {
|
fun doublePlus() = DoubleTensorAlgebra {
|
||||||
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
val tensor = DoubleTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
val res = 10.0 + tensor
|
val res = 10.0 + tensor
|
||||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0,12.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -18,7 +18,7 @@ class TestDoubleTensorAlgebra {
|
|||||||
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
val tensor = DoubleTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
val res = tensor.transpose(0, 0)
|
val res = tensor.transpose(0, 0)
|
||||||
|
|
||||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(0.0))
|
||||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class TestDoubleTensorAlgebra {
|
|||||||
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor = DoubleTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val res = tensor.transpose(1, 0)
|
val res = tensor.transpose(1, 0)
|
||||||
|
|
||||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,9 +42,9 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||||
|
|
||||||
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res01.buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res02.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -70,9 +70,9 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -82,14 +82,14 @@ class TestDoubleTensorAlgebra {
|
|||||||
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = DoubleTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
||||||
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
assertTrue((tensor2 - tensor1).buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
||||||
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
|
assertTrue((tensor3 - tensor1).buffer.array()
|
||||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
||||||
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
assertTrue((tensor3 - tensor2).buffer.array() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user