首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Pytorch On Java 你的第一个java版本的【真】 神经网络 [AI Infra 3.0]

Pytorch On Java 你的第一个java版本的【真】 神经网络 [AI Infra 3.0]

原创
作者头像
用户11467648
发布2026-03-12 20:28:53
发布2026-03-12 20:28:53
50
举报

如何用Java 写 全真的神经网络 一直是java 众多程序员的梦想,奈何你们寻不到真经,被各种玩具框架 妖魔鬼怪 挟持认知,一个个都以为必须用jni 调python 要么就是冷言嘲讽java不适合做神经网络,在ai是的已经被淘汰了? 果真如此,当然是假的了,反而是 java 在AI时代有媲美Apache spark flink 神级 框架的存在,它就是pytorch

编辑版本将使用 JavaCPP Presets for PyTorch。JavaCPP 提供了 PyTorch C++ API(LibTorch)的直接映射,因此代码风格会非常接近 C++ 版的 LibTorch,但运行在 JVM 上,注意是几百万行代码的全量编译!!!

注意 不是 Pytorch 官方支持 java 版本,也不是 java Oracle支持 Pytorch,而是ByteDeco 旗下的Javacpp 支持 PyTorch ,Pytorch官方基金会在java 的支持上只限于 andriod ,其他都非常拉胯!!! 吃水不忘挖井人,你如果要感谢的话,一定要感谢 Bytedeco 这个伟大的天才开源组织

以下是针对 javacpp-pytorch 2.1.0-1.5.13 版本的完整指南。

1. Maven 配置

首先,在你的 pom.xml 中引入依赖。pytorch-platform 会自动根据你的操作系统下载对应的本地库(包含 CPU 版本,如需 GPU 需额外配置)。

代码语言:javascript
复制
xml 体验AI代码助手 代码解读复制代码<dependencies>
    <dependency>
        <groupId>org.bytedeco</groupId>
        <artifactId>pytorch-platform</artifactId>
        <version>2.1.0-1.5.13</version>
    </dependency>
</dependencies>
Image
Image

2. 张量(Tensor)操作

在 JavaCPP 中,Tensor 的操作主要通过 org.bytedeco.pytorch.global.torch 类中的静态方法实现。

代码语言:javascript
复制
java 体验AI代码助手 代码解读复制代码
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

public class TensorDemo {
public static void main(String[] args) {
// 1. 创建一个未初始化的 5x3 矩阵
Tensor x1 = torch.empty(new long[]{5, 3}); // 默认类型
System.out.println("Empty Tensor:\n" + x1);

// 2. 创建一个随机初始化的 5x3 矩阵
Tensor x2 = torch.rand(new long[]{5, 3});
System.out.println("Random Tensor:\n" + x2);

// 3. 创建一个全为 0,数据类型为 Long 的矩阵
Tensor x3 = torch.zeros(new long[]{5, 3}, torch.dtype(kLong()));
System.out.println("Zeros Tensor:\n" + x3);

// 4. 直接使用数据初始化 (Java 数组转 Tensor)
Tensor x4 = torch.tensor(new float[]{5.5f, 3f});
System.out.println("Data Tensor:\n" + x4);
}
}

3. 定义神经网络模型

在 JavaCPP 中定义模型需要继承 Module 类,并手动注册子模块(使用 register_module)。

代码语言:javascript
复制
java 体验AI代码助手 代码解读复制代码
import static org.bytedeco.pytorch.global.torch.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module; // 注意别导错包
import static org.bytedeco.pytorch.global.torch.*;

// 定义网络结构
class Net extends Module {
// 定义层
private Conv2dImpl conv1, conv2;
private LinearImpl fc1, fc2, fc3;

public Net() {
// 1. 初始化并注册层 (注意:JavaCPP 中使用 Options 对象配置参数)
        // Conv2d(输入通道, 输出通道, 卷积核大小)
conv1 = register_module("conv1", new Conv2dImpl(new Conv2dOptions(1, 6, 3)));
conv2 = register_module("conv2", new Conv2dImpl(new Conv2dOptions(6, 16, 3)));

// Linear(输入特征数, 输出特征数)
        // 16*6*6 是根据输入图像大小推算出的展平后的特征数
fc1 = register_module("fc1", new LinearImpl(new LinearOptions(16 * 6 * 6, 120)));
fc2 = register_module("fc2", new LinearImpl(new LinearOptions(120, 84)));
fc3 = register_module("fc3", new LinearImpl(new LinearOptions(84, 10)));
}

// 前向传播
public Tensor forward(Tensor x) {
// 第一层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv1.forward(x)), new long[]{2, 2});

// 第二层卷积 -> ReLU -> 2x2 最大池化
x = max_pool2d(relu(conv2.forward(x)), new long[]{2, 2});

// 展平张量 (flatten),-1 表示自动推导 batch 维度
x = x.view(new long[]{-1, 16 * 6 * 6});

// 全连接层 -> ReLU
x = relu(fc1.forward(x));
x = relu(fc2.forward(x));

// 输出层
x = fc3.forward(x);
return x;
}
}

4. 运行模型(Main 方法)

最后,我们将一切串联起来,创建一个网络实例并进行一次前向计算。

代码语言:javascript
复制
java 体验AI代码助手 代码解读复制代码

import org.bytedeco.pytorch.Adam;
import org.bytedeco.pytorch.BCELossImpl;
import org.bytedeco.pytorch.SGD;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;

public class Main {
public static void main(String[] args) {
// 实例化网络
Net net = new Net();
System.out.println("Network structure initialized.");

// 创建一个模拟输入:1张图像,1个通道,32x32 分辨率
        // 注意 模拟数据 代码中 6x6 的特征图推导通常对应 32x32 的输入
Tensor input = torch.rand(new long[]{1, 1, 32, 32});

Tensor target = torch.rand(new long[]{1, 1}); // 模拟二分类标签

Adam optimizer = new Adam(net.parameters());
BCELossImpl lossFn = new BCELossImpl();
// 前向传播
Tensor output = net.forward(input);
optimizer.zero_grad();
var loss = lossFn.forward(output, target);
optimizer.step();
System.out.println("Initial loss: " + loss.item_double());

System.out.println("Output Tensor:");
System.out.println(output);

// 打印输出形状
System.out.println("Output sizes: " + java.util.Arrays.toString(output.sizes().vec().get()));
}
}

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. Maven 配置
  • 2. 张量(Tensor)操作
  • 3. 定义神经网络模型
  • 4. 运行模型(Main 方法)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档