首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何解决在写入Tif时将RasterFrameLayer转换为栅格时出现的NPE?

如何解决在写入Tif时将RasterFrameLayer转换为栅格时出现的NPE?
EN

Stack Overflow用户
提问于 2021-01-06 14:07:03
回答 1查看 58关注 0票数 1

在训练机器学习模型后,我正在尝试将RasterFrames中的预测RasterFrameLayer转换为GeoTiff文件。当使用rasterframes的演示数据Elkton-VA时,它工作得很好。

但是,当使用一个作物标记点2a tif与ndvi指数(归一化,从-1000到1000)时,它在toRaster步骤中以NullPointedException失败。

感觉这是由于ROI之外没有数据价值。测试数据为heregeojsonlog

Geotrellis版本:3.3.0

栅格帧版本:0.9.0

代码语言:javascript
复制
import geotrellis.proj4.LatLng
import geotrellis.raster._
import geotrellis.raster.io.geotiff.{MultibandGeoTiff, SinglebandGeoTiff}
import geotrellis.raster.io.geotiff.reader.GeoTiffReader
import geotrellis.raster.render.{ColorRamps, Png}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql._
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder}

object ClassificiationRaster extends App {

  def readTiff(name: String) =  GeoTiffReader.readSingleband(getClass.getResource(s"/$name").getPath)

  def readMtbTiff(name: String): MultibandGeoTiff =  GeoTiffReader.readMultiband(getClass.getResource(s"/$name").getPath)

  implicit val spark = SparkSession.builder()
    .master("local[*]")
    .appName(getClass.getName)
    .withKryoSerialization
    .getOrCreate()
    .withRasterFrames

  import spark.implicits._

  val filenamePattern = "xiangfuqu_202003_mask_%s.tif"
  val bandNumbers = "ndvi".split(",").toSeq
  val bandColNames = bandNumbers.map(b ⇒ s"band_$b").toArray
  val tileSize = 256

  val joinedRF: RasterFrameLayer = bandNumbers
    .map { b ⇒ (b, filenamePattern.format(b)) }
    .map { case (b, f) ⇒ (b, readTiff(f)) }
    .map { case (b, t) ⇒ t.projectedRaster.toLayer(tileSize, tileSize, s"band_$b") }
    .reduce(_ spatialJoin _)
    .withCRS()
    .withExtent()

  val tlm = joinedRF.tileLayerMetadata.left.get

//  println(tlm.totalDimensions.cols)
//  println(tlm.totalDimensions.rows)

  joinedRF.printSchema()

  val targetCol = "label"

  val geojsonPath = "/Users/ethan/work/data/L2a10m4326/zds/test.geojson"
  spark.sparkContext.addFile(geojsonPath)
  import org.locationtech.rasterframes.datasource.geojson._

  val jsonDF: DataFrame = spark.read.geojson.load(geojsonPath)
  val label_df: DataFrame = jsonDF
    .select($"CLASS_ID", st_reproject($"geometry",LatLng,LatLng).alias("geometry"))
    .hint("broadcast")

  val df_joined = joinedRF.join(label_df, st_intersects(st_geometry($"extent"), $"geometry"))
    .withColumn("dims",rf_dimensions($"band_ndvi"))

  val df_labeled: DataFrame = df_joined.withColumn(
    "label",
    rf_rasterize($"geometry", st_geometry($"extent"), $"CLASS_ID", $"dims.cols", $"dims.rows")
  )

  df_labeled.printSchema()

  val tmp = df_labeled.filter(rf_tile_sum($"label") > 0).cache()

  val exploder = new TileExploder()

  val noDataFilter = new NoDataFilter().setInputCols(bandColNames :+ targetCol)

  val assembler = new VectorAssembler()
    .setInputCols(bandColNames)
    .setOutputCol("features")

  val classifier = new DecisionTreeClassifier()
    .setLabelCol(targetCol)
    .setFeaturesCol(assembler.getOutputCol)

  val pipeline = new Pipeline()
    .setStages(Array(exploder, noDataFilter, assembler, classifier))

  val evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol(targetCol)
    .setPredictionCol("prediction")
    .setMetricName("f1")

  val paramGrid = new ParamGridBuilder()
    //.addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
    .build()

  val trainer = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(evaluator)
    .setEstimatorParamMaps(paramGrid)
    .setNumFolds(4)

  val model = trainer.fit(tmp)

  val metrics = model.getEstimatorParamMaps
    .map(_.toSeq.map(p ⇒ s"${p.param.name} = ${p.value}"))
    .map(_.mkString(", "))
    .zip(model.avgMetrics)
  metrics.toSeq.toDF("params", "metric").show(false)

  val scored = model.bestModel.transform(joinedRF)

  scored.groupBy($"prediction" as "class").count().show

  scored.show(20)


  val retiled: DataFrame = scored.groupBy($"crs", $"extent").agg(
    rf_assemble_tile(
      $"column_index", $"row_index", $"prediction",
      tlm.tileCols, tlm.tileRows, IntConstantNoDataCellType
    )
  )

  val rf: RasterFrameLayer = retiled.toLayer(tlm)

  val raster: ProjectedRaster[Tile] = rf.toRaster($"prediction", 5848, 4189)

  SinglebandGeoTiff(raster.tile,tlm.extent, tlm.crs).write("/Users/ethan/project/IdeaProjects/learn/spark_ml_learn.git/src/main/resources/easy_b1.tif")

    val clusterColors = ColorRamp(
      ColorRamps.Viridis.toColorMap((0 until 1).toArray).colors
    )

//  val pngBytes = retiled.select(rf_render_png($"prediction", clusterColors)).first  //It can output the png.
//  retiled.tile.renderPng(clusterColors).write("/Users/ethan/project/IdeaProjects/learn/spark_ml_learn.git/src/main/resources/classified2.png")

//  Png(pngBytes).write("/Users/ethan/project/IdeaProjects/learn/spark_ml_learn.git/src/main/resources/classified2.png")

  spark.stop()
}
EN

回答 1

Stack Overflow用户

发布于 2021-01-08 02:58:47

我怀疑toLayer扩展方法的工作方式存在错误。我会向RasterFrames项目提交一份错误报告。我怀疑这将需要更多的努力。

这是一个可能的解决方法,它的级别稍微低一点。在这种情况下,它导致25个非重叠GeoTiffs被写出。

代码语言:javascript
复制
import geotrellis.store.hadoop.{SerializableConfiguration, _}
import geotrellis.spark.Implicits._
import org.apache.hadoop.fs.Path

// Need this to write local files from spark
val hconf = SerializableConfiguration(spark.sparkContext.hadoopConfiguration)

ContextRDD(
    rf.toTileLayerRDD($"prediction")
      .left.get
      .filter{
        case (_: SpatialKey, null) ⇒ false  // remove any null Tiles
        case _ ⇒ true
      },
    tlm)
    .regrid(1024)  //Regrid the Tiles so that they are 1024 x 1024
    .toGeoTiffs()
    .foreach{ case (sk: SpatialKey, gt: SinglebandGeoTiff) ⇒
        val path = new Path(new Path("file:///tmp/output"), s"${sk.col}_${sk.row}.tif")
        gt.write(path, hconf.value)
      }
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65590971

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档