Int Tensor Algebra implementation

This commit is contained in:
Alexander Nozik 2022-09-05 16:30:39 +03:00
parent ad97751327
commit 5042fda751
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
17 changed files with 593 additions and 36 deletions

View File

@ -14,7 +14,7 @@ allprojects {
}
group = "space.kscience"
version = "0.3.1-dev-2"
version = "0.3.1-dev-3"
}
subprojects {

View File

@ -1,8 +1,7 @@
plugins {
kotlin("jvm") version "1.7.20-Beta"
`kotlin-dsl`
`version-catalog`
alias(npmlibs.plugins.kotlin.plugin.serialization)
kotlin("plugin.serialization") version "1.6.21"
}
java.targetCompatibility = JavaVersion.VERSION_11

View File

@ -56,7 +56,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param buffer The underlying buffer.
*/
public class MutableBufferND<T>(
public open class MutableBufferND<T>(
strides: ShapeIndexer,
override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {

View File

@ -16,7 +16,7 @@ import kotlin.math.pow as kpow
public class DoubleBufferND(
indexes: ShapeIndexer,
override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer)
) : MutableBufferND<Double>(indexes, buffer)
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),

View File

@ -0,0 +1,50 @@
/*
* Copyright 2018-2022 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.IntBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
public class IntBufferND(
indexes: ShapeIndexer,
override val buffer: IntBuffer,
) : MutableBufferND<Int>(indexes, buffer)
public sealed class IntRingOpsND : BufferedRingOpsND<Int, IntRing>(IntRing.bufferAlgebra) {
override fun structureND(shape: Shape, initializer: IntRing.(IntArray) -> Int): IntBufferND {
val indexer = indexerBuilder(shape)
return IntBufferND(
indexer,
IntBuffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
}
)
}
public companion object : IntRingOpsND()
}
@OptIn(UnstableKMathAPI::class)
public class IntRingND(
override val shape: Shape
) : IntRingOpsND(), RingND<Int, IntRing>, NumbersAddOps<StructureND<Int>> {
override fun number(value: Number): BufferND<Int> {
val int = value.toInt() // minimize conversions
return structureND(shape) { int }
}
}
public inline fun <R> IntRing.withNdAlgebra(vararg shape: Int, action: IntRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return IntRingND(shape).run(action)
}

View File

@ -22,8 +22,9 @@ public class ShortRingND(
) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
override fun number(value: Number): BufferND<Short> {
val d = value.toShort() // minimize conversions
return structureND(shape) { d }
val short
= value.toShort() // minimize conversions
return structureND(shape) { short }
}
}

View File

@ -82,7 +82,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
public fun contentEquals(
st1: StructureND<Double>,
st2: StructureND<Double>,
tolerance: Double = 1e-11
tolerance: Double = 1e-11,
): Boolean {
if (st1 === st2) return true
@ -101,11 +101,17 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
val bufferRepr: String = when (structure.shape.size) {
1 -> (0 until structure.shape[0]).map { structure[it] }
.joinToString(prefix = "[", postfix = "]", separator = ", ")
2 -> (0 until structure.shape[0]).joinToString(prefix = "[\n", postfix = "\n]", separator = ",\n") { i ->
2 -> (0 until structure.shape[0]).joinToString(
prefix = "[\n",
postfix = "\n]",
separator = ",\n"
) { i ->
(0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
structure[i, j].toString()
}
}
else -> "..."
}
val className = structure::class.simpleName ?: "StructureND"
@ -226,6 +232,13 @@ public interface MutableStructureND<T> : StructureND<T> {
public operator fun set(index: IntArray, value: T)
}
/**
* Set value at specified indices
*/
public operator fun <T> MutableStructureND<T>.set(vararg index: Int, value: T) {
set(index, value)
}
/**
* Transform a structure element-by element in place.
*/

View File

@ -142,6 +142,9 @@ public open class BufferRingOps<T, A : Ring<T>>(
super<BufferAlgebra>.binaryOperationFunction(operation)
}
public val IntRing.bufferAlgebra: BufferRingOps<Int, IntRing>
get() = BufferRingOps(IntRing)
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
get() = BufferRingOps(ShortRing)

View File

@ -4,18 +4,12 @@ plugins {
kscience{
native()
}
kotlin.sourceSets.commonMain {
withContextReceivers()
dependencies{
api(projects.kmath.kmathComplex)
}
}
kscience {
withContextReceivers()
}
readme {
maturity = space.kscience.gradle.Maturity.PROTOTYPE
}

View File

@ -4,6 +4,10 @@ plugins {
kscience{
native()
dependencies {
api(projects.kmathCore)
api(projects.kmathStat)
}
}
kotlin.sourceSets {

View File

@ -31,8 +31,7 @@ public open class DoubleTensorAlgebra :
public companion object : DoubleTensorAlgebra()
override val elementAlgebra: DoubleField
get() = DoubleField
override val elementAlgebra: DoubleField get() = DoubleField
/**
@ -622,7 +621,8 @@ public open class DoubleTensorAlgebra :
}
val resNumElements = resShape.reduce(Int::times)
val init = foldFunction(DoubleArray(1) { 0.0 })
val resTensor = BufferedTensor(resShape,
val resTensor = BufferedTensor(
resShape,
MutableBuffer.auto(resNumElements) { init }, 0
)
for (index in resTensor.indices) {

View File

@ -11,10 +11,10 @@ import space.kscience.kmath.tensors.core.internal.array
/**
* Default [BufferedTensor] implementation for [Int] values
*/
public class IntTensor internal constructor(
public class IntTensor @PublishedApi internal constructor(
shape: IntArray,
buffer: IntArray,
offset: Int = 0
offset: Int = 0,
) : BufferedTensor<Int>(shape, IntBuffer(buffer), offset) {
public fun asDouble(): DoubleTensor =
DoubleTensor(shape, mutableBuffer.array().map { it.toDouble() }.toDoubleArray(), bufferStart)

View File

@ -0,0 +1,493 @@
/*
* Copyright 2018-2022 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:OptIn(PerformancePitfall::class)
package space.kscience.kmath.tensors.core
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.tensors.api.*
import space.kscience.kmath.tensors.core.internal.*
import kotlin.math.*
/**
* Implementation of basic operations over double tensors and basic algebra operations on them.
*/
public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
public companion object : IntTensorAlgebra()
override fun StructureND<Int>.dot(other: StructureND<Int>): Tensor<Int> {
TODO("Not yet implemented")
}
override val elementAlgebra: IntRing get() = IntRing
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
@PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Int>.map(transform: IntRing.(Int) -> Int): IntTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = IntArray(tensor.numElements) { IntRing.transform(sourceArray[it]) }
return IntTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
@PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Int>.mapIndexed(transform: IntRing.(index: IntArray, Int) -> Int): IntTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = IntArray(tensor.numElements) { IntRing.transform(tensor.indices.index(it), sourceArray[it]) }
return IntTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
@PerformancePitfall
override fun zip(
left: StructureND<Int>,
right: StructureND<Int>,
transform: IntRing.(Int, Int) -> Int,
): IntTensor {
require(left.shape.contentEquals(right.shape)) {
"The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}"
}
val leftTensor = left.tensor
val leftArray = leftTensor.copyArray()
val rightTensor = right.tensor
val rightArray = rightTensor.copyArray()
val array = IntArray(leftTensor.numElements) { IntRing.transform(leftArray[it], rightArray[it]) }
return IntTensor(
leftTensor.shape,
array
)
}
override fun StructureND<Int>.valueOrNull(): Int? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun StructureND<Int>.value(): Int = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
/**
* Constructs a tensor with the specified shape and data.
*
* @param shape the desired shape for the tensor.
* @param buffer one-dimensional data array.
* @return tensor with the [shape] shape and [buffer] data.
*/
public fun fromArray(shape: IntArray, buffer: IntArray): IntTensor {
checkEmptyShape(shape)
check(buffer.isNotEmpty()) { "Illegal empty buffer provided" }
check(buffer.size == shape.reduce(Int::times)) {
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
}
return IntTensor(shape, buffer, 0)
}
/**
* Constructs a tensor with the specified shape and initializer.
*
* @param shape the desired shape for the tensor.
* @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by the [initializer].
*/
override fun structureND(shape: IntArray, initializer: IntRing.(IntArray) -> Int): IntTensor = fromArray(
shape,
TensorLinearStructure(shape).asSequence().map { IntRing.initializer(it) }.toMutableList().toIntArray()
)
override operator fun Tensor<Int>.get(i: Int): IntTensor {
val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
return IntTensor(newShape, tensor.mutableBuffer.array(), newStart)
}
/**
* Creates a tensor of a given shape and fills all elements with a given value.
*
* @param value the value to fill the output tensor with.
* @param shape array of integers defining the shape of the output tensor.
* @return tensor with the [shape] shape and filled with [value].
*/
public fun full(value: Int, shape: IntArray): IntTensor {
checkEmptyShape(shape)
val buffer = IntArray(shape.reduce(Int::times)) { value }
return IntTensor(shape, buffer)
}
/**
* Returns a tensor with the same shape as `input` filled with [value].
*
* @param value the value to fill the output tensor with.
* @return tensor with the `input` tensor shape and filled with [value].
*/
public fun Tensor<Int>.fullLike(value: Int): IntTensor {
val shape = tensor.shape
val buffer = IntArray(tensor.numElements) { value }
return IntTensor(shape, buffer)
}
/**
* Returns a tensor filled with the scalar value `0`, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value `0`, with the [shape] shape.
*/
public fun zeros(shape: IntArray): IntTensor = full(0, shape)
/**
* Returns a tensor filled with the scalar value `0`, with the same shape as a given array.
*
* @return tensor filled with the scalar value `0`, with the same shape as `input` tensor.
*/
public fun StructureND<Int>.zeroesLike(): IntTensor = tensor.fullLike(0)
/**
* Returns a tensor filled with the scalar value `1`, with the shape defined by the variable argument [shape].
*
* @param shape array of integers defining the shape of the output tensor.
* @return tensor filled with the scalar value `1`, with the [shape] shape.
*/
public fun ones(shape: IntArray): IntTensor = full(1, shape)
/**
* Returns a tensor filled with the scalar value `1`, with the same shape as a given array.
*
* @return tensor filled with the scalar value `1`, with the same shape as `input` tensor.
*/
public fun Tensor<Int>.onesLike(): IntTensor = tensor.fullLike(1)
/**
* Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
*
* @param n the number of rows and columns
* @return a 2-D tensor with ones on the diagonal and zeros elsewhere.
*/
public fun eye(n: Int): IntTensor {
val shape = intArrayOf(n, n)
val buffer = IntArray(n * n) { 0 }
val res = IntTensor(shape, buffer)
for (i in 0 until n) {
res[intArrayOf(i, i)] = 1
}
return res
}
/**
* Return a copy of the tensor.
*
* @return a copy of the `input` tensor with a copied buffer.
*/
public fun StructureND<Int>.copy(): IntTensor =
IntTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
override fun Int.plus(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this
}
return IntTensor(arg.shape, resBuffer)
}
override fun StructureND<Int>.plus(arg: Int): IntTensor = arg + tensor
override fun StructureND<Int>.plus(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg.tensor)
val resBuffer = IntArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i]
}
return IntTensor(tensor.shape, resBuffer)
}
override fun Tensor<Int>.plusAssign(value: Int) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
}
}
override fun Tensor<Int>.plusAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg.tensor)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Int.minus(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i ->
this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i]
}
return IntTensor(arg.shape, resBuffer)
}
override fun StructureND<Int>.minus(arg: Int): IntTensor {
val resBuffer = IntArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg
}
return IntTensor(tensor.shape, resBuffer)
}
override fun StructureND<Int>.minus(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg)
val resBuffer = IntArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i]
}
return IntTensor(tensor.shape, resBuffer)
}
override fun Tensor<Int>.minusAssign(value: Int) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
}
}
override fun Tensor<Int>.minusAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun Int.times(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this
}
return IntTensor(arg.shape, resBuffer)
}
override fun StructureND<Int>.times(arg: Int): IntTensor = arg * tensor
override fun StructureND<Int>.times(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg)
val resBuffer = IntArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] *
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i]
}
return IntTensor(tensor.shape, resBuffer)
}
override fun Tensor<Int>.timesAssign(value: Int) {
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
}
}
override fun Tensor<Int>.timesAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i]
}
}
override fun StructureND<Int>.unaryMinus(): IntTensor {
val resBuffer = IntArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
}
return IntTensor(tensor.shape, resBuffer)
}
override fun Tensor<Int>.transpose(i: Int, j: Int): IntTensor {
val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj)
val n = tensor.numElements
val resBuffer = IntArray(n)
val resShape = tensor.shape.copyOf()
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = IntTensor(resShape, resBuffer)
for (offset in 0 until n) {
val oldMultiIndex = tensor.indices.index(offset)
val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.indices.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset]
}
return resTensor
}
override fun Tensor<Int>.view(shape: IntArray): IntTensor {
checkView(tensor, shape)
return IntTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
}
override fun Tensor<Int>.viewAs(other: StructureND<Int>): IntTensor =
tensor.view(other.shape)
override fun diagonalEmbedding(
diagonalEntries: Tensor<Int>,
offset: Int,
dim1: Int,
dim2: Int,
): IntTensor {
val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2)
check(d1 != d2) {
"Diagonal dimensions cannot be identical $d1, $d2"
}
check(d1 <= n && d2 <= n) {
"Dimension out of range"
}
var lessDim = d1
var greaterDim = d2
var realOffset = offset
if (lessDim > greaterDim) {
realOffset *= -1
lessDim = greaterDim.also { greaterDim = lessDim }
}
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) {
val multiIndex = diagonalEntries.tensor.indices.index(i)
var offset1 = 0
var offset2 = abs(realOffset)
if (realOffset < 0) {
offset1 = offset2.also { offset2 = offset1 }
}
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset1) +
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
intArrayOf(multiIndex[n - 1] + offset2) +
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
}
return resTensor.tensor
}
private infix fun Tensor<Int>.eq(
other: Tensor<Int>,
): Boolean {
checkShapesCompatible(tensor, other)
val n = tensor.numElements
if (n != other.tensor.numElements) {
return false
}
for (i in 0 until n) {
if (tensor.mutableBuffer[tensor.bufferStart + i] != other.tensor.mutableBuffer[other.tensor.bufferStart + i]) {
return false
}
}
return true
}
/**
* Concatenates a sequence of tensors with equal shapes along the first dimension.
*
* @param tensors the [List] of tensors with same shapes to concatenate
* @return tensor with concatenation result
*/
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements)
}.toIntArray()
return IntTensor(resShape, resBuffer, 0)
}
/**
* Builds tensor from rows of the input tensor.
*
* @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to row by [indices]
*/
public fun Tensor<Int>.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { this[it] })
private inline fun StructureND<Int>.fold(foldFunction: (IntArray) -> Int): Int =
foldFunction(tensor.copyArray())
private inline fun <reified R : Any> StructureND<Int>.foldDim(
dim: Int,
keepDim: Boolean,
foldFunction: (IntArray) -> R,
): BufferedTensor<R> {
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) {
shape.take(dim).toIntArray() + intArrayOf(1) + shape.takeLast(dimension - dim - 1).toIntArray()
} else {
shape.take(dim).toIntArray() + shape.takeLast(dimension - dim - 1).toIntArray()
}
val resNumElements = resShape.reduce(Int::times)
val init = foldFunction(IntArray(1) { 0 })
val resTensor = BufferedTensor(
resShape,
MutableBuffer.auto(resNumElements) { init }, 0
)
for (index in resTensor.indices) {
val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(IntArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix]
})
}
return resTensor
}
override fun StructureND<Int>.sum(): Int = tensor.fold { it.sum() }
override fun StructureND<Int>.sum(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.sum() }.toIntTensor()
override fun StructureND<Int>.min(): Int = this.fold { it.minOrNull()!! }
override fun StructureND<Int>.min(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.minOrNull()!! }.toIntTensor()
override fun StructureND<Int>.max(): Int = this.fold { it.maxOrNull()!! }
override fun StructureND<Int>.max(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.toIntTensor()
override fun StructureND<Int>.argMax(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x ->
x.withIndex().maxByOrNull { it.value }?.index!!
}.toIntTensor()
}
public val Int.Companion.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra
public val IntRing.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra

