모던 자바 인 액션 - Chapter 7 <병렬 데이터 처리와 성능>

자바 7이 등장하기 전까지는 데이터 컬렉션을 병렬로 처리하기 어려웠습니다.이번 장에서는 자바7에 포함된 포크 조인 프레임워크와 병렬 스트림이 어떻게 처리되는지 알아보겠습니다.

병렬 스트림

컬렉션에 stream() 대신 parallelStream()을 호출하면 쉽게 병렬 스트림을 생성 할 수 있습니다.병렬 스트림이란 각각의 쓰레드에서 스트림 요소를 처리할 수 있도록 스트림 요소를 여러 청크로 분할한 스트림입니다.따라서 병렬 스트림을 활용하면 모든 멀티코어 프로세서가 각각의 분할된 청크를 처리하도록 할당시킬 수 있습니다.

숫자 n을 인수로 받아 1부터 n까지의 모든 합계를 구하는 메서드를 통해 병렬 스트림을 활용해 보겠습니다.아래와 같이 간단하게 순차 스트림을 통해 위 메서드를 구현할 수 있습니다.

	@Test
    void noneParallel() {
        long sum = sequentialSum(10);
        Assertions.assertThat(sum).isEqualTo(55);
    }

    private long sequentialSum(long n) {
        return Stream.iterate(1L, i -> i + 1)
                .limit(n)
                .reduce(0L, Long::sum);
    }

	private long beforeStream(long n) {
        long sum = 0;
        for (int i = 1; i <= n; i++) {
            sum += i;
        }
        return sum;
    }

크게 두가지의 순차적인 방식을 통해 숫자의 합을 구하는 메서드를 구현하였습니다.하나는 순차 스트림,나머지 하나는 전통적인 방식의 합계를 구하는 for loop를 사용했습니다.

위의 두 메서드들에서 인수인 n인 기하급수적으로 커진다면 연산을 병렬로 처리하는것이 훨씬 효율적일 것입니다.지금부터 병렬스트림을 통해 이를 사용해보겠습니다.

순차 스트림을 병렬 스트림으로 변환하기

위의 sequentialSum 메서드는 기존의 리듀싱 연산이 수행되어 순차적으로 스트림 요소에 접근하여 요소를 소비합니다.이를 아래의 parallelSum으로 바꿀 경우 스트림이 여러 청크로 나뉘게 된다는 것입니다.

	private long parallelSum(long n) {
        long start = System.currentTimeMillis();
        Long sum = Stream.iterate(1L, i -> i + 1)
                .limit(n)
                .parallel()
                .reduce(0L, Long::sum);
        System.out.println("parallelSum: " + (System.currentTimeMillis() - start));
        return sum;
    }

스트림의 요소들이 아래 그림처럼 하나의 청크 단위로 나뉘고 해당 청크의 첫번째 요소들이 병렬적으로 스레드에 의해 연산되어집니다.마지막으로 리듀싱 연산으로 생성된 chunk들의 부분 결과를 다시 리듀싱 연산으로 합쳐서 전체 스트림의 연산 결과를 반환합니다.

그림1

사실 parallelSum과 같이 순차 스트림에 parallel()을 호출한다고 해서 스트림 자체에는 아무런 변화가 일어나지는 않습니다.다만 내부적으로 해당 리듀싱 연산이 병렬로 수행되어야함을 나타내는 불리언 flag가 설정됩니다.위 과정과 반대로 sequential()를 호출해 병렬 스트림을 순차 스트림으로도 변경 가능합니다.

스트림 성능 측정하기

