refactor BT + docs

This commit is contained in:
Andrei Kislitsyn 2021-05-06 12:30:13 +03:00
parent 0680a3a1cb
commit 499cf85ff0
5 changed files with 136 additions and 90 deletions

View File

@ -1,116 +1,48 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
/**
public open class BufferedTensor<T>( * [Tensor] implementation provided with [MutableBuffer]
*/
public open class BufferedTensor<T> internal constructor(
override val shape: IntArray, override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>, internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int internal val bufferStart: Int
) : Tensor<T> { ) : Tensor<T> {
/**
* [TensorLinearStructure] with the same shape
*/
public val linearStructure: TensorLinearStructure public val linearStructure: TensorLinearStructure
get() = TensorLinearStructure(shape) get() = TensorLinearStructure(shape)
/**
* Number of elements in tensor
*/
public val numElements: Int public val numElements: Int
get() = linearStructure.linearSize get() = linearStructure.linearSize
/**
* @param index [IntArray] with size equal to tensor dimension
* @return the element by multidimensional index
*/
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)]
/**
* @param index the [IntArray] with size equal to tensor dimension
* @param value the value to set
*/
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[bufferStart + linearStructure.offset(index)] = value mutableBuffer[bufferStart + linearStructure.offset(index)] = value
} }
/**
* @return the sequence of pairs multidimensional indices and values
*/
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map {
it to this[it] it to this[it]
} }
} }
public class IntTensor internal constructor(
shape: IntArray,
buffer: IntArray,
offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)
public class DoubleTensor internal constructor(
shape: IntArray,
buffer: DoubleArray,
offset: Int = 0
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
override fun toString(): String = toPrettyString()
}
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor(
this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
)
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor()
}
internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor()
}
internal val Tensor<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
public fun Array<DoubleArray>.toDoubleTensor(): DoubleTensor {
val n = size
check(n > 0) { "An empty array cannot be casted to tensor" }
val m = first().size
check(m > 0) { "Inner arrays must have at least 1 argument" }
check(all { size == m }) { "Inner arrays must be the same size" }
val shape = intArrayOf(n, m)
val buffer = this.flatMap { arr -> arr.map { it } }.toDoubleArray()
return DoubleTensor(shape, buffer, 0)
}
public fun Array<IntArray>.toIntTensor(): IntTensor {
val n = size
check(n > 0) { "An empty array cannot be casted to tensor" }
val m = first().size
check(m > 0) { "Inner arrays must have at least 1 argument" }
check(all { size == m }) { "Inner arrays must be the same size" }
val shape = intArrayOf(n, m)
val buffer = this.flatMap { arr -> arr.map { it } }.toIntArray()
return IntTensor(shape, buffer, 0)
}
public fun DoubleTensor.toDoubleArray(): DoubleArray {
return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}
public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}

View File

@ -0,0 +1,19 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.DoubleBuffer
/**
* Default [BufferedTensor] implementation for [Double] values
*/
public class DoubleTensor internal constructor(
shape: IntArray,
buffer: DoubleArray,
offset: Int = 0
) : BufferedTensor<Double>(shape, DoubleBuffer(buffer), offset) {
override fun toString(): String = toPrettyString()
}

View File

@ -0,0 +1,17 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.IntBuffer
/**
* Default [BufferedTensor] implementation for [Int] values
*/
public class IntTensor internal constructor(
shape: IntArray,
buffer: IntArray,
offset: Int = 0
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset)

View File

@ -0,0 +1,36 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.Tensor
/**
* Casts [Tensor<Double>] to [DoubleTensor]
*/
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
/**
* Casts [Tensor<Int>] to [IntTensor]
*/
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
/**
* @return [DoubleArray] of tensor elements
*/
public fun DoubleTensor.toDoubleArray(): DoubleArray {
return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}
/**
* @return [IntArray] of tensor elements
*/
public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}

View File

@ -0,0 +1,42 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
internal fun BufferedTensor<Int>.asTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor(
this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
)
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor()
}
internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor()
}
internal val Tensor<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}