Optimize RealMatrix dot operation

This commit is contained in:
Alexander Nozik 2021-01-18 21:33:53 +03:00
parent 758508ba96
commit 4635080317
4 changed files with 8 additions and 4 deletions

View File

@ -45,14 +45,14 @@ fun main() {
measureAndPrint("Specialized addition") { measureAndPrint("Specialized addition") {
specializedField { specializedField {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { res += 1.0 } repeat(n) { res += one }
} }
} }
measureAndPrint("Nd4j specialized addition") { measureAndPrint("Nd4j specialized addition") {
nd4jField { nd4jField {
var res = one var res = one
repeat(n) { res += 1.0 as Number } repeat(n) { res += one }
} }
} }

View File

@ -61,7 +61,7 @@ public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> = override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> =
produce(a.rowNum, a.colNum) { i, j -> a.get(i, j) * k.toDouble() } produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() }
} }

View File

@ -38,6 +38,7 @@ public interface NDStructure<T> {
*/ */
public fun elements(): Sequence<Pair<IntArray, T>> public fun elements(): Sequence<Pair<IntArray, T>>
//force override equality and hash code
public override fun equals(other: Any?): Boolean public override fun equals(other: Any?): Boolean
public override fun hashCode(): Int public override fun hashCode(): Int
@ -133,6 +134,9 @@ public interface MutableNDStructure<T> : NDStructure<T> {
public operator fun set(index: IntArray, value: T) public operator fun set(index: IntArray, value: T)
} }
/**
* Transform a structure element-by element in place.
*/
public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit = public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }

View File

@ -8,7 +8,7 @@ pluginManagement {
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
} }
val toolsVersion = "0.7.1" val toolsVersion = "0.7.2-dev-2"
val kotlinVersion = "1.4.21" val kotlinVersion = "1.4.21"
plugins { plugins {