dot fixed for tensorflow
This commit is contained in:
parent
0f5f59e175
commit
db06d10cc2
@ -180,7 +180,9 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A: Ring<T>> internal cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = biOp(other) { l, r ->
|
||||||
ops.linalg.matMul(l, r)
|
ops.linalg.matMul(
|
||||||
|
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l,ops.constant(0)) else l,
|
||||||
|
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r,ops.constant(-1)) else r)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): Tensor<T> = ops.run {
|
||||||
|
@ -208,7 +208,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
|
|||||||
*
|
*
|
||||||
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
||||||
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
||||||
* After the matrix multiply, the prepended dimension is removed.
|
* After the matrix multiply, depending on the implementation the prepended dimension might be removed.
|
||||||
*
|
*
|
||||||
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
||||||
* the matrix-vector product is returned.
|
* the matrix-vector product is returned.
|
||||||
|
Loading…
Reference in New Issue
Block a user