Data Structure/세그먼트 트리

세그먼트 트리 (Segment Tree)

SH3542 2024. 6. 25. 15:11
 

목차

 

    세그먼트 트리(Segment Tree)

    완전 이진 트리에 기반한 자료구조이다.

    고정된 크기의 수열에서 특정 구간의 질의 및 수정을 효율적으로 수행하기 위해 사용한다.

     

    세그먼트 트리의 필요성

    길이 N=100만인 배열에서 특정 구간의 합을 구해야한다고 가정해보자.

     

    최악의 경우는 아마, 전체 원소의 합을 구하는 경우일 것이다. 이는 O(100만)이 소요된다.

    수정의 경우는 어떨까? 해당하는 인덱스의 값을 변경하면 된다. 이는 O(1)이 소요된다.

     

    세그먼트 트리를 사용한다면, 연산과 수정을 O(logN)에 수행할 수 있다. 즉, O(log100만=19.93)이다.

    합을 구하는 동작을 m번 수행한다면, 더 엄청난 차이가 발생할 것이다.

     

    유의할 점

    이진수에 기반한 컴퓨터 과학에서, log의 밑은 e가 아닌 2를 의미한다.

     

    세그먼트 트리의 특성

    1. 수정과 질의가 동시에 잦을 때에 유리하다.

    하나가 O(1)이고 나머지가 O(N)인 것과, 모두 O(logN)인 것은 차이가 크다.

     

    2. 다양한 질의가 가능하다.

    대표적으로 특정 구간의 합 연산, 최대/최소 값 구하기, XOR연산 등에 활용한다.

     

    3. 분할 정복 형식으로 구현된다.

    그러므로 또한, 재귀적으로 구현된다.

     

    4. 원소의 삭제 및 추가 연산에는 적합하지 않다.

    세그먼트 트리는 수열의 구간별 값을 저장해놓은 자료구조이다.

     

    따라서, 업데이트시엔 포함되는 구간의 값만 갱신하면 되지만, 삭제 및 추가 시엔 트리를 재조정 해야한다.

    이는 비효율적일 뿐더러, 애초에 목적과 어긋나므로 고려하지 않는다.

     

    삭제 연산을 노드를 없애는 대신, 해당 값을 무효화 시키는 연산으로 대체하여 수행할 수 있다.

    ex) 구간 합에서 삭제 대신 해당 노드의 값을 0으로 놓는다. 최대 값에서 음수인 구간이 없다면 0으로 놓는다.

     

    그러나, XOR연산 등에서는 이는 매우 복잡(사실상 불가능)하다. 적용해야 한다면 조건을 잘 고려하자.

     

    세그먼트 트리 구성하기

     

    전체 코드는 다음과 같다. 구간 합을 수행하는 세그먼트 트리의 구현 코드이다.

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.Arrays;
    import java.util.StringTokenizer;
    
    class Main {
        static class SegTree {
    
            int tree[];
            int treeSize;
    
            SegTree(int arrSize) {
                int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
                this.treeSize = 1 << h;
    //                        treeSize = arrSize * 4;
                tree = new int[treeSize];
            }
    
            private int init(int[] a, int node, int start, int end) {
                if (start == end) return tree[node] = a[start];
    
                return tree[node] = init(a, node * 2, start, (start + end) / 2) +
                        init(a, node * 2 + 1, ((start + end) / 2) + 1, end);
            }
    
            private void update(int node, int start, int end, int idx, int diff) {
                if (idx < start || idx > end) return;
    
                tree[node] += diff;
    
                if (start != end) {
                    update(node * 2, start, (start + end) / 2, idx, diff);
                    update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
                }
            }
    
            private int sum(int node, int left, int right, int start, int end) {
                if (left > end || right < start) return 0;
    
                if (left <= start && right >= end) return tree[node];
    
                return sum(node * 2, left, right, start, (start + end) / 2)
                        + sum(node * 2 + 1, left, right, (start + end) / 2 + 1, end);
            }
        }
    
        public static void main(String[] args) throws IOException {
    
    
    //        input :
    //        8
    //        1 2 3 4 5 6 7 8
            BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
            int N = Integer.parseInt(br.readLine());
            StringTokenizer st = new StringTokenizer(br.readLine());
    
            int[] a = new int[N + 1];
    
            for (int i = 1; i <= N; i++) {
                a[i] = Integer.parseInt(st.nextToken());
            }
    
            SegTree tree = new SegTree(N + 1);
            tree.init(a, 1, 1, a.length - 1);
            System.out.println(Arrays.toString(tree.tree)); // [0, 36, 10, 26, 3, 7, 11, 15, 1, 2, 3, 4, 5, 6, 7, 8]
            System.out.println(tree.sum(1, 1, 5, 1, a.length - 1)); // 15
            System.out.println(tree.sum(1, 1, 4, 1, a.length - 1)); // 10
            tree.update(1, 1, a.length - 1, 3, a[3] = 10 - a[3]); // a[3] = (10 - 3) = 7
            System.out.println(tree.sum(1, 1, 4, 1, a.length - 1)); // 17
        }
    }

     

    1. 트리 크기 정하기

    출처 : https://cano721.tistory.com/38

     

     

    우선, 높이 h를 구해야 한다.

    세그먼트 트리의 구조를 살펴보면, 결국 리프 노드에는 원본 배열의 모든 원소가 각각 저장된다. 부모 노드는 특정 구간의 값이다.

    길이가 8인 배열의 세그먼트 트리 노드 개수는, 8 + 4  + 2 + 1(root node)가 될 것이다. 즉, 트리의 높이 h=4이고

    포화 이진 트리 이므로, 노드 개수는 (2^h)-1이다.

     

    위의 그림처럼, 길이가 7이라면 어떻게될까? 굳이 따지자면, h=3.xxx가 될 것이다. 하지만 이는 말이 되지 않는다.

    길이가 4인 배열의 h=3보다는 커야하고, 길이가 8인 배열의 h=4 이하면 될 것이다.

     

    따라서, h=4일 때로 구성한다면 약간의 공간복잡도와 타협하고 아래과 같이 구성할 수 있다.

     

    SegTree(int arrSize) {
        int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
        this.treeSize = 1 << h;
        tree = new int[treeSize];
    }

    자바에서 Math.log()는 자연로그(밑이e)이므로,

    Math.log(arrSize) / Math.log(2)

    로 밑을 2로 만들어준다.

     

    Math.ceil() 함수를 통해, 노드를 충분히 담을 수 있는 높이를 구해준다. (3 -> 3 / 3.xxx -> 4)

    this.treeSize = 1 << h;

    를 통해 2^h의 크기를 할당한다. (이후에 나오는 내용이지만, 2^h 크기의 tree에서, 0번 인덱스를 쓰지 않으므로 총 2^h -1가 된다. 이는 트리가 포화 이진 트리인 경우에도 원소의 개수 2^h -1를 담을 수 있다.)

     

    이로써 treesize를 구할 수 있게 되었다.

    (단순히 원본 배열의 크기 arrSize에 4를 곱해서 treeSize를 정할 수도 있다. 이는 간편하지만, 메모리를 조금 더 쓰게된다.)

     

    2. 트리 초기화 하기

    private int init(int[] a, int node, int start, int end) {
        if (start == end) return tree[node] = a[start];
    
        return tree[node] = init(a, node * 2, start, (start + end) / 2) +
                init(a, node * 2 + 1, ((start + end) / 2) + 1, end);
    }
    tree.init(a, 1, 1, a.length - 1);

     

    root node 에서 출발한다. (부모부터 갱신을 하는 것이 아니다.)

    부모노드가 node라면, left 자식은 node*2, right 자식은 (node*2) +1이다.

    ex) root가 1이라면, 왼쪽 자식은 2, 오른쪽 자식은 3

     

    이는 분할 정복으로 재귀를 통해 리프 노드까지 내려갈 것이다. 이후,

    if (start == end) return tree[node] = a[start];

    조건에 의해, 리프 노드의 값이 원본 배열의 값으로 할당된다.

     

    이후엔 부모 노드의 값이

    return tree[node] = init(a, node * 2, start, (start + end) / 2) +
            init(a, node * 2 + 1, ((start + end) / 2) + 1, end);

     

    에 의해, 왼쪽 자식과 오른쪽 자식의 값의 합으로 이루어진다.

     

    이는 결국, 루트 노드까지 다시 올라가며 세그먼트 트리가 완성된다.

     

    3. 구간 합 구하기

    private int sum(int node, int left, int right, int start, int end) {
        if (left > end || right < start) return 0;
    
        if (left <= start && right >= end) return tree[node];
    
        return sum(node * 2, left, right, start, (start + end) / 2)
                + sum(node * 2 + 1, left, right, (start + end) / 2 + 1, end);
    }

     

    left와 right는 합을 구하려는 구간을,

    start와 end는 현재 탐색중인 node의 구간을 의미한다.

     

    구하려는 구간은 유지한채로,

    탐색중인 노드의 구간이 바뀌며 이에 포함된다면 tree[node]라는 node번째 값을 return할 것이다.

     

    출처 : https://cano721.tistory.com/38

    그림을 다시 예로 들어보자,

     

    나는 arr[1]~arr[5]번 구간의 합을 구하려고 한다. ( left=1, right=5)

    if (left <= start && right >= end) return tree[node];

    이 때, 2번노드에 도달한다면, start=1 end=4 (arr[1]~arr[4])이다.

    이는 left=1 right=5에 포함되므로, 값을 return한다. 이후의 자식 노드는 더이상 탐색하지 않는다.

     

    남은 것은 start=4 end=5인 구간이다.

    이는 1번->3번->6번->12번 노드를 거치며, 12번 노드를 결과에 sum 한다.

     

    이 과정에서, 1번->3번->7번 노드의 탐색은

    if (left > end || right < start) return 0;

     

    를 통해 sum에 영향이 없는 결과 값 0을 return하며 자식 노드를 탐색하지 않고 탐색을 종료한다.

    (해당 그림에선 리프 노드지만, 자식 노드가 있더라도 종료된다.)

     

     

    결국 우리는, 2번 노드와 12번 노드를 더한 값으로 arr[1]~arr[5]의 구간 합을 얻게 되었다.

     

    4. 값 수정하기

    public void update(int node, int start, int end, int idx, int diff) {
        if (idx < start || idx > end) return;
    
        tree[node] += diff;
    
        if (start != end) {
            update(node * 2, start, (start + end) / 2, idx, diff);
            update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
        }
    }

     

    sum과 기본 원리는 비슷하다.

    노드를 분할정복 및 재귀로 탐색하며, 일치하는 scope 에서만 연산을 수행한다. 다만, 리프 노드가 있는 곳 까지 계속 내려간다.

     

    이는, sum의 경우 일치하는 구간의 값을 단순히 반환하면 되지만,

    update의 경우 바뀐 노드 및 바뀐 노드가 포함되는 모든 부모 노드가 바뀌어야 하기 때문이다.

     

     

    범위를 벗어난다면,

    if (idx < start || idx > end) return;

    마찬가지로 더 이상의 탐색 없이 종료한다.

     

     

    범위 안이라면,

    tree[node] += diff;

     

    를 수행하며

     

    if (start != end) {
        update(node * 2, start, (start + end) / 2, idx, diff);
        update(node * 2 + 1, (start + end) / 2 + 1, end, idx, diff);
    }

     

     

    리프 노드까지 내려간다. 이후 탐색이 종료될 것이다.

     

     

    또 한가지 주의할 점은,

    tree[node] += diff;

    차이를 update한다는 것이다.

     

    arr[] = [0,1,2,3,4,5,6,7,8]에서,

     

    arr[3]이 3에서 10으로 바뀌었다고 가정하자.

     

    2번 노드 = arr[1]~arr[4]에 이를 반영하려면 바뀐 노드를 기반으로 재구성 해야한다. 이는 비효율적이다.

     

    대신에, diff인 (10-3)을 더한다면, 이를 쉽게 수행할 수 있다.

     

     유의할 점

    1. root node를 0으로 놓지 않는다.

    왼쪽 자식이 node(0)*2 = 0이 되며, 엉뚱한 결과를 낳는다.

    보통 트리 구현시 편의를 위해 0번 인덱스를 비우는 이유이기도 하다.

     

    2. update 작업에서, 원본 배열의 update 또한 잊지 말아야 한다.