首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >张量广播机制【Ai Infra 3.0】[PyTorch Java 硕士研一课程]

张量广播机制【Ai Infra 3.0】[PyTorch Java 硕士研一课程]

原创
作者头像
用户12258095
发布2026-03-13 13:09:33
发布2026-03-13 13:09:33
1240
举报

概述

[PyTorch Java 硕士研一课程]

展开

代码语言:

TXT

自动换行

AI代码解释

# 理解广播机制

当对张量执行逐元素操作(如加法、减法或乘法)时,它们的形状通常需要对齐。但是,手动调整或重复张量以匹配形状可能会很繁琐且效率低下,尤其是在处理大型数据集时。PyTorch 通过一种称为**广播(broadcasting)**的机制解决了这个问题。

广播提供了一套规则,允许 PyTorch 在执行操作时自动扩展张量维度,前提是它们的形状满足特定的兼容标准。这在许多常见情况下省去了显式维度扩展的需要,使得代码更简洁,内存使用更优化,因为实际数据并未重复;只有计算行为像数据重复了一样。

### 广播规则

PyTorch 通过逐元素比较两个张量的形状来判断它们是否“可广播”,比较从*末尾*(最右侧)维度开始。如果满足以下条件,则两个张量可兼容进行广播(从右到左比较每个维度对):

1. **维度相等:** 维度大小相等。

2. **其中一个维度为 1:** 两个维度中的一个为 1。

3. **缺少维度:** 一个张量不具备该维度(在此比较中,其大小被视为 1)。

如果所有维度对都满足这些条件,则张量是可广播的。结果张量的形状将沿每个维度对取最大尺寸。如果任何维度对不满足条件(即,维度不同且都不为 1),则会引发 `RuntimeError`。

我们来分析一下这个过程:

1. **对齐形状:** 张量根据它们的末尾维度进行对齐。如果一个张量的维度少于另一个,那么为了对齐,会在其形状前面添加大小为 1 的维度。

2. **检查兼容性并确定结果形状:** 从最右侧维度开始,比较尺寸:

- 如果维度相等,则结果维度大小就是该尺寸。

- 如果一个维度为 1,则结果维度大小是另一个(较大)维度的大小。

- 如果一个张量缺少某个维度(由于对齐),则结果维度大小是另一个张量中该维度的大小。

3. **执行操作:** 操作的执行方式,就像是沿着给定维度大小为 1 的张量,其值被复制以匹配另一个张量中对应维度的大小一样。

### 广播示例

我们用代码示例来说明。

#### 标量与张量

将标量(一个 0 维张量)添加到任何张量时,总是通过广播机制生效。标量会有效地扩展以匹配张量的形状。

