Fixed tests with unsafe accessors

This commit is contained in:
Roland Grinis 2021-03-15 08:31:19 +00:00
parent 39a0889123
commit 04f6ef1ed0
7 changed files with 59 additions and 28 deletions

View File

@ -43,6 +43,14 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
else -> FloatArray(size, ::get)
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns [FloatBuffer] over this array.
*

View File

@ -42,6 +42,14 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
else -> IntArray(size, ::get)
}
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/**
* Returns [IntBuffer] over this array.
*

View File

@ -42,6 +42,14 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
else -> LongArray(size, ::get)
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns [LongBuffer] over this array.
*

View File

@ -47,6 +47,14 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
else -> DoubleArray(size, ::get)
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
/**
* Returns [RealBuffer] over this array.
*

View File

@ -1,7 +1,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.structures.toIntArray
import space.kscience.kmath.structures.unsafeToIntArray
public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>,
@ -81,7 +80,7 @@ public class RealLinearOpsTensorAlgebra :
// todo checks
val n = lu.shape[0]
val p = lu.zeroesLike()
pivots.buffer.toIntArray().forEachIndexed { i, pivot ->
pivots.buffer.unsafeToIntArray().forEachIndexed { i, pivot ->
p[i, pivot] = 1.0
}
val l = lu.zeroesLike()

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.structures.unsafeToDoubleArray
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
@ -9,7 +9,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}"
}
return this.buffer.toDoubleArray()[0]
return this.buffer.unsafeToDoubleArray()[0]
}
override fun zeros(shape: IntArray): RealTensor {
@ -32,13 +32,13 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun RealTensor.copy(): RealTensor {
// should be rework as soon as copy() method for NDBuffer will be available
return RealTensor(this.shape, this.buffer.toDoubleArray().copyOf())
return RealTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
}
override fun Double.plus(other: RealTensor): RealTensor {
val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.toDoubleArray()[i] + this
other.buffer.unsafeToDoubleArray()[i] + this
}
return RealTensor(other.shape, resBuffer)
}
@ -50,34 +50,34 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.toDoubleArray()[i] + newOther.buffer.toDoubleArray()[i]
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
}
return RealTensor(newThis.shape, resBuffer)
}
override fun RealTensor.plusAssign(value: Double) {
for (i in this.buffer.toDoubleArray().indices) {
this.buffer.toDoubleArray()[i] += value
for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.unsafeToDoubleArray()[i] += value
}
}
override fun RealTensor.plusAssign(other: RealTensor) {
//todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) {
this.buffer.toDoubleArray()[i] += other.buffer.toDoubleArray()[i]
for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i]
}
}
override fun Double.minus(other: RealTensor): RealTensor {
val resBuffer = DoubleArray(other.buffer.size) { i ->
this - other.buffer.toDoubleArray()[i]
this - other.buffer.unsafeToDoubleArray()[i]
}
return RealTensor(other.shape, resBuffer)
}
override fun RealTensor.minus(value: Double): RealTensor {
val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i] - value
this.buffer.unsafeToDoubleArray()[i] - value
}
return RealTensor(this.shape, resBuffer)
}
@ -87,14 +87,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.toDoubleArray()[i] - newOther.buffer.toDoubleArray()[i]
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
}
return RealTensor(newThis.shape, resBuffer)
}
override fun RealTensor.minusAssign(value: Double) {
for (i in this.buffer.toDoubleArray().indices) {
this.buffer.toDoubleArray()[i] -= value
for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.unsafeToDoubleArray()[i] -= value
}
}
@ -105,7 +105,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun Double.times(other: RealTensor): RealTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.toDoubleArray()[i] * this
other.buffer.unsafeToDoubleArray()[i] * this
}
return RealTensor(other.shape, resBuffer)
}
@ -116,28 +116,28 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun RealTensor.times(other: RealTensor): RealTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i] * other.buffer.toDoubleArray()[i]
this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i]
}
return RealTensor(this.shape, resBuffer)
}
override fun RealTensor.timesAssign(value: Double) {
//todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) {
this.buffer.toDoubleArray()[i] *= value
for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.unsafeToDoubleArray()[i] *= value
}
}
override fun RealTensor.timesAssign(other: RealTensor) {
//todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) {
this.buffer.toDoubleArray()[i] *= other.buffer.toDoubleArray()[i]
for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i]
}
}
override fun RealTensor.unaryMinus(): RealTensor {
val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i].unaryMinus()
this.buffer.unsafeToDoubleArray()[i].unaryMinus()
}
return RealTensor(this.shape, resBuffer)
}
@ -158,14 +158,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex)
resTensor.buffer.toDoubleArray()[linearIndex] = this.buffer.toDoubleArray()[offset]
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset]
}
return resTensor
}
override fun RealTensor.view(shape: IntArray): RealTensor {
return RealTensor(shape, this.buffer.toDoubleArray())
return RealTensor(shape, this.buffer.unsafeToDoubleArray())
}
override fun RealTensor.viewAs(other: RealTensor): RealTensor {

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.structures.unsafeToDoubleArray
import kotlin.math.max
@ -55,7 +55,7 @@ internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTenso
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.toDoubleArray()[linearIndex] = tensor.buffer.toDoubleArray()[curLinearIndex]
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex]
}
res.add(resTensor)
}