首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >是否有更好的方法来绕过集合上的类型擦除,特别是在将所述集合传递给Java方法的情况下?

是否有更好的方法来绕过集合上的类型擦除,特别是在将所述集合传递给Java方法的情况下?
EN

Stack Overflow用户
提问于 2016-10-05 13:17:36
回答 1查看 49关注 0票数 1

由于我工作的项目,我不处理类型擦除所有的太多。这就是说,有一种方法让我心烦意乱,我想出了另一种解决方案。我正在进行一个使用大量矩阵乘法的项目,并且使用fommil的netlib-java进行本地blas操作。下面是有问题的方法:

代码语言:javascript
复制
def gemm[A: ClassTag: TypeTag](
    transA  : String,
    transB  : String,
    m       : Int,
    n       : Int,
    k       : Int,
    alpha   : A,
    a       : Array[A],
    b       : Array[A],
    beta    : A) = {

    val lda = if (transA == "N" || transA == "n") k else m
    val ldb = if (transB == "N" || transA == "n") n else k

    typeOf[A] match {
      case t if t =:= typeOf[Float] =>
        val _alpha = alpha.asInstanceOf[Float]
        val _beta = beta.asInstanceOf[Float]
        val _a = a.asInstanceOf[Array[Float]]
        val _b = b.asInstanceOf[Array[Float]]
        val outArray = new Array[Float](m * n)
        blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
        outArray.asInstanceOf[Array[A]]
      case t if t =:= typeOf[Double] =>
        val _alpha = alpha.asInstanceOf[Double]
        val _beta = beta.asInstanceOf[Double]
        val _a = a.asInstanceOf[Array[Double]]
        val _b = b.asInstanceOf[Array[Double]]
        val outArray = new Array[Double](m * n)
        blas.dgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
        outArray.asInstanceOf[Array[A]]
      case _ =>
        val outArray = Predef.implicitly[ClassTag[A]].newArray(m * n)
        gemm_ref(transA, transB, m, n, k, alpha, a, b, beta, outArray)
        outArray
    }
  }

我已经考虑过的另一种选择是使用无形状的“可键入/原型”的类型安全的转换。根据我的理解,这是通过遍历集合中的每个元素来确保类型的一致性。与此相关的开销是存在的,而且由于我处理的数组通常有很多元素,所以我不需要任何额外的开销。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-10-05 13:54:33

像这样的怎么样?

代码语言:javascript
复制
trait Blas[A] {
  def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: A, beta: A, a: Array[A], b: Array[A]): Array[A]
}

object Blas {
  implicit def floatBlas: Blas[Float] = new Blas[Float] {
    override def gemm(transA: String, transB: String, m: Int, n: Int, k: Int, alpha: Float, beta: Float, a: Array[Float], b: Array[Float]): Array[Float] = {
      val outArray = new Array[Float](m * n)
      blas.sgemm(transA, transB, m, n, k, _alpha, _a, lda, _b, ldb, _beta, outArray, m)
      outArray
    }
  }

  implicit def doubleBlas: Blas[Double] = ???

  // etc.
}

def gemm[A](
    transA: String,
    transB: String,
    m: Int,
    n: Int,
    k: Int,
    alpha: A,
    a: Array[A],
    b: Array[A],
    beta: A
)(implicit blas: Blas[A]) = {

  val lda = if (transA == "N" || transA == "n") k else m
  val ldb = if (transB == "N" || transA == "n") n else k

  blas.gemm(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, m)

}

(您必须自己修复对象和方法名称,我不知道它们指的是什么。)

其思想是传递一个额外的隐式参数,即自动查找。在定义这些实例时,您拥有可用的完整类型信息,然后不需要在typeOf上匹配。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39875066

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档