Fixed tests with unsafe accessors

This commit is contained in:
Roland Grinis 2021-03-15 08:31:19 +00:00
parent 39a0889123
commit 04f6ef1ed0
7 changed files with 59 additions and 28 deletions

View File

@ -43,6 +43,14 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
else -> FloatArray(size, ::get) else -> FloatArray(size, ::get)
} }
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/** /**
* Returns [FloatBuffer] over this array. * Returns [FloatBuffer] over this array.
* *

View File

@ -42,6 +42,14 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
else -> IntArray(size, ::get) else -> IntArray(size, ::get)
} }
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/** /**
* Returns [IntBuffer] over this array. * Returns [IntBuffer] over this array.
* *

View File

@ -42,6 +42,14 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
else -> LongArray(size, ::get) else -> LongArray(size, ::get)
} }
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/** /**
* Returns [LongBuffer] over this array. * Returns [LongBuffer] over this array.
* *

View File

@ -47,6 +47,14 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
else -> DoubleArray(size, ::get) else -> DoubleArray(size, ::get)
} }
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
/** /**
* Returns [RealBuffer] over this array. * Returns [RealBuffer] over this array.
* *

View File

@ -1,7 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.unsafeToIntArray
import space.kscience.kmath.structures.toIntArray
public class RealLinearOpsTensorAlgebra : public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>, LinearOpsTensorAlgebra<Double, RealTensor>,
@ -81,7 +80,7 @@ public class RealLinearOpsTensorAlgebra :
// todo checks // todo checks
val n = lu.shape[0] val n = lu.shape[0]
val p = lu.zeroesLike() val p = lu.zeroesLike()
pivots.buffer.toIntArray().forEachIndexed { i, pivot -> pivots.buffer.unsafeToIntArray().forEachIndexed { i, pivot ->
p[i, pivot] = 1.0 p[i, pivot] = 1.0
} }
val l = lu.zeroesLike() val l = lu.zeroesLike()

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.unsafeToDoubleArray
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
@ -9,7 +9,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
check(this.shape contentEquals intArrayOf(1)) { check(this.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
return this.buffer.toDoubleArray()[0] return this.buffer.unsafeToDoubleArray()[0]
} }
override fun zeros(shape: IntArray): RealTensor { override fun zeros(shape: IntArray): RealTensor {
@ -32,13 +32,13 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun RealTensor.copy(): RealTensor { override fun RealTensor.copy(): RealTensor {
// should be rework as soon as copy() method for NDBuffer will be available // should be rework as soon as copy() method for NDBuffer will be available
return RealTensor(this.shape, this.buffer.toDoubleArray().copyOf()) return RealTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
} }
override fun Double.plus(other: RealTensor): RealTensor { override fun Double.plus(other: RealTensor): RealTensor {
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.toDoubleArray()[i] + this other.buffer.unsafeToDoubleArray()[i] + this
} }
return RealTensor(other.shape, resBuffer) return RealTensor(other.shape, resBuffer)
} }
@ -50,34 +50,34 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i -> val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.toDoubleArray()[i] + newOther.buffer.toDoubleArray()[i] newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
} }
return RealTensor(newThis.shape, resBuffer) return RealTensor(newThis.shape, resBuffer)
} }
override fun RealTensor.plusAssign(value: Double) { override fun RealTensor.plusAssign(value: Double) {
for (i in this.buffer.toDoubleArray().indices) { for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.toDoubleArray()[i] += value this.buffer.unsafeToDoubleArray()[i] += value
} }
} }
override fun RealTensor.plusAssign(other: RealTensor) { override fun RealTensor.plusAssign(other: RealTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) { for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.toDoubleArray()[i] += other.buffer.toDoubleArray()[i] this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i]
} }
} }
override fun Double.minus(other: RealTensor): RealTensor { override fun Double.minus(other: RealTensor): RealTensor {
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.buffer.size) { i ->
this - other.buffer.toDoubleArray()[i] this - other.buffer.unsafeToDoubleArray()[i]
} }
return RealTensor(other.shape, resBuffer) return RealTensor(other.shape, resBuffer)
} }
override fun RealTensor.minus(value: Double): RealTensor { override fun RealTensor.minus(value: Double): RealTensor {
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i] - value this.buffer.unsafeToDoubleArray()[i] - value
} }
return RealTensor(this.shape, resBuffer) return RealTensor(this.shape, resBuffer)
} }
@ -87,14 +87,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i -> val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.toDoubleArray()[i] - newOther.buffer.toDoubleArray()[i] newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
} }
return RealTensor(newThis.shape, resBuffer) return RealTensor(newThis.shape, resBuffer)
} }
override fun RealTensor.minusAssign(value: Double) { override fun RealTensor.minusAssign(value: Double) {
for (i in this.buffer.toDoubleArray().indices) { for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.toDoubleArray()[i] -= value this.buffer.unsafeToDoubleArray()[i] -= value
} }
} }
@ -105,7 +105,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun Double.times(other: RealTensor): RealTensor { override fun Double.times(other: RealTensor): RealTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.toDoubleArray()[i] * this other.buffer.unsafeToDoubleArray()[i] * this
} }
return RealTensor(other.shape, resBuffer) return RealTensor(other.shape, resBuffer)
} }
@ -116,28 +116,28 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
override fun RealTensor.times(other: RealTensor): RealTensor { override fun RealTensor.times(other: RealTensor): RealTensor {
//todo should be change with broadcasting //todo should be change with broadcasting
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i] * other.buffer.toDoubleArray()[i] this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i]
} }
return RealTensor(this.shape, resBuffer) return RealTensor(this.shape, resBuffer)
} }
override fun RealTensor.timesAssign(value: Double) { override fun RealTensor.timesAssign(value: Double) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) { for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.toDoubleArray()[i] *= value this.buffer.unsafeToDoubleArray()[i] *= value
} }
} }
override fun RealTensor.timesAssign(other: RealTensor) { override fun RealTensor.timesAssign(other: RealTensor) {
//todo should be change with broadcasting //todo should be change with broadcasting
for (i in this.buffer.toDoubleArray().indices) { for (i in this.buffer.unsafeToDoubleArray().indices) {
this.buffer.toDoubleArray()[i] *= other.buffer.toDoubleArray()[i] this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i]
} }
} }
override fun RealTensor.unaryMinus(): RealTensor { override fun RealTensor.unaryMinus(): RealTensor {
val resBuffer = DoubleArray(this.buffer.size) { i -> val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.toDoubleArray()[i].unaryMinus() this.buffer.unsafeToDoubleArray()[i].unaryMinus()
} }
return RealTensor(this.shape, resBuffer) return RealTensor(this.shape, resBuffer)
} }
@ -158,14 +158,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] } newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
val linearIndex = resTensor.strides.offset(newMultiIndex) val linearIndex = resTensor.strides.offset(newMultiIndex)
resTensor.buffer.toDoubleArray()[linearIndex] = this.buffer.toDoubleArray()[offset] resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset]
} }
return resTensor return resTensor
} }
override fun RealTensor.view(shape: IntArray): RealTensor { override fun RealTensor.view(shape: IntArray): RealTensor {
return RealTensor(shape, this.buffer.toDoubleArray()) return RealTensor(shape, this.buffer.unsafeToDoubleArray())
} }
override fun RealTensor.viewAs(other: RealTensor): RealTensor { override fun RealTensor.viewAs(other: RealTensor): RealTensor {

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.unsafeToDoubleArray
import kotlin.math.max import kotlin.math.max
@ -55,7 +55,7 @@ internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTenso
} }
val curLinearIndex = tensor.strides.offset(curMultiIndex) val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.toDoubleArray()[linearIndex] = tensor.buffer.toDoubleArray()[curLinearIndex] resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex]
} }
res.add(resTensor) res.add(resTensor)
} }