basic algebra ops

This commit is contained in:
Roland Grinis 2021-07-08 11:03:20 +01:00
parent 2a97aca9b6
commit 623c96e4bb

View File

@ -1,6 +1,6 @@
/* /*
* Copyright 2018-2021 KMath contributors. * Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
package space.kscience.kmath.noa package space.kscience.kmath.noa
@ -20,18 +20,56 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
/** /**
* A scalar tensor in this implementation must have empty shape * A scalar tensor in this.cast() implementation must have empty shape
*/ */
override fun Tensor<T>.valueOrNull(): T? = override fun Tensor<T>.valueOrNull(): T? =
try { this.cast().item() } catch (e: NoaException) { null } try {
this.cast().cast().item()
} catch (e: NoaException) {
null
}
override fun Tensor<T>.value(): T = this.cast().item() override fun Tensor<T>.value(): T = this.cast().cast().item()
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle))
}
override operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit {
JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
override operator fun Tensor<T>.plus(other: Tensor<T>): TensorType {
return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
}
override operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit {
JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
override operator fun Tensor<T>.minus(other: Tensor<T>): TensorType {
return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
}
override operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit {
JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
override operator fun Tensor<T>.unaryMinus(): TensorType =
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
} }
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>> public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>, internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
AnalyticTensorAlgebra<T> { AnalyticTensorAlgebra<T> {
override operator fun Tensor<T>.div(other: Tensor<T>): TensorType {
return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle))
}
override operator fun Tensor<T>.divAssign(other: Tensor<T>): Unit {
JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
}
} }