forked from kscience/kmath
Merge remote-tracking branch 'origin/dev' into even-more-docs
This commit is contained in:
commit
254ee9eced
@ -14,6 +14,7 @@
|
|||||||
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
||||||
- `Polynomial` secondary constructor made function.
|
- `Polynomial` secondary constructor made function.
|
||||||
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
||||||
|
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -5,15 +5,19 @@ import kscience.kmath.structures.NDField
|
|||||||
import kscience.kmath.structures.complex
|
import kscience.kmath.structures.complex
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
|
// 2d element
|
||||||
val element = NDElement.complex(2, 2) { index: IntArray ->
|
val element = NDElement.complex(2, 2) { index: IntArray ->
|
||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
}
|
}
|
||||||
|
println(element)
|
||||||
|
|
||||||
val compute = (NDField.complex(8)) {
|
// 1d element operation
|
||||||
|
val result = with(NDField.complex(8)) {
|
||||||
val a = produce { (it) -> i * it - it.toDouble() }
|
val a = produce { (it) -> i * it - it.toDouble() }
|
||||||
val b = 3
|
val b = 3
|
||||||
val c = Complex(1.0, 1.0)
|
val c = Complex(1.0, 1.0)
|
||||||
|
|
||||||
(a pow b) + c
|
(a pow b) + c
|
||||||
}
|
}
|
||||||
|
println(result)
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ kotlin.sourceSets {
|
|||||||
implementation("org.ow2.asm:asm:8.0.1")
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||||
implementation(kotlin("reflect"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import kscience.kmath.ast.MST
|
|||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -18,7 +17,8 @@ import kotlin.reflect.KClass
|
|||||||
* @return the compiled expression.
|
* @return the compiled expression.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
@PublishedApi
|
||||||
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||||
is MST.Symbolic -> {
|
is MST.Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
@ -61,11 +61,12 @@ public fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expr
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
||||||
|
mst.compileWith(T::class.java, this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimizes performance of an [MstExpression] using ASM codegen.
|
* Optimizes performance of an [MstExpression] using ASM codegen.
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
|
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
||||||
|
@ -10,7 +10,6 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import java.util.stream.Collectors
|
import java.util.stream.Collectors
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||||
@ -23,7 +22,7 @@ import kotlin.reflect.KClass
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T> internal constructor(
|
internal class AsmBuilder<T> internal constructor(
|
||||||
private val classOfT: KClass<*>,
|
private val classOfT: Class<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||||
@ -43,7 +42,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* ASM Type for [algebra].
|
* ASM Type for [algebra].
|
||||||
*/
|
*/
|
||||||
private val tAlgebraType: Type = algebra::class.asm
|
private val tAlgebraType: Type = algebra.javaClass.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T].
|
* ASM type for [T].
|
||||||
@ -55,16 +54,6 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
/**
|
|
||||||
* Index of `this` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private val invokeThisVar: Int = 0
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Index of `arguments` variable in invoke method of the built subclass.
|
|
||||||
*/
|
|
||||||
private val invokeArgumentsVar: Int = 1
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of constants to provide to the subclass.
|
* List of constants to provide to the subclass.
|
||||||
*/
|
*/
|
||||||
@ -76,22 +65,22 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if this [AsmBuilder] needs to generate constants field.
|
* States whether this [AsmBuilder] needs to generate constants field.
|
||||||
*/
|
*/
|
||||||
private var hasConstants: Boolean = true
|
private var hasConstants: Boolean = true
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
*/
|
*/
|
||||||
internal var primitiveMode: Boolean = false
|
internal var primitiveMode: Boolean = false
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
*/
|
*/
|
||||||
internal var primitiveMask: Type = OBJECT_TYPE
|
internal var primitiveMask: Type = OBJECT_TYPE
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||||
*/
|
*/
|
||||||
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||||
|
|
||||||
@ -103,7 +92,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Stack of useful objects types on stack expected by algebra calls.
|
* Stack of useful objects types on stack expected by algebra calls.
|
||||||
*/
|
*/
|
||||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque(listOf(tType))
|
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The cache for instance built by this builder.
|
* The cache for instance built by this builder.
|
||||||
@ -361,7 +350,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* from it).
|
* from it).
|
||||||
*/
|
*/
|
||||||
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||||
val boxed = value::class.asm
|
val boxed = value.javaClass.asm
|
||||||
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||||
|
|
||||||
if (primitive != null) {
|
if (primitive != null) {
|
||||||
@ -475,17 +464,27 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||||
|
|
||||||
internal companion object {
|
internal companion object {
|
||||||
|
/**
|
||||||
|
* Index of `this` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private const val invokeThisVar: Int = 0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index of `arguments` variable in invoke method of the built subclass.
|
||||||
|
*/
|
||||||
|
private const val invokeArgumentsVar: Int = 1
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||||
*/
|
*/
|
||||||
private val SIGNATURE_LETTERS: Map<KClass<out Any>, Type> by lazy {
|
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
java.lang.Byte::class to Type.BYTE_TYPE,
|
java.lang.Byte::class.java to Type.BYTE_TYPE,
|
||||||
java.lang.Short::class to Type.SHORT_TYPE,
|
java.lang.Short::class.java to Type.SHORT_TYPE,
|
||||||
java.lang.Integer::class to Type.INT_TYPE,
|
java.lang.Integer::class.java to Type.INT_TYPE,
|
||||||
java.lang.Long::class to Type.LONG_TYPE,
|
java.lang.Long::class.java to Type.LONG_TYPE,
|
||||||
java.lang.Float::class to Type.FLOAT_TYPE,
|
java.lang.Float::class.java to Type.FLOAT_TYPE,
|
||||||
java.lang.Double::class to Type.DOUBLE_TYPE
|
java.lang.Double::class.java to Type.DOUBLE_TYPE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,43 +522,43 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||||
*/
|
*/
|
||||||
private val INLINABLE_NUMBERS: Set<KClass<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [Expression].
|
* ASM type for [Expression].
|
||||||
*/
|
*/
|
||||||
internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm }
|
internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Number].
|
* ASM type for [java.lang.Number].
|
||||||
*/
|
*/
|
||||||
internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm }
|
internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.util.Map].
|
* ASM type for [java.util.Map].
|
||||||
*/
|
*/
|
||||||
internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm }
|
internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Object].
|
* ASM type for [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm }
|
internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for array of [java.lang.Object].
|
* ASM type for array of [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Array<java.lang.Object>::class.asm }
|
internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [Algebra].
|
* ASM type for [Algebra].
|
||||||
*/
|
*/
|
||||||
internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm }
|
internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.String].
|
* ASM type for [java.lang.String].
|
||||||
*/
|
*/
|
||||||
internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm }
|
internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for MapIntrinsics.
|
* ASM type for MapIntrinsics.
|
||||||
|
@ -10,9 +10,9 @@ import org.objectweb.asm.*
|
|||||||
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import java.lang.reflect.Method
|
import java.lang.reflect.Method
|
||||||
|
import java.util.*
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||||
hashMapOf(
|
hashMapOf(
|
||||||
@ -26,12 +26,12 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns ASM [Type] for given [KClass].
|
* Returns ASM [Type] for given [Class].
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal val KClass<*>.asm: Type
|
internal inline val Class<*>.asm: Type
|
||||||
get() = Type.getType(java)
|
get() = Type.getType(this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
@ -140,7 +140,7 @@ private fun <T> AsmBuilder<T>.buildExpectationStack(
|
|||||||
if (specific != null)
|
if (specific != null)
|
||||||
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||||
else
|
else
|
||||||
repeat(arity) { expectationStack.push(tType) }
|
expectationStack.addAll(Collections.nCopies(arity, tType))
|
||||||
|
|
||||||
return specific != null
|
return specific != null
|
||||||
}
|
}
|
||||||
@ -169,7 +169,7 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
|||||||
val arity = parameterTypes.size
|
val arity = parameterTypes.size
|
||||||
val theName = methodNameAdapters[name to arity] ?: name
|
val theName = methodNameAdapters[name to arity] ?: name
|
||||||
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||||
val owner = context::class.asm
|
val owner = context.javaClass.asm
|
||||||
|
|
||||||
invokeAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = owner.internalName,
|
owner = owner.internalName,
|
||||||
|
@ -7,6 +7,7 @@ import com.github.h0tk3y.betterParse.grammar.parser
|
|||||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
import com.github.h0tk3y.betterParse.lexer.Token
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.literalToken
|
||||||
import com.github.h0tk3y.betterParse.lexer.regexToken
|
import com.github.h0tk3y.betterParse.lexer.regexToken
|
||||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
@ -23,14 +24,14 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
|||||||
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||||
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
||||||
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
||||||
private val lpar: Token by regexToken("\\(")
|
private val lpar: Token by literalToken("(")
|
||||||
private val rpar: Token by regexToken("\\)")
|
private val rpar: Token by literalToken(")")
|
||||||
private val comma: Token by regexToken(",")
|
private val comma: Token by literalToken(",")
|
||||||
private val mul: Token by regexToken("\\*")
|
private val mul: Token by literalToken("*")
|
||||||
private val pow: Token by regexToken("\\^")
|
private val pow: Token by literalToken("^")
|
||||||
private val div: Token by regexToken("/")
|
private val div: Token by literalToken("/")
|
||||||
private val minus: Token by regexToken("-")
|
private val minus: Token by literalToken("-")
|
||||||
private val plus: Token by regexToken("\\+")
|
private val plus: Token by literalToken("+")
|
||||||
private val ws: Token by regexToken("\\s+", ignore = true)
|
private val ws: Token by regexToken("\\s+", ignore = true)
|
||||||
|
|
||||||
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
@ -177,6 +177,11 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
|
|||||||
|
|
||||||
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return "($re + i*$im)"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public companion object : MemorySpec<Complex> {
|
public companion object : MemorySpec<Complex> {
|
||||||
override val objectSize: Int
|
override val objectSize: Int
|
||||||
get() = 16
|
get() = 16
|
||||||
|
@ -64,7 +64,7 @@ public interface NDStructure<T> {
|
|||||||
public fun <T> build(
|
public fun <T> build(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
@ -73,40 +73,40 @@ public interface NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
public fun <T> build(
|
public fun <T> build(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
initializer: (IntArray) -> T
|
initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
): BufferNDStructure<T> = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
@JvmName("autoVarArg")
|
@JvmName("autoVarArg")
|
||||||
public inline fun <reified T : Any> auto(
|
public inline fun <reified T : Any> auto(
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(DefaultStrides(shape), initializer)
|
auto(DefaultStrides(shape), initializer)
|
||||||
|
|
||||||
public inline fun <T : Any> auto(
|
public inline fun <T : Any> auto(
|
||||||
type: KClass<T>,
|
type: KClass<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
crossinline initializer: (IntArray) -> T
|
crossinline initializer: (IntArray) -> T,
|
||||||
): BufferNDStructure<T> =
|
): BufferNDStructure<T> =
|
||||||
auto(type, DefaultStrides(shape), initializer)
|
auto(type, DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
@ -268,6 +268,22 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
result = 31 * result + buffer.hashCode()
|
result = 31 * result + buffer.hashCode()
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
val bufferRepr: String = when (shape.size) {
|
||||||
|
1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||||
|
2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
|
||||||
|
(0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j ->
|
||||||
|
val offset = strides.offset(intArrayOf(i, j))
|
||||||
|
buffer[offset].toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> "..."
|
||||||
|
}
|
||||||
|
return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -275,7 +291,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
|||||||
*/
|
*/
|
||||||
public class BufferNDStructure<T>(
|
public class BufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: Buffer<T>
|
override val buffer: Buffer<T>,
|
||||||
) : NDBuffer<T>() {
|
) : NDBuffer<T>() {
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
@ -289,7 +305,7 @@ public class BufferNDStructure<T>(
|
|||||||
*/
|
*/
|
||||||
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||||
factory: BufferFactory<R> = Buffer.Companion::auto,
|
factory: BufferFactory<R> = Buffer.Companion::auto,
|
||||||
crossinline transform: (T) -> R
|
crossinline transform: (T) -> R,
|
||||||
): BufferNDStructure<R> {
|
): BufferNDStructure<R> {
|
||||||
return if (this is BufferNDStructure<T>)
|
return if (this is BufferNDStructure<T>)
|
||||||
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||||
@ -304,7 +320,7 @@ public inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
*/
|
*/
|
||||||
public class MutableBufferNDStructure<T>(
|
public class MutableBufferNDStructure<T>(
|
||||||
override val strides: Strides,
|
override val strides: Strides,
|
||||||
override val buffer: MutableBuffer<T>
|
override val buffer: MutableBuffer<T>,
|
||||||
) : NDBuffer<T>(), MutableNDStructure<T> {
|
) : NDBuffer<T>(), MutableNDStructure<T> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
@ -318,7 +334,7 @@ public class MutableBufferNDStructure<T>(
|
|||||||
|
|
||||||
public inline fun <reified T : Any> NDStructure<T>.combine(
|
public inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
struct: NDStructure<T>,
|
struct: NDStructure<T>,
|
||||||
crossinline block: (T, T) -> T
|
crossinline block: (T, T) -> T,
|
||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
|
||||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||||
|
Loading…
Reference in New Issue
Block a user