```scala 3

import torch.*

// 张量 A: 形状 [2, 3]

val a = torch.tensor([[1, 2, 3], [4, 5, 6]])

// 标量 B: 形状 [] (0 维度)

val b = torch.tensor(10)

// 将标量添加到张量

val c = a + b

println(f"Shape of a: {a.shape}")

// 张量 a 的形状: torch.Size([2, 3])

println(f"Shape of b: {b.shape}")

// 标量 b 的形状: torch.Size([])

println(f"Shape of c: {c.shape}")

// 张量 c 的形状: torch.Size([2, 3])

println(f"Result c:\n{c}")

// 结果 c:

// tensor([[11, 12, 13],

// [14, 15, 16]])

展开

代码语言:

Java

自动换行

AI代码解释

// 18-21. 广播机制

LongPointer matrixData3 = new LongPointer(1, 2, 3, 4, 5, 6);

Tensor a = tensor(new LongArrayRef(matrixData3, new LongPointer(2, 3)));

LongPointer scalarData = new LongPointer(10);

Tensor b = tensor(scalarData);

Tensor c = a.add(b);

System.out.printf("Shape of a: %s%n", a.sizes());

System.out.printf("Shape of b: %s%n", b.sizes());

System.out.printf("Shape of c: %s%n", c.sizes());

System.out.printf("Result c:\n %s%n", c);

这里,b(形状 [])被广播到形状 [2, 3] 以匹配 a。

行向量与矩阵

考虑将一个行向量(形状 [3])添加到一个矩阵(形状 [2, 3])中。

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 张量 A: 形状 [2, 3]

val a = torch.tensor([[1, 2, 3],

[4, 5, 6]])

// 张量 B: 形状 [3] (为了广播,可以视为 [1, 3])

val b = torch.tensor([10, 20, 30])

// 将行向量添加到矩阵

val c = a + b

println(f"Shape of a: {a.shape}") // torch.Size([2, 3])

println(f"Shape of b: {b.shape}") // torch.Size([3])

println(f"Shape of c: {c.shape}") // torch.Size([2, 3])

println(f"Result c:\n{c}")

// 结果 c:

// tensor([[11, 22, 33],

// [14, 25, 36]])

对齐: a 的形状为 [2, 3]。b 的形状为 [3]。右侧对齐结果如下:

代码语言:

TXT

自动换行

AI代码解释

张量 A: 2 x 3

张量 B: 3

兼容性检查:

末尾维度:3 等于 3。兼容。结果维度大小为 3。

下一个维度:a 为 2,b 在此处没有维度(隐式大小为 1)。兼容。结果维度大小为 2。

结果形状: [2, 3]。

扩展: 张量 b 被视为形状 [1, 3],并且其单行沿第一个维度复制以匹配 a 的形状 [2, 3]。

列向量与矩阵

现在,我们来将一个列向量(形状 [2, 1])添加至同一矩阵(形状 [2, 3])。

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 张量 A: 形状 [2, 3]

val a = torch.tensor(Seq(Seq(1, 2, 3),

Seq(4, 5, 6)))

// 张量 B: 形状 [2, 1]

val b = torch.tensor(Seq(Seq(10), Seq(20)))

// 将列向量添加到矩阵

val c = a + b

println(f"Shape of a: {a.shape}") // torch.Size([2, 3])

println(f"Shape of b: {b.shape}") // torch.Size([2, 1])

println(f"Shape of c: {c.shape}") // torch.Size([2, 3])

println(f"Result c:\n{c}")

// 结果 c:

// tensor([[11, 12, 13],

// [24, 25, 26]])

对齐:

代码语言:

TXT

自动换行

AI代码解释

张量 A: 2 x 3

张量 B: 2 x 1

兼容性检查:

末尾维度:a 为 3,b 为 1。兼容(其中一个为 1)。结果维度大小为 3。

下一个维度:a 为 2,b 为 2。兼容(相等)。结果维度大小为 2。

结果形状: [2, 3]。

扩展: 张量 b 中大小为 1 的维度(列维度)通过跨列复制值来扩展,以匹配 a 的形状 [2, 3]。

可视化示例

我们来可视化张量 A(形状 [3, 1])和 B(形状 [4])的广播过程。

张量 A (形状: [3, 1])张量 B (形状: [4])广播 A + B -> 结果 (形状: [3, 4])A1A2A3B1B2B3B4A1, A1, A1, A1A2, A2, A2, A2A3, A3, A3, A3扩展维度 1(大小 1 -> 4)B1, B2, B3, B4B1, B2, B3, B4B1, B2, B3, B4添加维度 0 并扩展(大小 [4] -> [1,4] -> [3,4])A1+B1, A1+B2, A1+B3, A1+B4A2+B1, A2+B2, A2+B3, A2+B4A3+B1, A3+B2, A3+B3, A3+B4++

张量 A (形状 [3, 1]) 和张量 B (形状 [4]) 进行广播加法的示意图。张量 A 的第二个维度(大小 1)扩展到 4。张量 B 获得一个大小为 1 的前置维度(变为形状 [1, 4]),然后扩展到大小 3。两者都有效地变为形状 [3, 4] 以进行逐元素加法。

不兼容的形状

如果非匹配维度不为 1,则广播会失败。

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 张量 A: 形状 [2, 3]

val a = torch.tensor(Seq(Seq(1, 2, 3),

Seq(4, 5, 6)))

// 张量 B: 形状 [2]

val b = torch.tensor(Seq(10, 20))

try:

val c = a + b

except Exception as e:

println(f"Error: {e}")

// 错误: 张量 a (3) 的大小必须与张量 b (2) 在非单例维度 1 处匹配

对齐:

代码语言:

TXT

自动换行

AI代码解释

张量 A: 2 x 3

张量 B: 2

兼容性检查:

末尾维度:a 为 3,b 为 2。两者都不为 1。不兼容。 操作失败。

常见用途

广播在神经网络中经常使用:

添加偏置: 将偏置向量(形状 [output_features])添加到线性层的输出(形状 [batch_size, output_features])。

归一化: 从一批数据中减去均值(标量或按特征向量)并除以标准差(标量或按特征向量)。

应用掩码: 将数据与可能具有较少维度的布尔掩码进行逐元素乘法。

理解广播对于编写简洁高效的 PyTorch 代码非常重要。它允许你自然地对不同形状的张量执行操作,只要它们遵守兼容性规则,从而简化了许多常见的数据处理和建模任务。

张量数据类型

张量具有数据类型,通常称为 dtype。数据类型决定了张量可以存储的数值种类(例如整数或浮点数)以及每个元素占用多少内存。选择合适的数据类型对于管理计算资源和确保深度学习模型中的数值精度非常重要。

PyTorch 支持多种数值数据类型,与 NumPy 中的类似。每种类型都有不同的用途,平衡了内存使用、计算速度以及可表示数字的范围或精度。

理解 dtype

每个张量都有一个 dtype 属性,用于指定其元素的类型。默认情况下,PyTorch 创建浮点张量时使用 torch.float32,整数张量时使用 torch.int64。你可以这样检查张量的数据类型:

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 默认浮点张量

val a = torch.tensor(Seq(1.0, 2.0, 3.0))

println(f"Tensor a: {a}")

println(f"dtype of a: {a.dtype}")

// 默认整数张量

val b = torch.tensor(Seq(1, 2, 3))

println(f"\nTensor b: {b}")

println(f"dtype of b: {b.dtype}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 1. 创建默认浮点张量(对应 Scala 的 torch.tensor(Seq(1.0,2.0,3.0))) ==========

// Java 中需先构造浮点数组,PyTorch 默认浮点类型为 Float32(和 Scala 一致)

double[] floatData = new double[]{1.0, 2.0, 3.0};

Tensor a = torch.tensor(floatData); // 不指定dtype,使用默认浮点类型

// 打印张量内容和数据类型(对齐原代码输出格式)

System.out.printf("Tensor a: %s%n", a);

System.out.printf("dtype of a: %s%n", a.scalar_type());

// ========== 2. 创建默认整数张量(对应 Scala 的 torch.tensor(Seq(1,2,3))) ==========

long[] intData = new long[]{1, 2, 3}; // Java 中默认整数张量对应 Long(int64),和 Scala 一致

Tensor b = torch.tensor(intData); // 不指定dtype,使用默认整数类型

// 打印张量内容和数据类型

System.out.printf("%nTensor b: %s%n", b);

System.out.printf("dtype of b: %s%n", b.scalar_type());

输出:

代码语言:

TXT

自动换行

AI代码解释

Tensor a: tensor([1., 2., 3.])

dtype of a: torch.float32

Tensor b: tensor([1, 2, 3])

dtype of b: torch.int64

在创建张量时,你也可以明确指定 dtype:

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个64位浮点数张量

val c = torch.tensor(Seq(1.0, 2.0), dtype=torch.float64)

println(f"\nTensor c: {c}")

println(f"dtype of c: {c.dtype}")

// 创建一个32位整数张量

val d = torch.ones(2, 2, dtype=torch.int32)

println(f"\nTensor d:\n{d}")

println(f"dtype of d: {d.dtype}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 1. 创建64位浮点数张量(torch.float64) ==========

// 步骤1:构造double数组(对应float64数据)

double[] float64Data = new double[]{1.0, 2.0};

// 步骤2:指定dtype=ScalarType.Double(对应torch.float64)

Tensor c = torch.tensor(float64Data,

torch.tensorOptions().dtype(ScalarType.Double));

// 打印张量内容和数据类型(对齐原代码输出格式)

System.out.printf("%nTensor c: %s%n", c);

System.out.printf("dtype of c: %s%n", c.scalar_type());

// ========== 2. 创建32位整数张量(torch.int32)的2x2全一张量 ==========

// 步骤1:先创建默认全一张量,再指定dtype=ScalarType.Int(对应torch.int32)

// 或直接通过tensorOptions指定dtype(更高效)

Tensor d = torch.ones(2, 2,

torch.tensorOptions().dtype(ScalarType.Int));

// 打印张量内容和数据类型(换行对齐原代码格式)

System.out.printf("%nTensor d:%n%s%n", d);

System.out.printf("dtype of d: %s%n", d.scalar_type());

输出:

展开

代码语言:

TXT

自动换行

AI代码解释

Tensor c: tensor([1., 2.], dtype=torch.float64)

dtype of c: torch.float64

Tensor d:

tensor([[1, 1],

[1, 1]], dtype=torch.int32)

dtype of d: torch.int32

你可以使用 torch.get_default_dtype() 查看 PyTorch 默认使用的浮点类型。

常用数据类型

以下是 PyTorch 中一些最常用的数据类型:

浮点类型:

torch.float32 (或 torch.float):标准的32位单精度浮点数。由于它在CPU和GPU上兼顾了精度和性能,因此是模型参数和一般计算中最常见的类型。

torch.float64 (或 torch.double):64位双精度浮点数。提供更高的精度,但占用两倍内存,并且速度可能明显较慢,尤其是在未针对双精度进行优化的GPU上。当绝对需要高数值精度时使用它。

torch.float16 (或 torch.half):16位半精度浮点数。占用更少内存,并可在现代GPU(如NVIDIA Tensor Cores)上显著加快计算速度。然而,其有限的范围和精度有时可能导致数值不稳定(溢出或下溢)。常用于混合精度训练。

torch.bfloat16:一种替代的16位格式(脑浮点)。它与 float32 具有相似的范围,但精度较低。它正变得越来越受兼容硬件(例如,较新的NVIDIA GPU、Google TPU)上深度学习训练的欢迎,因为它提供了内存节省和速度提升,同时通常比 float16 保持更好的稳定性。

整数类型:

torch.int64 (或 torch.long):64位有符号整数。默认整数类型。常用于张量索引和分类任务中表示类别标签。

torch.int32 (或 torch.int):32位有符号整数。

torch.int16:16位有符号整数。

torch.int8:8位有符号整数。较小的整数类型可以节省内存,并对某些操作更快,常用于模型量化。

相应的无符号整数类型也存在 (torch.uint8)。

布尔类型:

torch.bool:表示布尔值 True 或 False。对于逻辑操作、使用掩码进行索引以及条件逻辑非常重要。

类型转换

你经常需要将张量从一种数据类型转换为另一种。这就是所谓的类型转换。转换张量的主要方式是使用 .to() 方法,该方法我们在张量在设备(CPU/GPU)之间移动的背景下也见过。

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 原始浮点张量

val float_tensor = torch.tensor(Seq(1.1, 2.2, 3.3), dtype=torch.float32)

println(f"Original tensor: {float_tensor}, dtype: {float_tensor.dtype}")

// 使用 .to() 转换为 int64

val int_tensor = float_tensor.to(torch.int64)

println(f"Casted to int64: {int_tensor}, dtype: {int_tensor.dtype}") // 注意截断

// 使用 .to() 转换回 float16

val half_tensor = int_tensor.to(dtype=torch.float16) // 可以只指定 dtype

println(f"Casted to float16: {half_tensor}, dtype: {half_tensor.dtype}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 1. 创建原始浮点张量(float32 类型) ==========

// 步骤1:构造 float32 数据数组

float[] floatData = new float[]{1.1f, 2.2f, 3.3f};

// 步骤2:创建张量(指定 dtype=Float32,对应 torch.float32)

Tensor floatTensor = torch.tensor(floatData,

torch.tensorOptions().dtype(ScalarType.Float));

// 打印原始张量信息(模仿 Scala 的输出格式)

System.out.printf("Original tensor: %s, dtype: %s%n",

floatTensor, floatTensor.scalar_type());

// ========== 2. 转换为 int64 类型(注意:浮点转整数会截断小数部分) ==========

Tensor intTensor = floatTensor.to(ScalarType.Long); // Long 对应 torch.int64

System.out.printf("Casted to int64: %s, dtype: %s%n",

intTensor, intTensor.scalar_type());

// ========== 3. 转换回 float16 类型(Half 对应 torch.float16) ==========

// 方式1:直接指定 dtype(和 Scala 风格一致)

Tensor halfTensor = intTensor.to(ScalarType.Half);

System.out.printf("Casted to float16: %s, dtype: %s%n",

halfTensor, halfTensor.scalar_type());

输出:

代码语言:

TXT

自动换行

AI代码解释

Original tensor: tensor([1.1000, 2.2000, 3.3000]), dtype: torch.float32

Casted to int64: tensor([1, 2, 3]), dtype: torch.int64

Casted to float16: tensor([1., 2., 3.], dtype=torch.float16), dtype: torch.float16

注意,从浮点数转换为整数会截断小数部分。

PyTorch 还提供了便捷方法来进行常见的类型转换:

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 原始整数张量

val tensor_a = torch.tensor(Seq(0, 1, 0, 1))

println(f"Original tensor: {tensor_a}, dtype: {tensor_a.dtype}")

// 使用 .float() 转换为浮点数

val tensor_b = tensor_a.float() // 等同于 .to(torch.float32)

println(f".float(): {tensor_b}, dtype: {tensor_b.dtype}")

// 使用 .long() 转换为长整型

val tensor_c = tensor_b.long() // 等同于 .to(torch.int64)

println(f".long(): {tensor_c}, dtype: {tensor_c.dtype}")

// 使用 .bool() 转换为布尔型

val tensor_d = tensor_a.bool() // 等同于 .to(torch.bool)

println(f".bool(): {tensor_d}, dtype: {tensor_d.dtype}")

展开

代码语言:

Java

自动换行

AI代码解释

// 25-29. 数据类型转换

LongPointer tensorAData2 = new LongPointer(0, 1, 0, 1);

Tensor tensor_a = tensor(new LongArrayRef(tensorAData2, new LongPointer(4)));

System.out.printf("Original tensor: %s, dtype: %s%n", tensor_a, tensor_a.dtype());

Tensor tensor_b = tensor_a.to(ScalarType.Float); //.toTensor();

System.out.printf(".float(): %s, dtype: %s%n", tensor_b, tensor_b.dtype());

Tensor tensor_c = tensor_b.to(ScalarType.Long); //.toTensor();

System.out.printf(".long(): %s, dtype: %s%n", tensor_c, tensor_c.dtype());

Tensor tensor_d = tensor_a.to(ScalarType.Bool); //.toTensor();

System.out.printf(".bool(): %s, dtype: %s%n", tensor_d, tensor_d.dtype());

输出:

代码语言:

TXT

自动换行

AI代码解释

.float(): tensor([0., 1., 0., 1.]), dtype: torch.float32

.long(): tensor([0, 1, 0, 1]), dtype: torch.int64

.bool(): tensor([False, True, False, True]), dtype: torch.bool

请记住,类型转换通常会在内存中创建一个具有指定数据类型的新张量,而不是就地修改原始张量。

操作中的类型提升

当你对不同数据类型的张量执行操作时,PyTorch 通常会自动提升类型以确保兼容性。一般规则是,整数类型与浮点类型进行操作时,结果将是浮点类型。不同浮点类型之间的操作通常会得到更高精度的类型。

展开

代码语言:

TXT

自动换行

AI代码解释

val int_t = torch.tensor(Seq(1, 2), dtype=torch.int32)

val float_t = torch.tensor(Seq(0.5, 0.5), dtype=torch.float32)

val double_t = torch.tensor(Seq(0.1, 0.1), dtype=torch.float64)

// int32 + float32 -> float32

val result1 = int_t + float_t

println(f"\nint32 + float32 = {result1}, dtype: {result1.dtype}")

// float32 + float64 -> float64

val result2 = float_t + double_t

println(f"float32 + float64 = {result2}, dtype: {result2.dtype}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 1. 创建不同数据类型的张量 ==========

// int32 张量(对应 torch.int32)

int[] intData = new int[]{1, 2}; // 用int[]匹配int32类型

Tensor int_t = torch.tensor(intData,

torch.tensorOptions().dtype(ScalarType.Int));

// float32 张量(对应 torch.float32)

float[] floatData = new float[]{0.5f, 0.5f}; // 用float[]匹配float32类型

Tensor float_t = torch.tensor(floatData,

torch.tensorOptions().dtype(ScalarType.Float));

// float64 张量(对应 torch.float64)

double[] doubleData = new double[]{0.1, 0.1}; // 用double[]匹配float64类型

Tensor double_t = torch.tensor(doubleData,

torch.tensorOptions().dtype(ScalarType.Double));

// ========== 2. 张量相加 & 验证类型提升规则 ==========

// 规则1:int32 + float32 → 自动提升为 float32

Tensor result1 = int_t.add(float_t); // Java中"+"运算符不可用,用add()方法替代

System.out.printf("%nint32 + float32 = %s, dtype: %s%n",

result1, result1.scalar_type());

// 规则2:float32 + float64 → 自动提升为 float64

Tensor result2 = float_t.add(double_t);

System.out.printf("float32 + float64 = %s, dtype: %s%n",

result2, result2.scalar_type());

输出:

代码语言:

TXT

自动换行

AI代码解释

int32 + float32 = tensor([1.5000, 2.5000]), dtype: torch.float32

float32 + float64 = tensor([0.6000, 0.6000], dtype=torch.float64), dtype: torch.float64

虽然方便,但要注意自动类型提升,因为它如果未被预料到,可能会导致意外结果或性能影响。使用 .to() 进行显式类型转换可以让你对计算中所需的数据类型有更清晰的控制。

理解和管理张量数据类型是高效 PyTorch 编程的重要组成部分。它能让你控制内存占用,运用硬件加速(如GPU上的FP16),并为你的特定深度学习任务保持必要的数值精度。

CPU 与 GPU 张量

深度学习计算,特别是涉及大型张量和复杂模型的计算,需要强大的计算能力。中央处理器(CPU)用途广泛,而图形处理器(GPU)提供大规模并行能力,可以大幅加速构成神经网络核心的矩阵和向量运算。PyTorch 提供简单直接的机制来管理张量的存放位置和计算发生地。了解如何在 CPU 和 GPU 之间移动张量是高效模型训练和推理的必备技能。

CPU:默认计算中心

默认情况下,当您创建 PyTorch 张量时未指定设备,它会分配在 CPU 上。

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 默认在 CPU 上创建张量

val cpu_tensor = torch.tensor(Seq(1.0, 2.0, 3.0))

println(f"默认张量设备:{cpu_tensor.device}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 创建默认 CPU 张量(和 Scala 逻辑一致) ==========

double[] data = new double[]{1.0, 2.0, 3.0};

Tensor cpu_tensor = torch.tensor(data); // 不指定设备,默认创建在 CPU 上

// ========== 获取并打印张量的设备信息(对齐原代码输出格式) ==========

// 获取设备对象 → 转为字符串输出(匹配 Scala 的 cpu_tensor.device 格式)

Device device = cpu_tensor.device();

System.out.printf("默认张量设备:%s%n", device);

CPU 完全适用于许多任务,包括预处理步骤、较小规模的计算,或在没有兼容 GPU 时运行模型。然而,对于训练大型深度学习模型,由于 CPU 的顺序处理特性与 GPU 的并行架构相比,仅依赖 CPU 常常会导致训练时间过长,难以承受。

GPU:使用并行加速深度学习

GPU 包含数百或数千个核心,旨在同时执行大量计算。这种架构非常适合深度学习中常见的运算类型,如大型矩阵乘法和卷积。PyTorch 通过 CUDA(统一计算设备架构)平台使用 NVIDIA GPU。

要使用 GPU,您需要:

一个兼容 CUDA 的 NVIDIA GPU。

已安装适当的 CUDA 工具包。

安装了 CUDA 支持的 PyTorch 版本。

检查 GPU 可用性并设置设备

在尝试使用 GPU 之前,最好检查它是否可用且已为 PyTorch 正确配置。torch.cuda.is_available() 函数在 PyTorch 可以访问支持 CUDA 的 GPU 时返回 True。

之后我们可以创建一个 torch.device 对象来表示我们的目标计算设备(CPU 或 GPU)。这使得代码具有适应性,如果 GPU 可用则自动使用 GPU,否则回退到 CPU。

展开

代码语言:

TXT

自动换行

AI代码解释

import torch.*

// 检查 CUDA 可用性并相应设置设备

if torch.cuda.is_available() then

val device = torch.device("cuda") // 使用第一个可用的 CUDA 设备

println(f"CUDA (GPU) 可用。使用设备:{device}")

// 您也可以指定特定 GPU,例如 torch.device("cuda:0")

else:

val device = torch.device("cpu")

println(f"CUDA (GPU) 不可用。使用设备:{device}")

// device 现在包含 torch.device('cuda') 或 torch.device('cpu')

展开

代码语言:

Java

自动换行

AI代码解释

Device device5 = new Device(cuda_is_available() ? DeviceType.CUDA : DeviceType.CPU);

System.out.printf("正在使用设备: %s%n", device5);

Device device4 = new Device(DeviceType.CPU);

Tensor cpu_a = torch.randn(new LongArrayRef(new LongPointer(2, 2)));

Tensor gpu_b = torch.randn(new LongArrayRef(new LongPointer(2, 2))).to(device4, ScalarType.Long);

try {

Tensor c = cpu_a.add(gpu_b);

} catch (RuntimeException e) {

System.out.printf("在不同设备上执行操作时出错:%s%n", e.getMessage());

}

直接在设备上创建张量

您可以在张量创建期间使用 device 参数直接指定目标设备。这通常比在 CPU 上创建然后再移动更高效。

展开

代码语言:

TXT

自动换行

AI代码解释

// 直接在选定设备上创建张量

try:

// 如果 device='cpu',此张量将在 CPU 上;如果 device='cuda',则在 GPU 上

val device_tensor = torch.randn(3, 4, device=device)

println(f"张量创建于:{device_tensor.device}")

catch RuntimeError as e:

println(f"无法直接在 {device} 上创建张量:{e}") // 处理未找到 GPU 等情况

展开

代码语言:

Java

自动换行

AI代码解释

// 1. 定义要使用的设备(可修改为 "cuda" 测试 GPU 场景)

String deviceName = "cpu"; // 可选值:"cpu" / "cuda" / "cuda:0"

Device device = new Device(deviceName);

try {

// 2. 直接在选定设备上创建 3x4 的随机正态分布张量(randn)

// 方式:通过 tensorOptions 指定 device,再传入 randn

Tensor device_tensor = torch.randn(

new long[]{3, 4}, // 张量形状 3x4

torch.tensorOptions().device(device) // 指定创建设备

);

// 3. 打印张量所在设备(对齐原代码输出格式)

System.out.printf("张量创建于:%s%n", device_tensor.device());

// 释放张量资源

device_tensor.close();

} catch (Exception e) { // 捕获设备不可用等运行时异常(对应 Scala 的 RuntimeError)

// 4. 处理异常:如 GPU 未安装/未找到等情况

System.out.printf("无法直接在 %s 上创建张量:%s%n", deviceName, e.getMessage());

} finally {

// 释放 Device 对象资源

device.close();

}

在 CPU 和 GPU 之间移动张量

通常,您需要在设备之间传输现有张量。例如,从磁盘加载的数据通常位于 CPU 上,但您的模型可能在 GPU 上以加快计算。移动张量的主要方法是 .to() 方法。

.to() 方法接受 torch.device 对象、设备字符串(例如 'cuda'、'cpu'),甚至另一个张量(在这种情况下,张量会移动到与参数张量相同的设备上)作为输入。它在指定设备上返回一个 新 张量。原始张量在其原始设备上保持不变。

展开

代码语言:

TXT

自动换行

AI代码解释

// 从 CPU 张量开始

val cpu_tensor = torch.ones(2, 2)

println(f"原始张量:{cpu_tensor.device}")

// 将张量移动到选定设备(如果 GPU 可用,则为 GPU;否则为 CPU)

// 请记住,'device' 是根据可用性预先设置的

val moved_tensor = cpu_tensor.to(device)

println(f"移动后的张量:{moved_tensor.device}")

// 如果张量在 GPU 上,则显式移回 CPU

if moved_tensor.is_cuda then // 检查张量是否在 CUDA 设备上

val back_to_cpu = moved_tensor.to("cpu")

println(f"张量移回至:{back_to_cpu.device}")

展开

代码语言:

Java

自动换行

AI代码解释

// ========== 1. 预先设置目标设备(根据 GPU 可用性) ==========

// 逻辑:GPU 可用则用 CUDA,否则用 CPU(对齐原代码的 'device' 预设逻辑)

Device device = torch.cuda.is_available()

? new Device("cuda:0")

: new Device("cpu");

String deviceName = device.str().getString(); // 获取设备名称(cpu/cuda:0)

// ========== 2. 创建原始 CPU 张量 ==========

Tensor cpu_tensor = torch.ones(2, 2); // 默认创建在 CPU 上

System.out.printf("原始张量:%s%n", cpu_tensor.device());

// ========== 3. 将张量移动到选定设备 ==========

Tensor moved_tensor = cpu_tensor.to(device); // 迁移到目标设备

System.out.printf("移动后的张量:%s%n", moved_tensor.device());

// ========== 4. 检查是否在 CUDA 上,若在则移回 CPU ==========

// Java 中用 is_cuda() 检查张量是否在 CUDA 设备上(对应 Scala 的 is_cuda)

if (moved_tensor.is_cuda()) {

Tensor back_to_cpu = moved_tensor.to(new Device("cpu")); // 移回 CPU

System.out.printf("张量移回至:%s%n", back_to_cpu.device());

// 释放移回 CPU 的张量资源

back_to_cpu.close();

}

PyTorch 还提供便利方法:.cpu() 和 .cuda()。它们分别是 .to('cpu') 和 .to('cuda:0')(或当前默认的 CUDA 设备)的简写。

展开

代码语言:

TXT

自动换行

AI代码解释

// 使用便利方法(假设 GPU 可用且 'device' 为 'cuda')

if device.type == 'cuda' then

// 将 cpu_tensor 移动到 GPU

val gpu_tensor_alt = cpu_tensor.cuda()

println(f"使用 .cuda():{gpu_tensor_alt.device}")

// 将 gpu_tensor_alt 移回 CPU

val cpu_tensor_alt = gpu_tensor_alt.cpu()

println(f"使用 .cpu():{cpu_tensor_alt.device}")

展开

代码语言:

Java

自动换行

AI代码解释

// 1. 先创建一个示例 CPU 张量(模拟原代码中的 cpu_tensor)

float[] data = {1.0f, 2.0f, 3.0f, 4.0f};

Tensor cpuTensor = Tensor.fromBlob(data, new long[]{2, 2});

// 2. 检查 CUDA 是否可用(核心前提)

if (Device.isAvailable()) {

// 将 CPU 张量移动到 GPU(对应原代码的 .cuda())

Tensor gpuTensorAlt = cpuTensor.to(Device.Type.CUDA);

System.out.println("使用 to(Device.Type.CUDA):" + gpuTensorAlt.device());

// 将 GPU 张量移回 CPU(对应原代码的 .cpu())

Tensor cpuTensorAlt = gpuTensorAlt.to(Device.Type.CPU);

System.out.println("使用 to(Device.Type.CPU):" + cpuTensorAlt.device());

} else {

System.out.println("CUDA 不可用,跳过 GPU 张量迁移");

}

设备管理的重要考量

设备一致性: 涉及多个张量的操作(例如,加法、矩阵乘法)通常要求所有参与的张量都在 同一 设备上。尝试在 CPU 张量和 GPU 张量之间进行操作将导致 RuntimeError。在执行操作前,请确保您的数据和模型位于同一设备上。

展开

代码语言:

TXT

自动换行

AI代码解释

// 错误示例(假设 device='cuda')

val cpu_a = torch.randn(2, 2)

val gpu_b = torch.randn(2, 2, device=device)

try:

// 如果 device 是 'cuda',这很可能会失败

val c = cpu_a + gpu_b

catch RuntimeError as e:

println(f"在不同设备上执行操作时出错:{e}")

展开

代码语言:

Java

自动换行

AI代码解释

// 1. 定义目标设备(模拟原代码中的 device='cuda')

Device targetDevice = torch.is_available() ?

new Device(torch.DeviceType.CUDA) :

new Device(torch.DeviceType.CPU);

// 2. 创建 CPU 张量(对应原代码的 cpu_a)

Tensor cpuA = randn(new long[]{2, 2}); // 默认创建在 CPU 上

System.out.println("cpuA 设备: " + cpuA.device());

// 3. 创建指定设备的张量(对应原代码的 gpu_b)

var device = torch.device(targetDevice);

Tensor gpuB = randn(new long[]{2, 2}, device);

System.out.println("gpuB 设备: " + gpuB.device());

// 4. 尝试执行跨设备加法(错误示例)

try {

// 注意:Java 中 PyTorch 张量加法需用 add() 方法,而非 + 运算符

Tensor c = cpuA.add(gpuB);

System.out.println("加法结果设备: " + c.device());

} catch (RuntimeException e) {

// 捕获设备不匹配的运行时异常(对应原代码的 RuntimeError)

System.out.println("在不同设备上执行操作时出错:" + e.getMessage());

} finally {

// 释放所有张量资源,避免内存泄漏

cpuA.close();

gpuB.close();

}

数据传输开销: 在 CPU 和 GPU 内存之间移动数据并非瞬间完成。虽然 GPU 计算速度快,但如果管理不当,数据的来回传输可能成为瓶颈。为了最佳性能,请尝试在 GPU 上尽可能多地执行操作,然后在需要时(例如,保存到磁盘或转换为 NumPy)再将最终结果移回 CPU。

模型放置: 与张量一样,使用 torch.nn.Module 定义的神经网络模型也需要使用 .to(device) 方法移动到适当的设备。这确保模型的参数(本身也是张量)位于目标设备上进行计算。在讨论模型构建时,会更详细地介绍这一点。

掌握使用 device 和 .to() 方法进行张量放置是发挥 GPU 计算能力和编写高效、硬件感知型 PyTorch 代码不可或缺的要素。请记住始终检查设备一致性并注意数据传输成本。

练习:张量操作技巧

练习张量操作技术,包括索引、形状改变、组合、广播、数据类型以及张量在不同设备间的移动。通过这些示例进行练习是巩固理解的最佳途径。请确保已导入PyTorch。

代码语言:

TXT

自动换行

AI代码解释

import torch.*

import numpy as np

// 检查CUDA是否可用并设置设备

val device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

println(f"正在使用设备: {device}")

索引和切片练习

索引和切片是访问和修改张量部分内容的基本操作。让我们尝试选择特定的数据点。

任务1: 创建一个二维张量,并选择第二行第三列的元素。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个示例二维张量(3行,4列)

val data = Seq(Seq(1, 2, 3, 4), Seq(5, 6, 7, 8), Seq(9, 10, 11, 12))

val tensor_2d = torch.tensor(data)

println("原始张量:\n", tensor_2d)

// 选择行索引为1、列索引为2的元素

val element = tensor_2d(1, 2)

println("\n在 [1, 2] 的元素:", element)

println("值:", element.item()) // 使用 .item() 获取Scala数值

展开

代码语言:

Java

自动换行

AI代码解释

// 34-37. 张量索引与掩码

LongPointer data2D = new LongPointer(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);

Tensor tensor_2d = tensor(new LongArrayRef(data2D, new LongPointer(3, 4)));

System.out.printf("原始张量:\n %s%n", tensor_2d);

var index = new TensorIndexArrayRef(new LongPointer(1, 2));

Tensor element = tensor_2d.index(index);

System.out.printf("\n在 [1, 2] 的元素: %s%n", element);

System.out.printf("值: %f%n", element.item().toDouble());

任务2: 选择 tensor_2d 的整个第二行。

展开

代码语言:

TXT

自动换行

AI代码解释

// 选择索引为1的行

val row_1 = tensor_2d(1)

println("\n第二行(索引1):\n", row_1)

// 使用切片的替代方法(选择第1行,所有列)

val row_1_alt = tensor_2d(1, ::)

println("\n第二行(替代方法):\n", row_1_alt)

代码语言:

Java

自动换行

AI代码解释

Tensor row_1 = tensor_2d.index(new TensorIndexArrayRef(new LongPointer(1)));

System.out.printf("\n第二行(索引1):\n %s%n", row_1);

任务3: 选择 tensor_2d 的第三列。

代码语言:

TXT

自动换行

AI代码解释

// 选择所有行,列索引为2

val col_2 = tensor_2d(::, 2)

println("\n第三列(索引2):\n", col_2)

代码语言:

Java

自动换行

AI代码解释

Tensor col_2 = tensor_2d.index(new TensorIndexArrayRef(new TensorIndex(new EllipsisIndexType()), new TensorIndex(2)));

System.out.printf("\n第三列(索引2):\n %s%n", col_2);

任务4: 创建一个布尔掩码以选择 tensor_2d 中所有大于7的元素,然后使用该掩码提取这些元素。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建布尔掩码

val mask = tensor_2d > 7

println("\n布尔掩码(张量 > 7):\n", mask)

// 应用掩码

val selected_elements = tensor_2d(mask)

println("\n大于7的元素:\n", selected_elements)

展开

代码语言:

Java

自动换行

AI代码解释

Tensor mask = tensor_2d.gt(new Scalar(7));

System.out.printf("\n布尔掩码(张量 > 7):\n %s%n", mask);

Tensor selected_elements = masked_select(tensor_2d, mask);

System.out.printf("\n大于7的元素:\n %s%n", selected_elements);

这些练习说明了标准Python索引如何结合类似NumPy的切片和布尔掩码,以提供灵活的数据访问方式。

形状改变与重排练习

在不改变张量数据的情况下改变其形状是很常见的,尤其是在为不同神经网络层准备输入时。

任务1: 创建一个包含12个元素的一维张量,并将其形状改为3x4的张量。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个值从0到11的张量

val tensor_1d = torch.arange(12)

println("\n原始一维张量:", tensor_1d)

// 使用 reshape() 改变形状

val reshaped_tensor = tensor_1d.reshape(3, 4)

println("\n改变形状为3x4:\n", reshaped_tensor)

// 使用 view() 改变形状 - 注意 view 适用于内存连续的张量

// arange 创建的是连续张量,因此 view 在这里适用。

val view_tensor = tensor_1d.view(3, 4)

println("\n视为3x4:\n", view_tensor)

展开

代码语言:

Java

自动换行

AI代码解释

// 38-39. 张量变形与维度置换

Tensor tensor_1d = torch.arange(new Scalar(12));

System.out.printf("\n原始一维张量:\n %s%n", tensor_1d);

Tensor reshaped_tensor = tensor_1d.reshape(new LongArrayRef(new LongPointer(3, 4)));

System.out.printf("\n改变形状为3x4:\n %s%n", reshaped_tensor);

Tensor view_tensor = tensor_1d.view(new LongArrayRef(new LongPointer(3, 4)));

System.out.printf("\n视为3x4:\n %s%n", view_tensor);

请记住,view 要求张量在内存中是连续的,并共享底层数据。reshape 可能会返回一个副本或视图,具体取决于连续性。

任务2: 给定 reshaped_tensor(3x4),使用 permute 交换其维度,得到一个4x3的张量。

展开

代码语言:

TXT

自动换行

AI代码解释

// 原始3x4张量

println("\n原始3x4张量:\n", reshaped_tensor)

// 交换维度0和1

val permuted_tensor = reshaped_tensor.permute(1, 0)

println("\n置换为4x3:\n", permuted_tensor)

println("原始形状:", reshaped_tensor.shape)

println("置换后形状:", permuted_tensor.shape)

代码语言:

Java

自动换行

AI代码解释

Tensor permuted_tensor = reshaped_tensor.permute(new LongArrayRef(new LongPointer(1, 0)));

System.out.printf("\n置换为4x3:\n %s%n", permuted_tensor);

System.out.printf("原始形状: %s%n", reshaped_tensor.sizes());

System.out.printf("置换后形状: %s%n", permuted_tensor.sizes());

permute 在改变图像维度顺序等任务中有用(例如,从通道数 x 高度 x 宽度 变为 高度 x 宽度 x 通道数)。

连接与拆分练习

合并或拆分张量通常是必要的,尤其是在处理批次或不同特征集时。

任务1: 创建两个2x3张量,并沿维度0(行)连接它们。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建两个2x3张量

val tensor_a = torch.tensor(Seq(Seq(1, 2, 3), Seq(4, 5, 6)))

val tensor_b = torch.tensor(Seq(Seq(7, 8, 9), Seq(10, 11, 12)))

println("\n张量A:\n", tensor_a)

println("张量B:\n", tensor_b)

// 沿维度0连接(堆叠行)

val concatenated_rows = torch.cat((tensor_a, tensor_b), dim=0)

println("\n沿行连接(dim=0):\n", concatenated_rows)

println("形状:", concatenated_rows.shape) // 应该是 4x3

展开

代码语言:

Java

自动换行

AI代码解释

LongPointer tensorAData = new LongPointer(1, 2, 3, 4, 5, 6);

LongPointer tensorBData = new LongPointer(7, 8, 9, 10, 11, 12);

Tensor tensor_a1 = torch.tensor(new LongArrayRef(tensorAData, new LongPointer(2, 3)));

Tensor tensor_b1 = torch.tensor(new LongArrayRef(tensorBData, new LongPointer(2, 3)));

System.out.printf("\n张量A:\n %s%n", tensor_a1);

System.out.printf("张量B:\n %s%n", tensor_b1);

TensorVector catTensors = new TensorVector(tensor_a1, tensor_b1);

Tensor concatenated_rows = cat(catTensors, 0);

System.out.printf("\n沿行连接(dim=0):\n %s%n", concatenated_rows);

System.out.printf("形状: %s%n", concatenated_rows.sizes());

任务2: 沿维度1(列)连接 tensor_a 和 tensor_b。

代码语言:

TXT

自动换行

AI代码解释

// 沿维度1连接(连接列)

val concatenated_cols = torch.cat((tensor_a, tensor_b), dim=1)

println("\n沿列连接(dim=1):\n", concatenated_cols)

println("形状:", concatenated_cols.shape) // 应该是 2x6

代码语言:

Java

自动换行

AI代码解释

Tensor concatenated_cols = cat(catTensors, 1);

System.out.printf("\n沿列连接(dim=1):\n %s%n", concatenated_cols);

System.out.printf("形状: %s%n", concatenated_cols.sizes());

任务3: 使用 stack 组合 tensor_a 和 tensor_b,形成一个形状为2x2x3的新张量。

展开

代码语言:

TXT

自动换行

AI代码解释

// 堆叠张量 - 创建一个新维度(默认 dim=0)

val stacked_tensor = torch.stack((tensor_a, tensor_b), dim=0)

println("\n堆叠的张量(dim=0):\n", stacked_tensor)

println("形状:", stacked_tensor.shape) // 应该是 2x2x3

// 沿维度1堆叠

val stacked_tensor_dim1 = torch.stack((tensor_a, tensor_b), dim=1)

println("\n堆叠的张量(dim=1):\n", stacked_tensor_dim1)

println("形状:", stacked_tensor_dim1.shape) // 应该是 2x2x3

展开

代码语言:

Java

自动换行

AI代码解释

Tensor stacked_tensor = stack(catTensors, 0);

System.out.printf("\n堆叠的张量(dim=0):\n %s%n", stacked_tensor);

System.out.printf("形状: %s%n", stacked_tensor.sizes());

Tensor stacked_tensor_dim1 = stack(catTensors, 1);

System.out.printf("\n堆叠的张量(dim=1):\n %s%n", stacked_tensor_dim1);

System.out.printf("形状: %s%n", stacked_tensor_dim1.sizes());

请注意 stack 如何添加一个新维度,而 cat 沿着现有维度连接。

任务4: 创建一个6x4张量,并沿维度0将其拆分为三个等大小的部分。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个值从0到23的张量

val tensor_to_split = torch.arange(24).reshape(6, 4)

println("\n待拆分张量(6x4):\n", tensor_to_split)

// 沿维度0拆分为3个部分

val chunks = torch.chunk(tensor_to_split, chunks=3, dim=0)

println("\n拆分为3个部分:")

for (i, chunk) <- chunks.zipWithIndex {

println(f"部分 {i}(形状 ${chunk.shape}):\n", chunk)

}

展开

代码语言:

Java

自动换行

AI代码解释

// 43. 张量拆分

Tensor tensor_to_split = torch.arange(new Scalar(24)).reshape(new LongArrayRef(new LongPointer(6, 4)));

System.out.printf("\n待拆分张量(6x4):\n %s%n", tensor_to_split);

TensorVector chunks = chunk(tensor_to_split, 3, 0);

System.out.println("\n拆分为3个部分:");

for (int i = 0; i < chunks.size(); i++) {

System.out.printf("部分 %d(形状 %s):\n %s%n", i, chunks.get(i).sizes(), chunks.get(i));

}

广播练习

广播简化了不同形状张量之间的操作。

任务1: 创建一个3x3张量和一个1x3张量(行向量)。将它们相加。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个3x3张量和一个1x3张量(行向量)

val matrix = torch.tensor(Seq(Seq(1, 2, 3), Seq(4, 5, 6), Seq(7, 8, 9)))

val row_vector = torch.tensor(Seq(Seq(10, 20, 30))) // 形状 1x3

println("\n矩阵(3x3):\n", matrix)

println("行向量(1x3):\n", row_vector)

// 广播加法: row_vector 被扩展以匹配 matrix 的形状

val result = matrix + row_vector

println("\n矩阵 + 行向量(广播):\n", result)

展开

代码语言:

Java

自动换行

AI代码解释

// 44. 广播加法 - 矩阵+行向量

LongPointer rowVectorData = new LongPointer(10, 20, 30);

Tensor row_vector = torch.tensor(new LongArrayRef(rowVectorData, new LongPointer(1, 3)));

System.out.printf("\n矩阵(3x3):\n %s%n", matrix2);

System.out.printf("行向量(1x3):\n %s%n", row_vector);

Tensor result = matrix2.add(row_vector);

System.out.printf("\n矩阵 + 行向量(广播):\n %s%n", result);

PyTorch 自动扩展了 row_vector(形状1x3)的行,使其形状变为3x3,从而允许与 matrix 进行逐元素相加。

任务2: 创建一个3x3张量和一个3x1张量(列向量)。将它们相加。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个3x3张量和一个3x1张量(列向量)

val matrix = torch.tensor(Seq(Seq(1, 2, 3), Seq(4, 5, 6), Seq(7, 8, 9)))

val col_vector = torch.tensor(Seq(Seq(100), Seq(200), Seq(300))) // 形状 3x1

println("\n矩阵(3x3):\n", matrix)

println("列向量(3x1):\n", col_vector)

// 广播加法: col_vector 被扩展以匹配 matrix 的形状

val result_col = matrix + col_vector

println("\n矩阵 + 列向量(广播):\n", result_col)

展开

代码语言:

Java

自动换行

AI代码解释

// 45. 广播加法 - 矩阵+列向量

LongPointer matrixData2 = new LongPointer(1, 2, 3, 4, 5, 6, 7, 8, 9);

Tensor matrix2 = torch.tensor(new LongArrayRef(matrixData2, new LongPointer(3, 3)));

LongPointer colVectorData = new LongPointer(100, 200, 300);

Tensor col_vector = torch.tensor(new LongArrayRef(colVectorData, new LongPointer(3, 1)));

System.out.printf("\n矩阵(3x3):\n %s%n", matrix2);

System.out.printf("列向量(3x1):\n %s%n", col_vector);

Tensor result_col = matrix2.add(col_vector);

System.out.printf("\n矩阵 + 列向量(广播):\n %s%n", result_col);

展开

代码语言:

Java

自动换行

AI代码解释

// 1. 定义矩阵和列向量的原始数据

// 3x3矩阵数据

long[] matrixData = {1, 2, 3, 4, 5, 6, 7, 8, 9};

// 3x1列向量数据

long[] colVectorData = {100, 200, 300};

// 2. 创建3x3矩阵张量(对应原代码的 matrix)

Tensor matrix = Tensor.fromBlob(matrixData, new long[]{3, 3});

// 3. 创建3x1列向量张量(对应原代码的 col_vector)

Tensor colVector = Tensor.fromBlob(colVectorData, new long[]{3, 1});

// 打印张量信息(模拟原代码的 println)

System.out.println("\n矩阵(3x3):\n" + tensorToString(matrix));

System.out.println("列向量(3x1):\n" + tensorToString(colVector));

// 4. 执行广播加法(对应原代码的 matrix + col_vector)

// Java中用add()方法实现加法,PyTorch会自动处理广播逻辑

Tensor resultCol = matrix.add(colVector);

// 打印广播加法结果

System.out.println("\n矩阵 + 列向量(广播):\n" + tensorToString(resultCol));

这里,col_vector(形状3x1)被广播到各列,以匹配 matrix 的3x3形状。

数据类型练习

管理数据类型对于内存效率和数值稳定性很重要。

任务1: 创建一个整数张量并检查其 dtype。然后将其转换为浮点张量。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个整数张量

val int_tensor = torch.tensor(Seq(1, 2, 3, 4))

println("\n整数张量:", int_tensor)

println("数据类型:", int_tensor.dtype)

// 转换为 float32

val float_tensor = int_tensor.to(torch.float32)

// 替代方法: float_tensor = int_tensor.float()

println("\n转换为浮点张量:", float_tensor)

println("数据类型:", float_tensor.dtype)

展开

代码语言:

Java

自动换行

AI代码解释

// 46. 整数张量转换为float32

LongPointer intData1 = new LongPointer(1, 2, 3, 4);

Tensor int_tensor2 = torch.tensor(new LongArrayRef(intData1, new LongPointer(4)));

System.out.printf("\n整数张量:%s%n", int_tensor2);

System.out.printf("数据类型:%s%n", int_tensor2.dtype());

Tensor float_tensor2 = int_tensor2.to(ScalarType.Float);

System.out.printf("\n转换为浮点张量: %s%n", float_tensor2);

System.out.printf("数据类型:%s%n", float_tensor2.dtype());

任务2: 创建一个浮点张量并将其转换为整数张量。观察任何变化。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建一个浮点张量

val float_tensor_orig = torch.tensor(Seq(1.1, 2.7, 3.5, 4.9))

println("\n原始浮点张量:", float_tensor_orig)

println("数据类型:", float_tensor_orig.dtype)

// 转换为 int32

val int_tensor_cast = float_tensor_orig.to(torch.int32)

// 替代方法: int_tensor_cast = float_tensor_orig.int()

println("\n转换为整数张量:", int_tensor_cast)

println("数据类型:", int_tensor_cast.dtype)

展开

代码语言:

Java

自动换行

AI代码解释

// 47. 浮点张量转换为int32

FloatPointer floatData1 = new FloatPointer(1.1f, 2.7f, 3.5f, 4.9f);

Tensor float_tensor_orig = torch.tensor(new FloatArrayRef(floatData1, new FloatPointer(4)));

System.out.printf("\n原始浮点张量: %s%n", float_tensor_orig);

System.out.printf("数据类型: %s%n", float_tensor_orig.dtype());

Tensor int_tensor_cast = float_tensor_orig.to(ScalarType.Int));

System.out.printf("\n转换为整数张量: %s%n", int_tensor_cast);

System.out.printf("数据类型: %s%n", int_tensor_cast.dtype());

请注意,从浮点数转换为整数会截断小数部分。请注意可能的数据精度损失。

CPU 与 GPU 练习

将张量移动到合适的设备(CPU或GPU)是必要的,以发挥硬件加速的优势。

任务1: 创建一个张量并检查其默认设备。然后,将其移动到GPU(如果可用),再移回CPU。

展开

代码语言:

TXT

自动换行

AI代码解释

// 创建张量(除非另有说明,否则默认为CPU)

val cpu_tensor = torch.randn(2, 2)

println(f"\n张量在CPU上: {cpu_tensor.device}\n", cpu_tensor)

// 移动到配置的设备(如果可用则为GPU,否则为CPU)

val device_tensor = cpu_tensor.to(device)

println(f"\n张量已移动到 {device_tensor.device}:\n", device_tensor)

// 明确移回CPU

val cpu_tensor_again = device_tensor.to("cpu")

println(f"\n张量已移回CPU: {cpu_tensor_again.device}\n", cpu_tensor_again)

// 执行操作 - 需要张量在同一设备上

if device_tensor.device != cpu_tensor.device then

println("\n在不同设备上的张量相加会导致错误。")

// 这会失败: cpu_tensor + device_tensor

// 正确方法:

val result_on_device = device_tensor + device_tensor

println(f"在 {result_on_device.device} 上的操作结果:\n", result_on_device)

else

println("\n两个张量都在CPU上,相加没问题。")

val result_on_cpu = cpu_tensor + cpu_tensor_again

println(f"在 {result_on_cpu.device} 上的操作结果:\n", result_on_cpu)

展开

代码语言:

Java

自动换行

AI代码解释

// 48. 设备管理与张量移动

Device device6 = new Device(cuda_is_available() ? DeviceType.CUDA : DeviceType.CPU);

Tensor cpu_tensor_a = torch.randn(new LongArrayRef(new LongPointer(2, 2)));

System.out.printf("\n张量在CPU上: %s\n %s%n", cpu_tensor_a.device(), cpu_tensor_a);

Tensor device_tensor = cpu_tensor_a.to(device6, ScalarType.Float);

System.out.printf("\n张量已移动到 %s:\n %s%n", device_tensor.device(), device_tensor);

Tensor cpu_tensor_again = device_tensor.to(new Device(DeviceType.CPU), ScalarType.Float);

System.out.printf("\n张量已移回CPU: %s\n %s%n", cpu_tensor_again.device(), cpu_tensor_again);

if (!device_tensor.device().equals(cpu_tensor_a.device())) {

System.out.println("\n在不同设备上的张量相加会导致错误。");

Tensor result_on_device = device_tensor.add(device_tensor);

System.out.printf("在 %s 上的操作结果:\n %s%n", result_on_device.device(), result_on_device);

} else {

System.out.println("\n两个张量都在CPU上,相加没问题。");

Tensor result_on_cpu = cpu_tensor_a.add(cpu_tensor_again);

System.out.printf("在 %s 上的操作结果:\n %s%n", result_on_cpu.device(), result_on_cpu);

}

请记住,张量之间的操作通常要求它们在同一设备上进行。使用 .to(device) 显式移动张量是 PyTorch 代码中常见的做法,尤其是在为GPU训练准备数据和模型时。

本次实践涵盖了操作张量的主要技巧。随着您构建更复杂的模型,熟练掌握索引、形状改变、组合张量、理解广播、管理数据类型以及控制设备放置将变得愈发重要。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档