Multik wrapper
This commit is contained in:
parent
05ae21580b
commit
827f115a92
@ -48,6 +48,7 @@ kotlin {
|
|||||||
implementation(project(":kmath-nd4j"))
|
implementation(project(":kmath-nd4j"))
|
||||||
implementation(project(":kmath-kotlingrad"))
|
implementation(project(":kmath-kotlingrad"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
|
implementation(projects.kmathMultik)
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||||
// uncomment if your system supports AVX2
|
// uncomment if your system supports AVX2
|
||||||
// val os = System.getProperty("os.name")
|
// val os = System.getProperty("os.name")
|
||||||
|
@ -9,6 +9,7 @@ import kotlinx.benchmark.Benchmark
|
|||||||
import kotlinx.benchmark.Blackhole
|
import kotlinx.benchmark.Blackhole
|
||||||
import kotlinx.benchmark.Scope
|
import kotlinx.benchmark.Scope
|
||||||
import kotlinx.benchmark.State
|
import kotlinx.benchmark.State
|
||||||
|
import space.kscience.kmath.multik.multikND
|
||||||
import space.kscience.kmath.nd.BufferedFieldOpsND
|
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
@ -17,8 +18,9 @@ import space.kscience.kmath.nd4j.nd4j
|
|||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
import space.kscience.kmath.tensors.core.ones
|
import space.kscience.kmath.tensors.core.one
|
||||||
import space.kscience.kmath.tensors.core.tensorAlgebra
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
|
import space.kscience.kmath.viktor.viktorAlgebra
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
internal class NDFieldBenchmark {
|
internal class NDFieldBenchmark {
|
||||||
@ -43,16 +45,30 @@ internal class NDFieldBenchmark {
|
|||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun multikAdd(blackhole: Blackhole) = with(multikField) {
|
||||||
|
var res: StructureND<Double> = one(shape)
|
||||||
|
repeat(n) { res += 1.0 }
|
||||||
|
blackhole.consume(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun viktorAdd(blackhole: Blackhole) = with(viktorField) {
|
||||||
|
var res: StructureND<Double> = one(shape)
|
||||||
|
repeat(n) { res += 1.0 }
|
||||||
|
blackhole.consume(res)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
var res: DoubleTensor = ones(dim, dim)
|
var res: DoubleTensor = one(shape)
|
||||||
repeat(n) { res = res + 1.0 }
|
repeat(n) { res = res + 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
val res: DoubleTensor = ones(dim, dim)
|
val res: DoubleTensor = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -72,5 +88,7 @@ internal class NDFieldBenchmark {
|
|||||||
private val specializedField = DoubleField.ndAlgebra
|
private val specializedField = DoubleField.ndAlgebra
|
||||||
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
||||||
private val nd4jField = DoubleField.nd4j
|
private val nd4jField = DoubleField.nd4j
|
||||||
|
private val multikField = DoubleField.multikND
|
||||||
|
private val viktorField = DoubleField.viktorAlgebra
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,16 +105,15 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
|
|
||||||
override fun hashCode(): Int = shape.contentHashCode()
|
override fun hashCode(): Int = shape.contentHashCode()
|
||||||
|
|
||||||
@ThreadLocal
|
|
||||||
public companion object {
|
|
||||||
//private val defaultStridesCache = HashMap<IntArray, Strides>()
|
|
||||||
|
|
||||||
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Cached builder for default strides
|
* Cached builder for default strides
|
||||||
*/
|
*/
|
||||||
public operator fun invoke(shape: IntArray): Strides = DefaultStrides(shape)
|
public operator fun invoke(shape: IntArray): Strides =
|
||||||
//defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//TODO fix cache
|
@ThreadLocal
|
||||||
}
|
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
}
|
|
@ -0,0 +1,135 @@
|
|||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.mk
|
||||||
|
import org.jetbrains.kotlinx.multik.api.zeros
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
|
import space.kscience.kmath.nd.FieldOpsND
|
||||||
|
import space.kscience.kmath.nd.RingOpsND
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A ring algebra for Multik operations
|
||||||
|
*/
|
||||||
|
public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
|
||||||
|
public val type: DataType,
|
||||||
|
override val elementAlgebra: A
|
||||||
|
) : RingOpsND<T, A> {
|
||||||
|
|
||||||
|
protected fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
||||||
|
|
||||||
|
override fun produce(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
|
val res = mk.zeros<T, DN>(shape, type).asDNArray()
|
||||||
|
for (index in res.multiIndices) {
|
||||||
|
res[index] = elementAlgebra.initializer(index)
|
||||||
|
}
|
||||||
|
return res.wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
|
||||||
|
this
|
||||||
|
} else {
|
||||||
|
produce(shape) { get(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {
|
||||||
|
//taken directly from Multik sources
|
||||||
|
val array = asMultik().array
|
||||||
|
val data = initMemoryView<T>(array.size, type)
|
||||||
|
var count = 0
|
||||||
|
for (el in array) data[count++] = elementAlgebra.transform(el)
|
||||||
|
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> {
|
||||||
|
//taken directly from Multik sources
|
||||||
|
val array = asMultik().array
|
||||||
|
val data = initMemoryView<T>(array.size, type)
|
||||||
|
val indexIter = array.multiIndices.iterator()
|
||||||
|
var index = 0
|
||||||
|
for (item in array) {
|
||||||
|
if (indexIter.hasNext()) {
|
||||||
|
data[index++] = elementAlgebra.transform(indexIter.next(), item)
|
||||||
|
} else {
|
||||||
|
throw ArithmeticException("Index overflow has happened.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||||
|
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||||
|
val leftArray = left.asMultik().array
|
||||||
|
val rightArray = right.asMultik().array
|
||||||
|
val data = initMemoryView<T>(leftArray.size, type)
|
||||||
|
var counter = 0
|
||||||
|
val leftIterator = leftArray.iterator()
|
||||||
|
val rightIterator = rightArray.iterator()
|
||||||
|
//iterating them together
|
||||||
|
while (leftIterator.hasNext()) {
|
||||||
|
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
|
||||||
|
}
|
||||||
|
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = asMultik().array.unaryMinus().wrap()
|
||||||
|
|
||||||
|
override fun add(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
|
||||||
|
(left.asMultik().array + right.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
|
||||||
|
asMultik().array.plus(arg).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<T>.minus(arg: T): MultikTensor<T> = asMultik().array.minus(arg).wrap()
|
||||||
|
|
||||||
|
override fun T.plus(arg: StructureND<T>): MultikTensor<T> = arg + this
|
||||||
|
|
||||||
|
override fun T.minus(arg: StructureND<T>): MultikTensor<T> = arg.map { this@minus - it }
|
||||||
|
|
||||||
|
override fun multiply(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
|
||||||
|
left.asMultik().array.times(right.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<T>.times(arg: T): MultikTensor<T> =
|
||||||
|
asMultik().array.times(arg).wrap()
|
||||||
|
|
||||||
|
override fun T.times(arg: StructureND<T>): MultikTensor<T> = arg * this
|
||||||
|
|
||||||
|
override fun StructureND<T>.unaryPlus(): MultikTensor<T> = asMultik()
|
||||||
|
|
||||||
|
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.plus(other.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.minus(other.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.times(other.asMultik().array).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field algebra for multik operations
|
||||||
|
*/
|
||||||
|
public class MultikFieldOpsND<T, A : Field<T>> internal constructor(
|
||||||
|
type: DataType,
|
||||||
|
elementAlgebra: A
|
||||||
|
) : MultikRingOpsND<T, A>(type, elementAlgebra), FieldOpsND<T, A> {
|
||||||
|
override fun StructureND<T>.div(other: StructureND<T>): StructureND<T> =
|
||||||
|
asMultik().array.div(other.asMultik().array).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public val DoubleField.multikND: MultikFieldOpsND<Double, DoubleField>
|
||||||
|
get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField)
|
||||||
|
|
||||||
|
public val FloatField.multikND: MultikFieldOpsND<Float, FloatField>
|
||||||
|
get() = MultikFieldOpsND(DataType.FloatDataType, FloatField)
|
||||||
|
|
||||||
|
public val ShortRing.multikND: MultikRingOpsND<Short, ShortRing>
|
||||||
|
get() = MultikRingOpsND(DataType.ShortDataType, ShortRing)
|
||||||
|
|
||||||
|
public val IntRing.multikND: MultikRingOpsND<Int, IntRing>
|
||||||
|
get() = MultikRingOpsND(DataType.IntDataType, IntRing)
|
||||||
|
|
||||||
|
public val LongRing.multikND: MultikRingOpsND<Long, LongRing>
|
||||||
|
get() = MultikRingOpsND(DataType.LongDataType, LongRing)
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.mk
|
||||||
|
import org.jetbrains.kotlinx.multik.api.zeros
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
@ -30,6 +32,7 @@ public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) :
|
|||||||
|
|
||||||
|
|
||||||
public abstract class MultikTensorAlgebra<T>(
|
public abstract class MultikTensorAlgebra<T>(
|
||||||
|
public val type: DataType,
|
||||||
public val elementAlgebra: Ring<T>,
|
public val elementAlgebra: Ring<T>,
|
||||||
public val comparator: Comparator<T>
|
public val comparator: Comparator<T>
|
||||||
) : TensorAlgebra<T> {
|
) : TensorAlgebra<T> {
|
||||||
@ -38,15 +41,19 @@ public abstract class MultikTensorAlgebra<T>(
|
|||||||
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
|
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
|
||||||
* are not reflected back onto the source
|
* are not reflected back onto the source
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.asMultik(): MultikTensor<T> {
|
private fun Tensor<T>.asMultik(): MultikTensor<T> {
|
||||||
return if (this is MultikTensor) {
|
return if (this is MultikTensor) {
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
TODO()
|
val res = mk.zeros<T, DN>(shape, type).asDNArray()
|
||||||
|
for (index in res.multiIndices) {
|
||||||
|
res[index] = this[index]
|
||||||
|
}
|
||||||
|
res.wrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
private fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
||||||
|
|
||||||
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
|
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
|
||||||
get(intArrayOf(0))
|
get(intArrayOf(0))
|
||||||
@ -77,8 +84,7 @@ public abstract class MultikTensorAlgebra<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO avoid additional copy
|
override fun T.minus(other: Tensor<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap()
|
||||||
override fun T.minus(other: Tensor<T>): MultikTensor<T> = -(other - this)
|
|
||||||
|
|
||||||
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
|
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
|
||||||
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
|
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
|
||||||
@ -130,13 +136,9 @@ public abstract class MultikTensorAlgebra<T>(
|
|||||||
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
|
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
|
||||||
asMultik().array.unaryMinus().wrap()
|
asMultik().array.unaryMinus().wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.get(i: Int): MultikTensor<T> {
|
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> {
|
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
||||||
require(shape.all { it > 0 })
|
require(shape.all { it > 0 })
|
||||||
@ -158,16 +160,14 @@ public abstract class MultikTensorAlgebra<T>(
|
|||||||
}.wrap()
|
}.wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> {
|
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||||
TODO("Not yet implemented")
|
TODO("Diagonal embedding not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
import space.kscience.kmath.nd.one
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
|
internal class MultikNDTest {
|
||||||
|
@Test
|
||||||
|
fun basicAlgebra(): Unit = DoubleField.multikND{
|
||||||
|
one(2,2) + 1.0
|
||||||
|
}
|
||||||
|
}
|
@ -38,6 +38,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
|||||||
return struct
|
return struct
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
||||||
val newStruct = ndArray.dup().wrap()
|
val newStruct = ndArray.dup().wrap()
|
||||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
|
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
|
||||||
@ -117,7 +118,7 @@ public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>,
|
|||||||
* Creates a most suitable implementation of [RingND] using reified class.
|
* Creates a most suitable implementation of [RingND] using reified class.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRingOps<T, Ring<T>> = when {
|
public inline fun <reified T : Number> auto(): Nd4jArrayRingOps<T, Ring<T>> = when {
|
||||||
T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
|
T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
|
||||||
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
||||||
}
|
}
|
||||||
@ -142,7 +143,7 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>,
|
|||||||
* Creates a most suitable implementation of [FieldND] using reified class.
|
* Creates a most suitable implementation of [FieldND] using reified class.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
|
public inline fun <reified T : Any> auto(): Nd4jArrayField<T, Field<T>> = when {
|
||||||
T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
|
T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||||
T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
|
T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||||
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||||
|
@ -5,4 +5,12 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
public fun DoubleTensorAlgebra.ones(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
|
import space.kscience.kmath.nd.Shape
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
|
@JvmName("varArgOne")
|
||||||
|
public fun DoubleTensorAlgebra.one(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
|
||||||
|
public fun DoubleTensorAlgebra.one(shape: Shape): DoubleTensor = ones(shape)
|
||||||
|
@JvmName("varArgZero")
|
||||||
|
public fun DoubleTensorAlgebra.zero(vararg shape: Int): DoubleTensor = zeros(intArrayOf(*shape))
|
||||||
|
public fun DoubleTensorAlgebra.zero(shape: Shape): DoubleTensor = zeros(shape)
|
Loading…
Reference in New Issue
Block a user