Fixed tests with unsafe accessors
This commit is contained in:
parent
39a0889123
commit
04f6ef1ed0
@ -43,6 +43,14 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
|
|||||||
else -> FloatArray(size, ::get)
|
else -> FloatArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
||||||
|
is FloatBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [FloatBuffer] over this array.
|
* Returns [FloatBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -42,6 +42,14 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
|
|||||||
else -> IntArray(size, ::get)
|
else -> IntArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
||||||
|
is IntBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [IntBuffer] over this array.
|
* Returns [IntBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -42,6 +42,14 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
|
|||||||
else -> LongArray(size, ::get)
|
else -> LongArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
||||||
|
is LongBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [LongBuffer] over this array.
|
* Returns [LongBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -47,6 +47,14 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
|
|||||||
else -> DoubleArray(size, ::get)
|
else -> DoubleArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
||||||
|
is RealBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [RealBuffer] over this array.
|
* Returns [RealBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.unsafeToIntArray
|
||||||
import space.kscience.kmath.structures.toIntArray
|
|
||||||
|
|
||||||
public class RealLinearOpsTensorAlgebra :
|
public class RealLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||||
@ -81,7 +80,7 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
// todo checks
|
// todo checks
|
||||||
val n = lu.shape[0]
|
val n = lu.shape[0]
|
||||||
val p = lu.zeroesLike()
|
val p = lu.zeroesLike()
|
||||||
pivots.buffer.toIntArray().forEachIndexed { i, pivot ->
|
pivots.buffer.unsafeToIntArray().forEachIndexed { i, pivot ->
|
||||||
p[i, pivot] = 1.0
|
p[i, pivot] = 1.0
|
||||||
}
|
}
|
||||||
val l = lu.zeroesLike()
|
val l = lu.zeroesLike()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||||
|
|
||||||
|
|
||||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
@ -9,7 +9,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
check(this.shape contentEquals intArrayOf(1)) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
return this.buffer.toDoubleArray()[0]
|
return this.buffer.unsafeToDoubleArray()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): RealTensor {
|
override fun zeros(shape: IntArray): RealTensor {
|
||||||
@ -32,13 +32,13 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
|
|
||||||
override fun RealTensor.copy(): RealTensor {
|
override fun RealTensor.copy(): RealTensor {
|
||||||
// should be rework as soon as copy() method for NDBuffer will be available
|
// should be rework as soon as copy() method for NDBuffer will be available
|
||||||
return RealTensor(this.shape, this.buffer.toDoubleArray().copyOf())
|
return RealTensor(this.shape, this.buffer.unsafeToDoubleArray().copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun Double.plus(other: RealTensor): RealTensor {
|
override fun Double.plus(other: RealTensor): RealTensor {
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
other.buffer.toDoubleArray()[i] + this
|
other.buffer.unsafeToDoubleArray()[i] + this
|
||||||
}
|
}
|
||||||
return RealTensor(other.shape, resBuffer)
|
return RealTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -50,34 +50,34 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||||
newThis.buffer.toDoubleArray()[i] + newOther.buffer.toDoubleArray()[i]
|
newThis.buffer.unsafeToDoubleArray()[i] + newOther.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return RealTensor(newThis.shape, resBuffer)
|
return RealTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.plusAssign(value: Double) {
|
override fun RealTensor.plusAssign(value: Double) {
|
||||||
for (i in this.buffer.toDoubleArray().indices) {
|
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||||
this.buffer.toDoubleArray()[i] += value
|
this.buffer.unsafeToDoubleArray()[i] += value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.plusAssign(other: RealTensor) {
|
override fun RealTensor.plusAssign(other: RealTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.toDoubleArray().indices) {
|
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||||
this.buffer.toDoubleArray()[i] += other.buffer.toDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[i] += other.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.minus(other: RealTensor): RealTensor {
|
override fun Double.minus(other: RealTensor): RealTensor {
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
this - other.buffer.toDoubleArray()[i]
|
this - other.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return RealTensor(other.shape, resBuffer)
|
return RealTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minus(value: Double): RealTensor {
|
override fun RealTensor.minus(value: Double): RealTensor {
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||||
this.buffer.toDoubleArray()[i] - value
|
this.buffer.unsafeToDoubleArray()[i] - value
|
||||||
}
|
}
|
||||||
return RealTensor(this.shape, resBuffer)
|
return RealTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -87,14 +87,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
val newThis = broadcast[0]
|
val newThis = broadcast[0]
|
||||||
val newOther = broadcast[1]
|
val newOther = broadcast[1]
|
||||||
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||||
newThis.buffer.toDoubleArray()[i] - newOther.buffer.toDoubleArray()[i]
|
newThis.buffer.unsafeToDoubleArray()[i] - newOther.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return RealTensor(newThis.shape, resBuffer)
|
return RealTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minusAssign(value: Double) {
|
override fun RealTensor.minusAssign(value: Double) {
|
||||||
for (i in this.buffer.toDoubleArray().indices) {
|
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||||
this.buffer.toDoubleArray()[i] -= value
|
this.buffer.unsafeToDoubleArray()[i] -= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
override fun Double.times(other: RealTensor): RealTensor {
|
override fun Double.times(other: RealTensor): RealTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
other.buffer.toDoubleArray()[i] * this
|
other.buffer.unsafeToDoubleArray()[i] * this
|
||||||
}
|
}
|
||||||
return RealTensor(other.shape, resBuffer)
|
return RealTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -116,28 +116,28 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
override fun RealTensor.times(other: RealTensor): RealTensor {
|
override fun RealTensor.times(other: RealTensor): RealTensor {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||||
this.buffer.toDoubleArray()[i] * other.buffer.toDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[i] * other.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
return RealTensor(this.shape, resBuffer)
|
return RealTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.timesAssign(value: Double) {
|
override fun RealTensor.timesAssign(value: Double) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.toDoubleArray().indices) {
|
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||||
this.buffer.toDoubleArray()[i] *= value
|
this.buffer.unsafeToDoubleArray()[i] *= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.timesAssign(other: RealTensor) {
|
override fun RealTensor.timesAssign(other: RealTensor) {
|
||||||
//todo should be change with broadcasting
|
//todo should be change with broadcasting
|
||||||
for (i in this.buffer.toDoubleArray().indices) {
|
for (i in this.buffer.unsafeToDoubleArray().indices) {
|
||||||
this.buffer.toDoubleArray()[i] *= other.buffer.toDoubleArray()[i]
|
this.buffer.unsafeToDoubleArray()[i] *= other.buffer.unsafeToDoubleArray()[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.unaryMinus(): RealTensor {
|
override fun RealTensor.unaryMinus(): RealTensor {
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||||
this.buffer.toDoubleArray()[i].unaryMinus()
|
this.buffer.unsafeToDoubleArray()[i].unaryMinus()
|
||||||
}
|
}
|
||||||
return RealTensor(this.shape, resBuffer)
|
return RealTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
@ -158,14 +158,14 @@ public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealT
|
|||||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
||||||
|
|
||||||
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
val linearIndex = resTensor.strides.offset(newMultiIndex)
|
||||||
resTensor.buffer.toDoubleArray()[linearIndex] = this.buffer.toDoubleArray()[offset]
|
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = this.buffer.unsafeToDoubleArray()[offset]
|
||||||
}
|
}
|
||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override fun RealTensor.view(shape: IntArray): RealTensor {
|
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||||
return RealTensor(shape, this.buffer.toDoubleArray())
|
return RealTensor(shape, this.buffer.unsafeToDoubleArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.viewAs(other: RealTensor): RealTensor {
|
override fun RealTensor.viewAs(other: RealTensor): RealTensor {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ internal inline fun broadcastTensors(vararg tensors: RealTensor): List<RealTenso
|
|||||||
}
|
}
|
||||||
|
|
||||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
resTensor.buffer.toDoubleArray()[linearIndex] = tensor.buffer.toDoubleArray()[curLinearIndex]
|
resTensor.buffer.unsafeToDoubleArray()[linearIndex] = tensor.buffer.unsafeToDoubleArray()[curLinearIndex]
|
||||||
}
|
}
|
||||||
res.add(resTensor)
|
res.add(resTensor)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user