Buffer mini-protocol

This commit is contained in:
Roland Grinis 2021-04-24 18:53:21 +01:00
parent 287e2aeba2
commit 4f593aec63
5 changed files with 81 additions and 21 deletions

View File

@ -748,7 +748,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
}
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
public fun elements ()Lkotlin/sequences/Sequence;
public fun get ([I)Ljava/lang/Object;
@ -876,6 +876,12 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
public final class space/kscience/kmath/nd/GroupND$Companion {
}
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
public fun set ([ILjava/lang/Object;)V
}
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
public fun set ([ILjava/lang/Object;)V
}

View File

@ -7,6 +7,8 @@ package space.kscience.kmath.nd
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory
/**
* Represents [StructureND] over [Buffer].
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
* @param strides The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer.
*/
public class BufferND<T>(
public open class BufferND<T>(
public val strides: Strides,
public val buffer: Buffer<T>,
) : StructureND<T> {
@ -50,4 +52,35 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}
/**
* Represents [MutableStructureND] over [MutableBuffer].
*
* @param T the type of items.
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param mutableBuffer The underlying buffer.
*/
public class MutableBufferND<T>(
strides: Strides,
public val mutableBuffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value
}
}
/**
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
*/
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
crossinline transform: (T) -> R,
): MutableBufferND<R> {
return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
else {
val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.api.TensorStructure
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
@ -60,6 +61,7 @@ internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTens
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> BufferedTensor(this.shape, this.mutableBuffer, 0)
else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0)
}

View File

@ -7,35 +7,20 @@ import space.kscience.kmath.structures.*
import kotlin.math.*
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Int>.array(): IntArray = when (this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
else -> this.toIntArray()
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Long>.array(): LongArray = when (this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Float>.array(): FloatArray = when (this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
*/
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
else -> this.toDoubleArray()
}
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {

View File

@ -1,7 +1,11 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.test.Test
@ -47,4 +51,34 @@ class TestDoubleTensor {
assertEquals(secondRow[1], secondColumn[1])
}
}
@Test
fun bufferProtocol() {
// create buffers
val doubleBuffer = DoubleBuffer(doubleArrayOf(1.0,2.0,3.0))
val doubleList = MutableList(3, doubleBuffer::get)
// create ND buffers
val ndBuffer = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleBuffer)
val ndList = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleList.asMutableBuffer())
// map to tensors
val bufferedTensorBuffer = ndBuffer.toBufferedTensor() // strides are flipped
val tensorBuffer = bufferedTensorBuffer.asTensor() // no data copied
val bufferedTensorList = ndList.toBufferedTensor() // strides are flipped
val tensorList = bufferedTensorList.asTensor() // data copied
tensorBuffer[intArrayOf(0)] = 55.9
assertEquals(ndBuffer[intArrayOf(0)], 55.9)
assertEquals(doubleBuffer[0], 55.9)
tensorList[intArrayOf(0)] = 55.9
assertEquals(ndList[intArrayOf(0)], 1.0)
assertEquals(doubleList[0], 1.0)
ndList[intArrayOf(0)] = 55.9
assertEquals(doubleList[0], 55.9)
}
}