문제 상황
길이 $N$인 배열이 주어졌을 때, 구간 $[l, r]$의 최댓값을 빠르게 구해야 한다.
방법전처리쿼리
| 브루트포스 | $O(1)$ | $O(N)$ |
| 세그먼트 트리 | $O(N)$ | $O(\log N)$ |
핵심 아이디어
배열을 이진 트리 형태로 분할하여, 각 노드가 해당 구간의 최댓값을 저장한다. 쿼리 시 필요한 구간만 방문하므로 $O(\log N)$에 처리된다.
예) 배열 [2, 5, 1, 4, 9, 3]
9 [0,5]
/ \
5 [0,2] 9 [3,5]
/ \ / \
5[0,1] 1[2,2] 9[3,4] 3[5,5]
/ \ / \
2[0] 5[1] 4[3] 9[4]
구현
1. 트리 초기화 (Build)
리프 노드에 원소를 넣고, 부모 노드는 자식들의 최댓값을 저장한다.
import sys
input = sys.stdin.readline
def build(node, start, end):
if start == end:
tree[node] = arr[start]
return tree[node]
mid = (start + end) // 2
left = build(node * 2, start, mid)
right = build(node * 2 + 1, mid + 1, end)
tree[node] = max(left, right)
return tree[node]
2. 구간 최댓값 쿼리
쿼리 구간 $[l, r]$과 현재 노드가 담당하는 구간 $[start, end]$를 비교한다.
- 구간이 겹치지 않음: $-\infty$ 반환
- 구간이 완전히 포함됨: 현재 노드 값 반환
- 일부만 겹침: 양쪽 자식에게 재귀
def query(node, start, end, l, r):
if r < start or end < l:
return -float('inf')
if l <= start and end <= r:
return tree[node]
mid = (start + end) // 2
left = query(node * 2, start, mid, l, r)
right = query(node * 2 + 1, mid + 1, end, l, r)
return max(left, right)
전체 코드 (2357번)
import sys
input = sys.stdin.readline
def build(node, start, end):
if start == end:
min_tree[node] = max_tree[node] = arr[start]
return
mid = (start + end) // 2
build(node * 2, start, mid)
build(node * 2 + 1, mid + 1, end)
min_tree[node] = min(min_tree[node * 2], min_tree[node * 2 + 1])
max_tree[node] = max(max_tree[node * 2], max_tree[node * 2 + 1])
def query_min(node, start, end, l, r):
if r < start or end < l:
return float('inf')
if l <= start and end <= r:
return min_tree[node]
mid = (start + end) // 2
return min(query_min(node * 2, start, mid, l, r),
query_min(node * 2 + 1, mid + 1, end, l, r))
def query_max(node, start, end, l, r):
if r < start or end < l:
return -float('inf')
if l <= start and end <= r:
return max_tree[node]
mid = (start + end) // 2
return max(query_max(node * 2, start, mid, l, r),
query_max(node * 2 + 1, mid + 1, end, l, r))
n, m = map(int, input().split())
arr = [int(input()) for _ in range(n)]
min_tree = [0] * (4 * n)
max_tree = [0] * (4 * n)
build(1, 0, n - 1)
for _ in range(m):
a, b = map(int, input().split())
print(query_min(1, 0, n - 1, a - 1, b - 1),
query_max(1, 0, n - 1, a - 1, b - 1))
시간복잡도
연산시간복잡도
| Build | $O(N)$ |
| Query | $O(\log N)$ |
| 총 (M개 쿼리) | $O(N + M \log N)$ |