세그먼트 트리와 Lazy Propagation을 이용한 구간 합 구하기
1. 문제 상황과 해결 전략
문제의 핵심 요구사항
- 긴 수열에서 구간 합을 구해야 함
- 특정 구간의 모든 수를 갱신하는 연산이 빈번함
- 두 연산이 번갈아가며 실행됨
왜 일반적인 방법으로는 해결이 어려운가?
- 배열을 직접 수정하는 경우: 구간 업데이트에 O(N)
- 일반 세그먼트 트리: 구간 업데이트에 여전히 O(N)
- 누적 합 배열: 한 값이 변경되면 모든 누적 합을 다시 계산해야 함
해결 전략: Lazy Propagation이 적용된 세그먼트 트리
- 구간 업데이트를 지연시켜 실제 필요할 때만 수행
- 각 연산의 시간 복잡도를 O(logN)으로 최적화
2. 알고리즘 상세 설명
세그먼트 트리 기본 구조
- 노드가 담당하는 정보
- 구간 [start, end]의 합
- lazy 값 (지연된 업데이트 정보)
- 트리의 구조
- 루트: 전체 구간
- 각 노드: 구간을 이진 분할한 부분 구간
- 리프: 원본 배열의 각 원소
Lazy Propagation의 동작 원리
- 업데이트 지연
- 구간 업데이트 시 실제로 값을 바로 변경하지 않음
- 대신 lazy 배열에 변경할 값을 기록
- 전파 과정
- 노드 접근 시 lazy 값 확인
- lazy 값이 있다면:
a. 현재 노드의 값 갱신
b. 자식 노드로 lazy 값 전달
c. 현재 노드의 lazy 값 초기화
3. 핵심 연산 구현
초기화 함수
세그먼트 트리의 기본 구조를 생성하고 초기값을 설정합니다.
def init(node, start, end):
if start == end: # 리프 노드인 경우
tree[node] = arr[start]
return tree[node]
mid = (start + end) // 2
# 왼쪽 자식과 오른쪽 자식 노드를 재귀적으로 초기화하고 그 합을 저장
tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
return tree[node]
propagate 함수: lazy 값 전파
현재 노드에 지연된 업데이트가 있다면 이를 처리합니다.
def propagate(node, start, end):
if lazy[node] != 0:
# 현재 구간에 대한 업데이트 적용
tree[node] += (end - start + 1) * lazy[node]
if start != end: # 리프 노드가 아니면 자식에게 전파
lazy[node*2] += lazy[node]
lazy[node*2+1] += lazy[node]
lazy[node] = 0 # lazy 값 초기화
update_range 함수: 구간 업데이트
구간 [left, right]의 모든 원소에 diff를 더합니다.
def update_range(node, start, end, left, right, diff):
propagate(node, start, end) # 우선 propagate 호출
if left > end or right < start: # 구간이 겹치지 않는 경우
return
if left <= start and end <= right: # 구간이 완전히 포함되는 경우
tree[node] += (end - start + 1) * diff
if start != end: # 리프 노드가 아니면 lazy 값 설정
lazy[node*2] += diff
lazy[node*2+1] += diff
return
# 구간이 부분적으로 겹치는 경우
mid = (start + end) // 2
update_range(node*2, start, mid, left, right, diff)
update_range(node*2+1, mid+1, end, left, right, diff)
tree[node] = tree[node*2] + tree[node*2+1] # 자식 노드의 합으로 갱신
query 함수: 구간 합 조회
구간 [left, right]의 합을 구합니다.
def query(node, start, end, left, right):
propagate(node, start, end) # 우선 propagate 호출
if left > end or right < start: # 구간이 겹치지 않는 경우
return 0
if left <= start and end <= right: # 구간이 완전히 포함되는 경우
return tree[node]
# 구간이 부분적으로 겹치는 경우
mid = (start + end) // 2
return query(node*2, start, mid, left, right) + \
query(node*2+1, mid+1, end, left, right)
4. 성능 분석
시간 복잡도
- 초기화: O(N)
- 구간 업데이트: O(logN)
- 구간 합 조회: O(logN)
- 전체: O(N + (M+K)logN)
공간 복잡도
- 세그먼트 트리: 4N 크기의 배열
- lazy 배열: 4N 크기의 배열
- 총 O(N)의 추가 공간 필요
5. 구현시 주의사항
- 자료형 관리
- 입력값 범위가 매우 큰 경우 오버플로우 주의
- Python의 경우 기본 int 타입으로 충분
- 초기화 순서
- 트리 배열 생성 → 초기값 설정 → lazy 배열 생성
- lazy 배열은 반드시 0으로 초기화
- propagate 호출 시점
- 노드에 접근하기 전 반드시 호출
- update_range와 query 함수 모두에서 필요
- 구간 인덱스 관리
- 0-based 인덱스와 1-based 인덱스 구분
- 문제의 입력에 맞게 적절히 변환
6. 전체 구현 코드
import sys
input = sys.stdin.readline
def init(node, start, end):
if start == end:
tree[node] = arr[start]
return tree[node]
mid = (start + end) // 2
tree[node] = init(node*2, start, mid) + init(node*2+1, mid+1, end)
return tree[node]
def propagate(node, start, end):
if lazy[node] != 0:
tree[node] += (end - start + 1) * lazy[node]
if start != end:
lazy[node*2] += lazy[node]
lazy[node*2+1] += lazy[node]
lazy[node] = 0
def update_range(node, start, end, left, right, diff):
propagate(node, start, end)
if left > end or right < start:
return
if left <= start and end <= right:
tree[node] += (end - start + 1) * diff
if start != end:
lazy[node*2] += diff
lazy[node*2+1] += diff
return
mid = (start + end) // 2
update_range(node*2, start, mid, left, right, diff)
update_range(node*2+1, mid+1, end, left, right, diff)
tree[node] = tree[node*2] + tree[node*2+1]
def query(node, start, end, left, right):
propagate(node, start, end)
if left > end or right < start:
return 0
if left <= start and end <= right:
return tree[node]
mid = (start + end) // 2
return query(node*2, start, mid, left, right) + \
query(node*2+1, mid+1, end, left, right)
# 입력 처리
N, M, K = map(int, input().split())
arr = [int(input()) for _ in range(N)]
# 세그먼트 트리와 lazy 배열 초기화
tree = [0] * (4 * N)
lazy = [0] * (4 * N)
init(1, 0, N-1)
# 쿼리 처리
for _ in range(M + K):
query_type, *params = map(int, input().split())
if query_type == 1: # 업데이트 쿼리
b, c, d = params
update_range(1, 0, N-1, b-1, c-1, d)
else: # 구간 합 쿼리
b, c = params
print(query(1, 0, N-1, b-1, c-1))