forked from kscience/kmath
Buffer mini-protocol
This commit is contained in:
parent
287e2aeba2
commit
4f593aec63
@ -748,7 +748,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
|
||||
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
||||
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
||||
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
|
||||
public fun elements ()Lkotlin/sequences/Sequence;
|
||||
public fun get ([I)Ljava/lang/Object;
|
||||
@ -876,6 +876,12 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
|
||||
public final class space/kscience/kmath/nd/GroupND$Companion {
|
||||
}
|
||||
|
||||
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
|
||||
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
|
||||
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
|
||||
public fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
||||
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
|
||||
public fun set ([ILjava/lang/Object;)V
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
|
||||
/**
|
||||
* Represents [StructureND] over [Buffer].
|
||||
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
|
||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public class BufferND<T>(
|
||||
public open class BufferND<T>(
|
||||
public val strides: Strides,
|
||||
public val buffer: Buffer<T>,
|
||||
) : StructureND<T> {
|
||||
@ -50,4 +52,35 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||
val strides = DefaultStrides(shape)
|
||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [MutableStructureND] over [MutableBuffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
||||
* @param mutableBuffer The underlying buffer.
|
||||
*/
|
||||
public class MutableBufferND<T>(
|
||||
strides: Strides,
|
||||
public val mutableBuffer: MutableBuffer<T>,
|
||||
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
||||
override fun set(index: IntArray, value: T) {
|
||||
mutableBuffer[strides.offset(index)] = value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
||||
*/
|
||||
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
|
||||
crossinline transform: (T) -> R,
|
||||
): MutableBufferND<R> {
|
||||
return if (this is MutableBufferND<T>)
|
||||
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
|
||||
else {
|
||||
val strides = DefaultStrides(shape)
|
||||
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.MutableBufferND
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.tensors.api.TensorStructure
|
||||
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
|
||||
@ -60,6 +61,7 @@ internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTens
|
||||
|
||||
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||
is BufferedTensor<T> -> this
|
||||
is MutableBufferND<T> -> BufferedTensor(this.shape, this.mutableBuffer, 0)
|
||||
else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0)
|
||||
}
|
||||
|
||||
|
@ -7,35 +7,20 @@ import space.kscience.kmath.structures.*
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
|
||||
*/
|
||||
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
else -> this.toIntArray()
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Long>.array(): LongArray = when (this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Float>.array(): FloatArray = when (this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
|
||||
*/
|
||||
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||
is DoubleBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
else -> this.toDoubleArray()
|
||||
}
|
||||
|
||||
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||
|
@ -1,7 +1,11 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.DefaultStrides
|
||||
import space.kscience.kmath.nd.MutableBufferND
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.asMutableBuffer
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
|
||||
import kotlin.test.Test
|
||||
@ -47,4 +51,34 @@ class TestDoubleTensor {
|
||||
assertEquals(secondRow[1], secondColumn[1])
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun bufferProtocol() {
|
||||
|
||||
// create buffers
|
||||
val doubleBuffer = DoubleBuffer(doubleArrayOf(1.0,2.0,3.0))
|
||||
val doubleList = MutableList(3, doubleBuffer::get)
|
||||
|
||||
// create ND buffers
|
||||
val ndBuffer = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleBuffer)
|
||||
val ndList = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleList.asMutableBuffer())
|
||||
|
||||
// map to tensors
|
||||
val bufferedTensorBuffer = ndBuffer.toBufferedTensor() // strides are flipped
|
||||
val tensorBuffer = bufferedTensorBuffer.asTensor() // no data copied
|
||||
|
||||
val bufferedTensorList = ndList.toBufferedTensor() // strides are flipped
|
||||
val tensorList = bufferedTensorList.asTensor() // data copied
|
||||
|
||||
tensorBuffer[intArrayOf(0)] = 55.9
|
||||
assertEquals(ndBuffer[intArrayOf(0)], 55.9)
|
||||
assertEquals(doubleBuffer[0], 55.9)
|
||||
|
||||
tensorList[intArrayOf(0)] = 55.9
|
||||
assertEquals(ndList[intArrayOf(0)], 1.0)
|
||||
assertEquals(doubleList[0], 1.0)
|
||||
|
||||
ndList[intArrayOf(0)] = 55.9
|
||||
assertEquals(doubleList[0], 55.9)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user