首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Spark Mlib ALS 交替最小二乘算法(学习笔记)

Spark Mlib ALS 交替最小二乘算法(学习笔记)

原创
作者头像
用户10150864
修改2026-01-22 16:01:07
修改2026-01-22 16:01:07
1150
举报
文章被收录于专栏:机器学习机器学习

一、ALS交替最小二乘法

ALS是交替最小二乘的简称(alternating least squares)的简称。在机器学习的上下文中,ALS特指使用交替最小二乘求解的一个协同推荐算法。它通过观察到的所有用户给物品的打分,来推断每个用户的喜好并向用户推荐合适的物品。

如下面的用户打分的矩阵:

图1-1 用户评分矩阵
图1-1 用户评分矩阵

三、实例代码

在IDEA 中创建scala 项目,代码:

代码语言:txt
复制

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.sql.SparkSession

object als1 {

  def main(args: Array[String]): Unit = {
    // 0. 构建Spark对象
//    val spark = SparkSession.builder()
//      .appName("ALS")
//      .getOrCreate()
    val spark = SparkSession.builder()
      .appName("ALSRecommendation")
      .master("local[*]")  // 使用本地所有可用核心
      .getOrCreate()


    Logger.getRootLogger.setLevel(Level.WARN)

    // 1. 读取样本数据
    val data = spark.read.option("header", "false").csv("H:\\zcy\\ALSrecommend\\src\\main\\scala\\luoke\\test.data")
    val ratings = data.toDF("userId", "movieId", "rating")
      .selectExpr("cast(userId as int) as userId", "cast(movieId as int) as movieId", "cast(rating as float) as rating")

    // 2. 建立模型
    val rank = 10
    val numIterations = 20
    val als = new ALS()
      .setUserCol("userId")
      .setItemCol("movieId")
      .setRatingCol("rating")
      .setRank(rank)
      .setMaxIter(numIterations)
      .setRegParam(0.01)

    val model = als.fit(ratings)

    // 3. 预测结果
    val usersProducts = ratings.select("userId", "movieId")

    val predictions = model.transform(usersProducts)

    // 计算均方误差 (MSE)
    val ratesAndPreds = ratings.join(predictions, Seq("userId", "movieId"))
    val MSE = ratesAndPreds
      .selectExpr("power(rating - prediction, 2) as error")
      .agg(org.apache.spark.sql.functions.mean("error"))
      .first()
      .getDouble(0)

    println("Mean Squared Error = " + MSE)

    // 4. 保存/加载模型
    model.write.overwrite().save("myModelPath")
    val sameModel = ALSModel.load("myModelPath")

    spark.stop() // 结束Spark会话
  }
}

运行结果:
Mean Squared Error = 1.904915633588189E-5

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档