由于我工作的项目,我不处理类型擦除所有的太多。这就是说,有一种方法让我心烦意乱,我想出了另一种解决方案。我正在进行一个使用大量矩阵乘法的项目,并且使用fommil的netlib-java进行本地blas操作。下面是有问题的方法:
def gemm[A: ClassTag: TypeTag](
transA : String,
transB : String,
m : Int,
n : Int,
k : Int,
alpha : A,
a : Array[A],
b : Array[A],
beta : A) = {
val lda = if (transA == "N" || transA == "n") k else m
val ldb = if (transB == "N" || transA == "n") n else k
typeOf[A] match {
case t if t =:= typeOf[Float] =>
val _alpha = alpha.asInstanceOf[Float]
val _beta = beta.asInstanceOf[Float]
val _a = a.asInstanceOf[Array[Float]]
val _b = b.asInstanceOf[Array[Float]]
val outArray = new Array[Float](m * n)
blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
outArray.asInstanceOf[Array[A]]
case t if t =:= typeOf[Double] =>
val _alpha = alpha.asInstanceOf[Double]
val _beta = beta.asInstanceOf[Double]
val _a = a.asInstanceOf[Array[Double]]
val _b = b.asInstanceOf[Array[Double]]
val outArray = new Array[Double](m * n)
blas.dgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
outArray.asInstanceOf[Array[A]]
case _ =>
val outArray = Predef.implicitly[ClassTag[A]].newArray(m * n)
gemm_ref(transA, transB, m, n, k, alpha, a, b, beta, outArray)
outArray
}
}我已经考虑过的另一种选择是使用无形状的“可键入/原型”的类型安全的转换。根据我的理解,这是通过遍历集合中的每个元素来确保类型的一致性。与此相关的开销是存在的,而且由于我处理的数组通常有很多元素,所以我不需要任何额外的开销。
发布于 2016-10-05 13:54:33
像这样的怎么样?
trait Blas[A] {
def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: A, beta: A, a: Array[A], b: Array[A]): Array[A]
}
object Blas {
implicit def floatBlas: Blas[Float] = new Blas[Float] {
override def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: Float, beta: Float, a: Array[Float], b: Array[Float]): Array[Float] = {
val outArray = new Array[Float](m * n)
blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
outArray
}
}
implicit def doubleBlas: Blas[Double] = ???
// etc.
}
def gemm[A](
transA: String,
transB: String,
m: Int,
n: Int,
k: Int,
alpha: A,
a: Array[A],
b: Array[A],
beta: A
)(implicit blas: Blas[A]) = {
val lda = if (transA == "N" || transA == "n") k else m
val ldb = if (transB == "N" || transA == "n") n else k
blas.gemm(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, m)
}(您必须自己修复对象和方法名称,我不知道它们指的是什么。)
其思想是传递一个额外的隐式参数,即自动查找。在定义这些实例时,您拥有可用的完整类型信息,然后不需要在typeOf上匹配。
https://stackoverflow.com/questions/39875066
复制相似问题