Buffer mini-protocol
This commit is contained in:
parent
287e2aeba2
commit
4f593aec63
@ -748,7 +748,7 @@ public final class space/kscience/kmath/nd/BufferAlgebraNDKt {
|
|||||||
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
|
public static final fun ring (Lspace/kscience/kmath/nd/AlgebraND$Companion;Lspace/kscience/kmath/operations/Ring;Lkotlin/jvm/functions/Function2;[I)Lspace/kscience/kmath/nd/BufferedRingND;
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
public class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/StructureND {
|
||||||
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
|
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/Buffer;)V
|
||||||
public fun elements ()Lkotlin/sequences/Sequence;
|
public fun elements ()Lkotlin/sequences/Sequence;
|
||||||
public fun get ([I)Ljava/lang/Object;
|
public fun get ([I)Ljava/lang/Object;
|
||||||
@ -876,6 +876,12 @@ public abstract interface class space/kscience/kmath/nd/GroupND : space/kscience
|
|||||||
public final class space/kscience/kmath/nd/GroupND$Companion {
|
public final class space/kscience/kmath/nd/GroupND$Companion {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/nd/MutableBufferND : space/kscience/kmath/nd/BufferND, space/kscience/kmath/nd/MutableStructureND {
|
||||||
|
public fun <init> (Lspace/kscience/kmath/nd/Strides;Lspace/kscience/kmath/structures/MutableBuffer;)V
|
||||||
|
public final fun getMutableBuffer ()Lspace/kscience/kmath/structures/MutableBuffer;
|
||||||
|
public fun set ([ILjava/lang/Object;)V
|
||||||
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
|
public abstract interface class space/kscience/kmath/nd/MutableStructure1D : space/kscience/kmath/nd/MutableStructureND, space/kscience/kmath/nd/Structure1D, space/kscience/kmath/structures/MutableBuffer {
|
||||||
public fun set ([ILjava/lang/Object;)V
|
public fun set ([ILjava/lang/Object;)V
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,8 @@ package space.kscience.kmath.nd
|
|||||||
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [StructureND] over [Buffer].
|
* Represents [StructureND] over [Buffer].
|
||||||
@ -15,7 +17,7 @@ import space.kscience.kmath.structures.BufferFactory
|
|||||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
* @param strides The strides to access elements of [Buffer] by linear indices.
|
||||||
* @param buffer The underlying buffer.
|
* @param buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public class BufferND<T>(
|
public open class BufferND<T>(
|
||||||
public val strides: Strides,
|
public val strides: Strides,
|
||||||
public val buffer: Buffer<T>,
|
public val buffer: Buffer<T>,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
@ -51,3 +53,34 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
|||||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [MutableStructureND] over [MutableBuffer].
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
||||||
|
* @param mutableBuffer The underlying buffer.
|
||||||
|
*/
|
||||||
|
public class MutableBufferND<T>(
|
||||||
|
strides: Strides,
|
||||||
|
public val mutableBuffer: MutableBuffer<T>,
|
||||||
|
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
mutableBuffer[strides.offset(index)] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
||||||
|
*/
|
||||||
|
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||||
|
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
|
||||||
|
crossinline transform: (T) -> R,
|
||||||
|
): MutableBufferND<R> {
|
||||||
|
return if (this is MutableBufferND<T>)
|
||||||
|
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
|
||||||
|
else {
|
||||||
|
val strides = DefaultStrides(shape)
|
||||||
|
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.MutableBufferND
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
import space.kscience.kmath.tensors.api.TensorStructure
|
import space.kscience.kmath.tensors.api.TensorStructure
|
||||||
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
|
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
|
||||||
@ -60,6 +61,7 @@ internal inline fun BufferedTensor<Double>.asTensor(): DoubleTensor = DoubleTens
|
|||||||
|
|
||||||
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
internal inline fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||||
is BufferedTensor<T> -> this
|
is BufferedTensor<T> -> this
|
||||||
|
is MutableBufferND<T> -> BufferedTensor(this.shape, this.mutableBuffer, 0)
|
||||||
else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0)
|
else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,35 +7,20 @@ import space.kscience.kmath.structures.*
|
|||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data.
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
internal fun Buffer<Int>.array(): IntArray = when (this) {
|
||||||
is IntBuffer -> array
|
is IntBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
else -> this.toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
internal fun Buffer<Long>.array(): LongArray = when (this) {
|
|
||||||
is LongBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data.
|
||||||
*/
|
|
||||||
internal fun Buffer<Float>.array(): FloatArray = when (this) {
|
|
||||||
is FloatBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
*/
|
||||||
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||||
is DoubleBuffer -> array
|
is DoubleBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
else -> this.toDoubleArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.DefaultStrides
|
||||||
|
import space.kscience.kmath.nd.MutableBufferND
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
|
import space.kscience.kmath.structures.asMutableBuffer
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
import space.kscience.kmath.structures.toDoubleArray
|
||||||
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
|
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -47,4 +51,34 @@ class TestDoubleTensor {
|
|||||||
assertEquals(secondRow[1], secondColumn[1])
|
assertEquals(secondRow[1], secondColumn[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun bufferProtocol() {
|
||||||
|
|
||||||
|
// create buffers
|
||||||
|
val doubleBuffer = DoubleBuffer(doubleArrayOf(1.0,2.0,3.0))
|
||||||
|
val doubleList = MutableList(3, doubleBuffer::get)
|
||||||
|
|
||||||
|
// create ND buffers
|
||||||
|
val ndBuffer = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleBuffer)
|
||||||
|
val ndList = MutableBufferND(DefaultStrides(intArrayOf(3)), doubleList.asMutableBuffer())
|
||||||
|
|
||||||
|
// map to tensors
|
||||||
|
val bufferedTensorBuffer = ndBuffer.toBufferedTensor() // strides are flipped
|
||||||
|
val tensorBuffer = bufferedTensorBuffer.asTensor() // no data copied
|
||||||
|
|
||||||
|
val bufferedTensorList = ndList.toBufferedTensor() // strides are flipped
|
||||||
|
val tensorList = bufferedTensorList.asTensor() // data copied
|
||||||
|
|
||||||
|
tensorBuffer[intArrayOf(0)] = 55.9
|
||||||
|
assertEquals(ndBuffer[intArrayOf(0)], 55.9)
|
||||||
|
assertEquals(doubleBuffer[0], 55.9)
|
||||||
|
|
||||||
|
tensorList[intArrayOf(0)] = 55.9
|
||||||
|
assertEquals(ndList[intArrayOf(0)], 1.0)
|
||||||
|
assertEquals(doubleList[0], 1.0)
|
||||||
|
|
||||||
|
ndList[intArrayOf(0)] = 55.9
|
||||||
|
assertEquals(doubleList[0], 55.9)
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user