我用Rust编写了一个简单的多线程应用程序,将数字从1添加到x。(我知道这是一个公式,但重点是用Rust编写一些多线程代码,而不是为了得到结果。)它工作得很好,但是在我将它重构为一种功能更强的风格而不是命令式之后,多线程就不再加速了。在检查CPU使用情况时,似乎只使用了我的4核心/8线程CPU的一个核心。原始代码的CPU使用率为790%,而重构版本仅为99%。
原始代码:
use std::thread;
fn main() {
let mut handles: Vec<thread::JoinHandle<u64>> = Vec::with_capacity(8);
const thread_count: u64 = 8;
const batch_size: u64 = 20000000;
for thread_id in 0..thread_count {
handles.push(thread::spawn(move || {
let mut sum = 0_u64;
for i in thread_id * batch_size + 1_u64..(thread_id + 1) * batch_size + 1_u64 {
sum += i;
}
sum
}));
}
let mut total_sum = 0_u64;
for handle in handles.into_iter() {
total_sum += handle.join().unwrap();
}
println!("{}", total_sum);
}重构代码:
use std::thread;
fn main() {
const THREAD_COUNT: u64 = 8;
const BATCH_SIZE: u64 = 20000000;
// spawn threads that calculate a part of the sum
let handles = (0..THREAD_COUNT).map(|thread_id| {
thread::spawn(move ||
// calculate the sum of all numbers from assigned to this thread
(thread_id * BATCH_SIZE + 1 .. (thread_id + 1) * BATCH_SIZE + 1)
.fold(0_u64,|sum, number| sum + number))
});
// add the parts of the sum together to get the total sum
let sum = handles.fold(0_u64, |sum, handle| sum + handle.join().unwrap());
println!("{}", sum);
}程序的输出是相同的(12800000080000000),但重构版本是5-6倍的速度。
迭代器似乎是懒惰地评估的。如何强制对整个迭代器进行评估?我试图将它收集到一个类型为[thread::JoinHandle<u64>; THREAD_COUNT as usize]的数组中,但随后我得到了以下错误:
--> src/main.rs:14:7
|
14 | ).collect::<[thread::JoinHandle<u64>; THREAD_COUNT as usize]>();
| ^^^^^^^ a collection of type `[std::thread::JoinHandle<u64>; 8]` cannot be built from `std::iter::Iterator<Item=std::thread::JoinHandle<u64>>`
|
= help: the trait `std::iter::FromIterator<std::thread::JoinHandle<u64>>` is not implemented for `[std::thread::JoinHandle<u64>; 8]`收集到向量确实有效,但这似乎是一个奇怪的解决方案,因为它的大小是预先知道的。有比使用向量更好的方法吗?
发布于 2019-04-03 09:12:34
Rust中的迭代器很懒,所以在handles.fold试图访问迭代器的相应元素之前不会启动线程。基本上发生的事情是:
handles.fold尝试访问迭代器的第一个元素。handles.fold调用它的闭包,为第一个线程调用handle.join()。handle.join等待第一个线程完成。handles.fold尝试访问迭代器的第二个元素。在折叠结果之前,应该将句柄收集到向量中:
let handles: Vec<_> = (0..THREAD_COUNT)
.map(|thread_id| {
thread::spawn(move ||
// calculate the sum of all numbers from assigned to this thread
(thread_id * BATCH_SIZE + 1 .. (thread_id + 1) * BATCH_SIZE + 1)
.fold(0_u64,|sum, number| sum + number))
})
.collect();或者您可以使用像人丝这样的机箱,它提供并行迭代器。
https://stackoverflow.com/questions/55490906
复制相似问题