View File

@ -16,8 +16,7 @@ internal fun checkEmptyShape(shape: IntArray) =
"Illegal empty shape provided"
}
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) =
check(buffer.isNotEmpty()) {
internal fun checkEmptyDoubleBuffer(buffer: DoubleArray) = check(buffer.isNotEmpty()) {
"Illegal empty buffer provided"
}
@ -50,7 +49,7 @@ internal fun checkSquareMatrix(shape: IntArray) {
}
internal fun DoubleTensorAlgebra.checkSymmetric(
tensor: Tensor<Double>, epsilon: Double = 1e-6
tensor: Tensor<Double>, epsilon: Double = 1e-6,
) =
check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"

View File

@ -8,7 +8,6 @@ package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor
@ -43,7 +42,8 @@ internal val StructureND<Double>.tensor: DoubleTensor
else -> this.toBufferedTensor().asTensor()
}
internal val Tensor<Int>.tensor: IntTensor
@PublishedApi
internal val StructureND<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()

View File

@ -13,7 +13,10 @@ import kotlin.jvm.JvmName
@JvmName("varArgOne")
public fun DoubleTensorAlgebra.one(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
public fun DoubleTensorAlgebra.one(shape: Shape): DoubleTensor = ones(shape)
@JvmName("varArgZero")
public fun DoubleTensorAlgebra.zero(vararg shape: Int): DoubleTensor = zeros(intArrayOf(*shape))
public fun DoubleTensorAlgebra.zero(shape: Shape): DoubleTensor = zeros(shape)

View File

@ -4,9 +4,7 @@ plugins {
kscience{
native()
}
kotlin.sourceSets.commonMain {
withContextReceivers()
dependencies {
api(projects.kmath.kmathGeometry)
}