우선 병렬 스트림의 성능을 측정하기 위해 책에서는 jmh라는 라이브러리를 사용하고 있습니다.벤치마킹 툴은 통해 생성한 메서드들을 벤치마크해보겠습니다.jmh를 설치하고 사용하는 포스팅은 추후에 게시하겠습니다.(https://github.com/melix/jmh-gradle-plugin 참조)

우선 아래와 같이 벤치마킹의 대상이 되는 메서드들을 작성해줍니다.

package parallel;

import org.openjdk.jmh.annotations.*;


@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(java.util.concurrent.TimeUnit.MILLISECONDS)
@Fork(value = 2, jvmArgs = {"-Xms4G", "-Xmx4G"})
public class ParallelStreamBenchmark {
    private static final long N = 10_000_000L;

    @Benchmark
    public long sequentialSum() {
        return java.util.stream.Stream.iterate(1L, i -> i + 1)
                .limit(N)
                .reduce(0L, Long::sum);
    }

    @TearDown(Level.Invocation)
    public void tearDown() {
        System.gc();
    }
}

테스트 용으로 순차 스트림을 통해 앞서 예제의 성능을 측정하면 아래와 같은 결과가 발생합니다.

이번에는 parallel 스트림을 통해 병렬로 위 로직을 실행해보겠습니다.

놀랍게도 parallel이 더 나쁜 성능을 나타내고 있습니다.사실 교재의 결과와 M1 맥북에서의 결과가 상당히 다릅니다.원래는 순차 스트림이 병렬 스트림보다 5배 가량 좋은 성능이 놔와야합니다.하지만 M1 맥북에 의해 위와 같이 아주 근소하게 좋은 성능을 순차스트림이 보여주고 있습니다.하지만 공부하는 입장이니 공짜 점심은 없다고 생각하겠습니다.

우선 위와 같이 병렬 스트림이 안좋은 성능을 가지는것에는 다음과 같은 이유가 있습니다.숫자의 합계를 구하는 연산은 이전 연산의 결과에 따라 다음 번 함수의 입력이 달라집니다.이와 같은 상황에서는 앞선 그림1에서 보여주는 것처럼 리듀싱 연산이 수행되지 않습니다.리듀싱 과정을 시작하는 시점에 모든 연산 대상이 되는 숫자 리스트가 준비되지 않으므로 스트림을 병렬로 처리하기 위한 청크 분할을 수행할 수 없습니다.

스트림이 병렬로 처리되도록 구현했지만 본질적으로 순차적을 동작하는 interate() 연산에 의해 순차적으로 스트림 요소를 추가하는 과정을 거칩니다.결국 순차적으로 완성된 스레드에서 멀티 쓰레드를 이용해 각각의 합계를 구합니다.하지만 순차 스트림에 여러 스레드를 할당하는 오버헤드만 발생했다고 볼 수 있습니다.

결국 위 벤치마킹의 요점은 병렬 처리라고 해서 모든 상황에서 좋은 성능을 발생 시키는 것이 아니다라는 것입니다.꼭 parallel의 내부 동작을 알고 사용해야한다는 의미이기도 합니다.

이번에는 위 병렬 함수를 특화된 메서드인 LongStream.rangeClosed를 사용하여 개선시켜보겠습니다.우선 위 함수는 아래와 같은 이점을 갖고 있습니다.

  • 기본형 long 타입을 직접 사용하기에 스트림 요소의 boxing과 unboxing 오버헤드가 없습니다.
  • 청크 단위로 쉽게 분할할 수 있는 숫자 범위를 생산합니다.예를 들어 1 - 20의 범위의 숫자를 각각 1-5,6-10,11-15,16-20의 범위의 숫자로 분할 가능합니다.

위의 메서드를 사용하여 벤치마킹을 진행해보겠습니다.아래와 같이 병렬 처리가 더 좋은 성능을 가지게 되었습니다.

포크 / 조인 프레임워크

포크/조인 프레임워크는 병렬화 할 수 있는 작업을 재귀적으로 작은 작업으로 분할한 다음에 분할된 서브태스크 각각의 결과를 합쳐서 전체 결과를 만들도록 설계되었습니다.해당 프레임워크는 분할되는 서브태스크를 ThreadPool의 worker thread에 분산 할당하는 ExecutorService를 구현합니다. ExecutorService는 간단하게 말하여 java에서 스레드풀을 생성하기 위해 구현하는 인터페이스라고 이해하시면됩니다.

Recursive Task 활용

스레드 풀을 이용하려면 ReculsiveTask<R>의 구현 클래스를 만들어야 합니다.여기서 R은 병렬화된 Task가 생성하는 결과의 class Type입니다. Return의 줄임말이라고 이해하셔도 됩니다. ReculsiveTask<R>를 정의하기 위해서는 내부의 추상 메서드인 compute() 메서드를 구현해야합니다.메서드의 시그니처는 아래와 같습니다.

	protected abstract R compute();

compute 메서드는 Task를 SubTask로 분할시키는 로직과 Task를 더 이상 분할할 수 없을 때 쪼개진 SubTask의 결과를 생산할 로직을 가져야합니다.

지금부터 숫자 배열이 주어질 경우, start 인덱스부터 end 인덱스까지의 합을 구하는 로직을 RecursiveTask을 활용해서 구현해보겠습니다.

import java.util.concurrent.RecursiveTask;

public class ForkJoinSumCalc extends RecursiveTask<Long> {
    private final long[] numbers;
    /*
    * start,end -> 연산을 시작할 numbers 배열의 시작과 끝 인덱스
    * */
    private final int start;
    private final int end;
    public static final long THRESHOLD = 10_000;

    public ForkJoinSumCalc(long[] numbers) {
        this(numbers, 0, numbers.length);
    }

    private ForkJoinSumCalc(long[] numbers, int start, int end) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
    }


    @Override
    public Long compute() {
        int length = end - start;
        if (length <= THRESHOLD) {
            return computeSequentially();
        }
        /*
        * Task split to half chuck.
        * */
        ForkJoinSumCalc leftTask = new ForkJoinSumCalc(numbers, start, start + length / 2);
        //start the Task asynchronously, specifically push this job to queue.
        leftTask.fork();

        ForkJoinSumCalc rightTask = new ForkJoinSumCalc(numbers, start + length / 2, end);
        Long compute = rightTask.compute();//main thread will compute rightTask.
        Long join = leftTask.join();//get result of leftTask.
        return compute + join;
    }

    private long computeSequentially() {
        long sum = 0;
        for (int i = start; i < end; i++) {
            sum += numbers[i];

        }
        return sum;
    }
}

위 코드에서 compute 메서드를 보면 length <= THRESHOLD 조건을 통해 해당 Task를 더 쪼갤지 말지를 결정합니다. 이후 쪼개진 Task의 결과를 수행하는 로직을 수행힙니다.좀 더 자세하게 ForkJoinSumCalc 어떻게 동작하는 확인해보겠습니다.

ForkJoinSumCalcForkJoinPool에 전달하게 되면 CommonPool 내의 스레드가 compute 메서드를 실행하며 Task를 수행합니다.compute 메서드는 함수 내의 Task의 크기가 더 쪼개질 수 있는지 확인하고 숫자 배열을 둘로 나누어 Task를 분할합니다.그러면 다시 새롭게 생긴 ForkJoinSumCalc라는 Task가 Common Pool에 전달되고 이는 스레드에 의해 compute() 메서드를 수행합니다. 더 이상 분할되지 않을 때까지 확인을 거친 뒤 해당 subTask들은 fork()를 통해 각자의 스레드에 의해 순차적으로 처리되며 각각의 subTask들의 결과가 합해지면서 최종 결과를 계산합니다.

책의 뒷부분에도 spliterator,작업 훔치기 등 병렬 처리와 관련된 내용이 더 존재한다.해당 내용들은 병렬 처리에 대해 좀 더 익숙해진 다음 포스팅에 업데이트 하도록 한다.