feat(Core): add new flavors of permSort: allow user to specify a comparator (sort with) or a custom field to use in buffer values (sort by).

This commit is contained in:
Alexis Manin 2021-11-18 17:44:53 +01:00
parent 0f7a25762e
commit 06a6a99ef0
2 changed files with 61 additions and 15 deletions

View File

@ -6,23 +6,38 @@
package space.kscience.kmath.misc package space.kscience.kmath.misc
import kotlin.comparisons.* import kotlin.comparisons.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.indices
/** /**
* Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value. * Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value.
* This feature allows to sort buffer values without reordering its content. * This feature allows to sort buffer values without reordering its content.
* *
* @param descending True to revert sort order from highest to lowest values. Default to ascending order.
* @return List of buffer indices, sorted by associated value. * @return List of buffer indices, sorted by associated value.
*/ */
@PerformancePitfall @PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSort(descending : Boolean = false) : IntArray { public fun <V: Comparable<V>> Buffer<V>.permSort() : IntArray = _permSortWith(compareBy<Int> { get(it) })
if (size < 2) return IntArray(size)
val comparator = if (descending) compareByDescending<Int> { get(it) } else compareBy<Int> { get(it) } @PerformancePitfall
@UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSortDescending() : IntArray = _permSortWith(compareByDescending<Int> { get(it) })
@PerformancePitfall
@UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortBy(selector: (V) -> C) : IntArray = _permSortWith(compareBy<Int> { selector(get(it)) })
@PerformancePitfall
@UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortByDescending(selector: (V) -> C) : IntArray = _permSortWith(compareByDescending<Int> { selector(get(it)) })
@PerformancePitfall
@UnstableKMathAPI
public fun <V> Buffer<V>.permSortWith(comparator : Comparator<V>) : IntArray = _permSortWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
@PerformancePitfall
@UnstableKMathAPI
private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray {
if (size < 2) return IntArray(size)
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice /* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire * arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
@ -31,10 +46,10 @@ public fun <V: Comparable<V>> Buffer<V>.permSort(descending : Boolean = false) :
* 2. Some benchmark would be needed to ensure it would really provide better performance * 2. Some benchmark would be needed to ensure it would really provide better performance
*/ */
val packedIndices = IntArray(size) { idx -> idx } val packedIndices = IntArray(size) { idx -> idx }
/* TODO: find an efficient way to sort in-place instead, and return directly the IntArray. /* TODO: find an efficient way to sort in-place instead, and return directly the IntArray.
* Not done for now, because no standard utility is provided yet. An open issue exists for this. * Not done for now, because no standard utility is provided yet. An open issue exists for this.
* See: https://youtrack.jetbrains.com/issue/KT-37860 * See: https://youtrack.jetbrains.com/issue/KT-37860
*/ */
return packedIndices.sortedWith(comparator).toIntArray() return packedIndices.sortedWith(comparator).toIntArray()
} }

View File

@ -5,17 +5,24 @@
package space.kscience.kmath.misc package space.kscience.kmath.misc
import kotlin.collections.mutableListOf import space.kscience.kmath.misc.PermSortTest.Platform.*
import kotlin.random.Random import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
import kotlin.test.fail
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.structures.asBuffer
import kotlin.test.assertContentEquals
class PermSortTest { class PermSortTest {
private enum class Platform {
ANDROID, JVM, JS, NATIVE, WASM
}
private val platforms = Platform.values().asBuffer()
/** /**
* Permutation on empty buffer should immediately return an empty array. * Permutation on empty buffer should immediately return an empty array.
*/ */
@ -24,7 +31,7 @@ class PermSortTest {
val emptyBuffer = IntBuffer(0) {it} val emptyBuffer = IntBuffer(0) {it}
var permutations = emptyBuffer.permSort() var permutations = emptyBuffer.permSort()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
permutations = emptyBuffer.permSort(true) permutations = emptyBuffer.permSortDescending()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
} }
@ -34,10 +41,34 @@ class PermSortTest {
} }
@Test @Test
public fun testOnSomeValues() { fun testOnSomeValues() {
testPermutation(10) testPermutation(10)
} }
@Test
fun testPermSortBy() {
val permutations = platforms.permSortBy { it.name }
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
}
@Test
fun testPermSortByDescending() {
val permutations = platforms.permSortByDescending { it.name }
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
}
@Test
fun testPermSortWith() {
var permutations = platforms.permSortWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
permutations = platforms.permSortWith(compareByDescending { it.name.length })
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
}
private fun testPermutation(bufferSize: Int) { private fun testPermutation(bufferSize: Int) {
val seed = Random.nextLong() val seed = Random.nextLong()
@ -56,7 +87,7 @@ class PermSortTest {
assertTrue(current <= next, "Permutation indices not properly sorted") assertTrue(current <= next, "Permutation indices not properly sorted")
} }
val descIndices = buffer.permSort(true) val descIndices = buffer.permSortDescending()
assertEquals(bufferSize, descIndices.size) assertEquals(bufferSize, descIndices.size)
// Ensure no doublon is present in indices // Ensure no doublon is present in indices
assertEquals(descIndices.toSet().size, descIndices.size) assertEquals(descIndices.toSet().size, descIndices.size)