diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt index fd8bcef11..ef0d91aab 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt @@ -32,7 +32,7 @@ inline class Real(val value: Double) : FieldElement { /** * A field for double without boxing. Does not produce appropriate field element */ -object RealField : AbstractField(),ExtendedField, Norm { +object RealField : AbstractField(), ExtendedField, Norm { override val zero: Double = 0.0 override fun add(a: Double, b: Double): Double = a + b override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b @@ -52,15 +52,25 @@ object RealField : AbstractField(),ExtendedField, Norm { +object IntRing : Ring { override val zero: Int = 0 override fun add(a: Int, b: Int): Int = a + b override fun multiply(a: Int, b: Int): Int = a * b override fun multiply(a: Int, k: Double): Int = (k * a).toInt() override val one: Int = 1 - override fun divide(a: Int, b: Int): Int = a / b +} + +/** + * A field for [Short] without boxing. Does not produce appropriate field element + */ +object ShortRing : Ring { + override val zero: Short = 0 + override fun add(a: Short, b: Short): Short = (a + b).toShort() + override fun multiply(a: Short, b: Short): Short = (a * b).toShort() + override fun multiply(a: Short, k: Double): Short = (a * k).toShort() + override val one: Short = 1 } //interface FieldAdapter : Field { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index b859b6dd3..00f19ec7e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -32,7 +32,6 @@ interface NDField, N : NDStructure> : Field { * Create a nd-field for [Double] values */ fun real(shape: IntArray) = RealNDField(shape) - /** * Create a nd-field with boxing generic buffer */ @@ -42,12 +41,11 @@ interface NDField, N : NDStructure> : Field { /** * Create a most suitable implementation for nd-field using reified class. */ + @Suppress("UNCHECKED_CAST") inline fun > auto(shape: IntArray, field: F): StridedNDField = - if (T::class == Double::class) { - @Suppress("UNCHECKED_CAST") - real(shape) as StridedNDField - } else { - BufferNDField(shape, field, Buffer.Companion::auto) + when { + T::class == Double::class -> real(shape) as StridedNDField + else -> BufferNDField(shape, field, Buffer.Companion::auto) } } }