Fixed tests with unsafe accessors
This commit is contained in:
parent
39a0889123
commit
04f6ef1ed0
@ -43,6 +43,14 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
|
||||
else -> FloatArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [FloatBuffer] over this array.
|
||||
*
|
||||
|
@ -42,6 +42,14 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
|
||||
else -> IntArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [IntBuffer] over this array.
|
||||
*
|
||||
|
@ -42,6 +42,14 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
|
||||
else -> LongArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [LongBuffer] over this array.
|
||||
*
|
||||
|
@ -47,6 +47,14 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
|
||||
else -> DoubleArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
||||
is RealBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [RealBuffer] over this array.
|
||||
*
|
||||
|
@ -1,7 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.kmath.structures.toIntArray
|
||||
import space.kscience.kmath.structures.unsafeToIntArray
|
||||
|
||||
public class RealLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||
@ -81,7 +80,7 @@ public class RealLinearOpsTensorAlgebra :
|
||||
// todo checks
|
||||
val n = lu.shape[0]
|
||||
val p = lu.zeroesLike()
|
||||
pivots.buffer.toIntArray().forEachIndexed { i, pivot ->
|
||||
pivots.buffer.unsafeToIntArray().forEachIndexed { i, pivot ->
|
||||
p[i, pivot] = 1.0
|
||||
}
|
||||
val l = lu.zeroesLike()
|
||||
|
@ -1,6 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||
|
||||
|
||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||
@ -9,7 +9,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||
}
|
||||
return this.buffer.toDoubleArray()[0]
|
||||
return this.buffer.unsafeToDoubleArray()[0]
|
||||
}
|
||||
|
||||
override fun zeros(shape: IntArray): RealTensor {
|
||||
@ -32,13 +32,13 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
|
||||
override fun RealTensor.copy(): RealTensor {
|
||||
// should be rework as soon as copy() method for NDBuffer will be available
|
||||
return RealTensor(this.shape, this.buffer.toDoubleArray().copyOf())
|
||||
return RealTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
|
||||
}
|
||||
|
||||
|
||||
override fun Double.plus(other: RealTensor): RealTensor {
|
||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||
other.buffer.toDoubleArray()[i] + this
|
||||
other.buffer.unsafeToDoubleArray()[i] + this
|
||||
}
|
||||
return RealTensor(other.shape, resBuffer)
|
||||
}
|
||||
@ -50,34 +50,34 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||
newThis.buffer.toDoubleArray()[i] + newOther.buffer.toDoubleArray()[i]
|
||||
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
return RealTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun RealTensor.plusAssign(value: Double) {
|
||||
for (i in this.buffer.toDoubleArray().indices) {
|
||||
this.buffer.toDoubleArray()[i] += value
|
||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||
this.buffer.unsafeToDoubleArray()[i] += value
|
||||
}
|
||||
}
|
||||
|
||||
override fun RealTensor.plusAssign(other: RealTensor) {
|
||||
//todo should be change with broadcasting
|
||||
for (i in this.buffer.toDoubleArray().indices) {
|
||||
this.buffer.toDoubleArray()[i] += other.buffer.toDoubleArray()[i]
|
||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||
this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Double.minus(other: RealTensor): RealTensor {
|
||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||
this - other.buffer.toDoubleArray()[i]
|
||||
this - other.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
return RealTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun RealTensor.minus(value: Double): RealTensor {
|
||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||
this.buffer.toDoubleArray()[i] - value
|
||||
this.buffer.unsafeToDoubleArray()[i] - value
|
||||
}
|
||||
return RealTensor(this.shape, resBuffer)
|
||||
}
|
||||
@ -87,14 +87,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||
newThis.buffer.toDoubleArray()[i] - newOther.buffer.toDoubleArray()[i]
|
||||
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
return RealTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun RealTensor.minusAssign(value: Double) {
|
||||
for (i in this.buffer.toDoubleArray().indices) {
|
||||
this.buffer.toDoubleArray()[i] -= value
|
||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||
this.buffer.unsafeToDoubleArray()[i] -= value
|
||||
}
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
override fun Double.times(other: RealTensor): RealTensor {
|
||||
//todo should be change with broadcasting
|
||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||
other.buffer.toDoubleArray()[i] * this
|
||||
other.buffer.unsafeToDoubleArray()[i] * this
|
||||
}
|
||||
return RealTensor(other.shape, resBuffer)
|
||||
}
|
||||
@ -116,28 +116,28 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
override fun RealTensor.times(other: RealTensor): RealTensor {
|
||||
//todo should be change with broadcasting
|
||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||
this.buffer.toDoubleArray()[i] * other.buffer.toDoubleArray()[i]
|
||||
this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
return RealTensor(this.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun RealTensor.timesAssign(value: Double) {
|
||||
//todo should be change with broadcasting
|
||||
for (i in this.buffer.toDoubleArray().indices) {
|
||||
this.buffer.toDoubleArray()[i] *= value
|
||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||
this.buffer.unsafeToDoubleArray()[i] *= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun RealTensor.timesAssign(other: RealTensor) {
|
||||
//todo should be change with broadcasting
|
||||
for (i in this.buffer.toDoubleArray().indices) {
|
||||
this.buffer.toDoubleArray()[i] *= other.buffer.toDoubleArray()[i]
|
||||
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||
this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun RealTensor.unaryMinus(): RealTensor {
|
||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||
this.buffer.toDoubleArray()[i].unaryMinus()
|
||||
this.buffer.unsafeToDoubleArray()[i].unaryMinus()
|
||||
}
|
||||
return RealTensor(this.shape, resBuffer)
|
||||
}
|
||||
@ -158,14 +158,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||
|
||||
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||
resTensor.buffer.toDoubleArray()[linearIndex] = this.buffer.toDoubleArray()[offset]
|
||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset]
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
|
||||
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||
return RealTensor(shape, this.buffer.toDoubleArray())
|
||||
return RealTensor(shape, this.buffer.unsafeToDoubleArray())
|
||||
}
|
||||
|
||||
override fun RealTensor.viewAs(other: RealTensor): RealTensor {
|
||||
|
@ -1,6 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTenso
|
||||
}
|
||||
|
||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||
resTensor.buffer.toDoubleArray()[linearIndex] = tensor.buffer.toDoubleArray()[curLinearIndex]
|
||||
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex]
|
||||
}
|
||||
res.add(resTensor)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user