More informative toString for NDBuffer and Complex

This commit is contained in:
Alexander Nozik 2020-09-30 12:30:06 +03:00
parent 431b574544
commit 049ac89667
3 changed files with 37 additions and 12 deletions

View File

@ -5,15 +5,19 @@ import kscience.kmath.structures.NDField
import kscience.kmath.structures.complex import kscience.kmath.structures.complex
fun main() { fun main() {
// 2d element
val element = NDElement.complex(2, 2) { index: IntArray -> val element = NDElement.complex(2, 2) { index: IntArray ->
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
} }
println(element)
val compute = (NDField.complex(8)) { // 1d element operation
val result = with(NDField.complex(8)) {
val a = produce { (it) -> i * it - it.toDouble() } val a = produce { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)
(a pow b) + c (a pow b) + c
} }
println(result)
} }

View File

@ -176,6 +176,11 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
override fun compareTo(other: Complex): Int = r.compareTo(other.r) override fun compareTo(other: Complex): Int = r.compareTo(other.r)
override fun toString(): String {
return "($re + i*$im)"
}
public companion object : MemorySpec<Complex> { public companion object : MemorySpec<Complex> {
override val objectSize: Int = 16 override val objectSize: Int = 16

View File

@ -70,7 +70,7 @@ public interface NDStructure<T> {
public fun <T> build( public fun <T> build(
strides: Strides, strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
@ -79,40 +79,40 @@ public interface NDStructure<T> {
*/ */
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
public inline fun <T : Any> auto( public inline fun <T : Any> auto(
type: KClass<T>, type: KClass<T>,
strides: Strides, strides: Strides,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
public fun <T> build( public fun <T> build(
shape: IntArray, shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T initializer: (IntArray) -> T,
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer) ): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
shape: IntArray, shape: IntArray,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
@JvmName("autoVarArg") @JvmName("autoVarArg")
public inline fun <reified T : Any> auto( public inline fun <reified T : Any> auto(
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
auto(DefaultStrides(shape), initializer) auto(DefaultStrides(shape), initializer)
public inline fun <T : Any> auto( public inline fun <T : Any> auto(
type: KClass<T>, type: KClass<T>,
vararg shape: Int, vararg shape: Int,
crossinline initializer: (IntArray) -> T crossinline initializer: (IntArray) -> T,
): BufferNDStructure<T> = ): BufferNDStructure<T> =
auto(type, DefaultStrides(shape), initializer) auto(type, DefaultStrides(shape), initializer)
} }
@ -274,6 +274,22 @@ public abstract class NDBuffer<T> : NDStructure<T> {
result = 31 * result + buffer.hashCode() result = 31 * result + buffer.hashCode()
return result return result
} }
override fun toString(): String {
val bufferRepr: String = when (shape.size) {
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
(0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j ->
val offset = strides.offset(intArrayOf(i, j))
buffer[offset].toString()
}
}
else -> "..."
}
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
}
} }
/** /**
@ -281,7 +297,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
*/ */
public class BufferNDStructure<T>( public class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>,
) : NDBuffer<T>() { ) : NDBuffer<T>() {
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {
@ -295,7 +311,7 @@ public class BufferNDStructure<T>(
*/ */
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer( public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
factory: BufferFactory<R> = Buffer.Companion::auto, factory: BufferFactory<R> = Buffer.Companion::auto,
crossinline transform: (T) -> R crossinline transform: (T) -> R,
): BufferNDStructure<R> { ): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) return if (this is BufferNDStructure<T>)
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
@ -310,7 +326,7 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
*/ */
public class MutableBufferNDStructure<T>( public class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T> override val buffer: MutableBuffer<T>,
) : NDBuffer<T>(), MutableNDStructure<T> { ) : NDBuffer<T>(), MutableNDStructure<T> {
init { init {
@ -324,7 +340,7 @@ public class MutableBufferNDStructure<T>(
public inline fun <reified T : Any> NDStructure<T>.combine( public inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T crossinline block: (T, T) -> T,
): NDStructure<T> { ): NDStructure<T> {
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" } require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
return NDStructure.auto(shape) { block(this[it], struct[it]) } return NDStructure.auto(shape) { block(this[it], struct[it]) }