본문 바로가기
카테고리 없음

세그먼트-트리와-lazy-propagation을-이용한-구간-합-구하기

by redcubes 2024. 11. 14.

세그먼트 트리와 Lazy Propagation을 이용한 구간 합 구하기

1. 문제 상황과 해결 전략

문제의 핵심 요구사항

  1. 긴 수열에서 구간 합을 구해야 함
  2. 특정 구간의 모든 수를 갱신하는 연산이 빈번함
  3. 두 연산이 번갈아가며 실행됨

왜 일반적인 방법으로는 해결이 어려운가?

  • 배열을 직접 수정하는 경우: 구간 업데이트에 O(N)
  • 일반 세그먼트 트리: 구간 업데이트에 여전히 O(N)
  • 누적 합 배열: 한 값이 변경되면 모든 누적 합을 다시 계산해야 함

해결 전략: Lazy Propagation이 적용된 세그먼트 트리

  • 구간 업데이트를 지연시켜 실제 필요할 때만 수행
  • 각 연산의 시간 복잡도를 O(logN)으로 최적화

2. 알고리즘 상세 설명

세그먼트 트리 기본 구조

  1. 노드가 담당하는 정보
    • 구간 [start, end]의 합
    • lazy 값 (지연된 업데이트 정보)
  2. 트리의 구조
    • 루트: 전체 구간
    • 각 노드: 구간을 이진 분할한 부분 구간
    • 리프: 원본 배열의 각 원소

Lazy Propagation의 동작 원리

  1. 업데이트 지연
    • 구간 업데이트 시 실제로 값을 바로 변경하지 않음
    • 대신 lazy 배열에 변경할 값을 기록
  2. 전파 과정
    • 노드 접근 시 lazy 값 확인
    • lazy 값이 있다면:
      a. 현재 노드의 값 갱신
      b. 자식 노드로 lazy 값 전달
      c. 현재 노드의 lazy 값 초기화

3. 핵심 연산 구현

초기화 함수

세그먼트 트리의 기본 구조를 생성하고 초기값을 설정합니다.

def init(node, start, end):
    if start == end:  # 리프 노드인 경우
        tree[node] = arr[start]
        return tree[node]
    
    mid = (start + end) // 2
    # 왼쪽 자식과 오른쪽 자식 노드를 재귀적으로 초기화하고 그 합을 저장
    tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
    return tree[node]

propagate 함수: lazy 값 전파

현재 노드에 지연된 업데이트가 있다면 이를 처리합니다.

def propagate(node, start, end):
    if lazy[node] != 0:
        # 현재 구간에 대한 업데이트 적용
        tree[node] += (end - start + 1) * lazy[node]
        if start != end:  # 리프 노드가 아니면 자식에게 전파
            lazy[node*2] += lazy[node]
            lazy[node*2+1] += lazy[node]
        lazy[node] = 0    # lazy 값 초기화

update_range 함수: 구간 업데이트

구간 [left, right]의 모든 원소에 diff를 더합니다.

def update_range(node, start, end, left, right, diff):
    propagate(node, start, end)  # 우선 propagate 호출
    
    if left > end or right < start:  # 구간이 겹치지 않는 경우
        return
        
    if left <= start and end <= right:  # 구간이 완전히 포함되는 경우
        tree[node] += (end - start + 1) * diff
        if start != end:  # 리프 노드가 아니면 lazy 값 설정
            lazy[node*2] += diff
            lazy[node*2+1] += diff
        return
    
    # 구간이 부분적으로 겹치는 경우
    mid = (start + end) // 2
    update_range(node*2, start, mid, left, right, diff)
    update_range(node*2+1, mid+1, end, left, right, diff)
    tree[node] = tree[node*2] + tree[node*2+1]  # 자식 노드의 합으로 갱신

query 함수: 구간 합 조회

구간 [left, right]의 합을 구합니다.

def query(node, start, end, left, right):
    propagate(node, start, end)  # 우선 propagate 호출
    
    if left > end or right < start:  # 구간이 겹치지 않는 경우
        return 0
        
    if left <= start and end <= right:  # 구간이 완전히 포함되는 경우
        return tree[node]
    
    # 구간이 부분적으로 겹치는 경우
    mid = (start + end) // 2
    return query(node*2, start, mid, left, right) + \
           query(node*2+1, mid+1, end, left, right)

4. 성능 분석

시간 복잡도

  • 초기화: O(N)
  • 구간 업데이트: O(logN)
  • 구간 합 조회: O(logN)
  • 전체: O(N + (M+K)logN)

공간 복잡도

  • 세그먼트 트리: 4N 크기의 배열
  • lazy 배열: 4N 크기의 배열
  • 총 O(N)의 추가 공간 필요

5. 구현시 주의사항

  1. 자료형 관리
    • 입력값 범위가 매우 큰 경우 오버플로우 주의
    • Python의 경우 기본 int 타입으로 충분
  2. 초기화 순서
    • 트리 배열 생성 → 초기값 설정 → lazy 배열 생성
    • lazy 배열은 반드시 0으로 초기화
  3. propagate 호출 시점
    • 노드에 접근하기 전 반드시 호출
    • update_range와 query 함수 모두에서 필요
  4. 구간 인덱스 관리
    • 0-based 인덱스와 1-based 인덱스 구분
    • 문제의 입력에 맞게 적절히 변환

6. 전체 구현 코드

import sys
input = sys.stdin.readline

def init(node, start, end):
    if start == end:
        tree[node] = arr[start]
        return tree[node]
    
    mid = (start + end) // 2
    tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
    return tree[node]

def propagate(node, start, end):
    if lazy[node] != 0:
        tree[node] += (end - start + 1) * lazy[node]
        if start != end:
            lazy[node*2] += lazy[node]
            lazy[node*2+1] += lazy[node]
        lazy[node] = 0

def update_range(node, start, end, left, right, diff):
    propagate(node, start, end)
    
    if left > end or right < start:
        return
        
    if left <= start and end <= right:
        tree[node] += (end - start + 1) * diff
        if start != end:
            lazy[node*2] += diff
            lazy[node*2+1] += diff
        return
        
    mid = (start + end) // 2
    update_range(node*2, start, mid, left, right, diff)
    update_range(node*2+1, mid+1, end, left, right, diff)
    tree[node] = tree[node*2] + tree[node*2+1]

def query(node, start, end, left, right):
    propagate(node, start, end)
    
    if left > end or right < start:
        return 0
        
    if left <= start and end <= right:
        return tree[node]
        
    mid = (start + end) // 2
    return query(node*2, start, mid, left, right) + \
           query(node*2+1, mid+1, end, left, right)

# 입력 처리
N, M, K = map(int, input().split())
arr = [int(input()) for _ in range(N)]

# 세그먼트 트리와 lazy 배열 초기화
tree = [0] * (4 * N)
lazy = [0] * (4 * N)
init(1, 0, N-1)

# 쿼리 처리
for _ in range(M + K):
    query_type, *params = map(int, input().split())
    
    if query_type == 1:  # 업데이트 쿼리
        b, c, d = params
        update_range(1, 0, N-1, b-1, c-1, d)
    else:  # 구간 합 쿼리
        b, c = params
        print(query(1, 0, N-1, b-1, c-1))