본문 바로가기
카테고리 없음

구간 최댓값 세그먼트 트리

by redcubes 2026. 1. 23.

2357번: 최솟값과 최댓값

문제 상황

길이 $N$인 배열이 주어졌을 때, 구간 $[l, r]$의 최댓값을 빠르게 구해야 한다.

방법전처리쿼리

브루트포스 $O(1)$ $O(N)$
세그먼트 트리 $O(N)$ $O(\log N)$

핵심 아이디어

배열을 이진 트리 형태로 분할하여, 각 노드가 해당 구간의 최댓값을 저장한다. 쿼리 시 필요한 구간만 방문하므로 $O(\log N)$에 처리된다.

예) 배열 [2, 5, 1, 4, 9, 3]

            9 [0,5]
          /        \
      5 [0,2]      9 [3,5]
      /    \       /    \
   5[0,1] 1[2,2] 9[3,4] 3[5,5]
   /   \         /   \
 2[0] 5[1]     4[3] 9[4]

구현

1. 트리 초기화 (Build)

리프 노드에 원소를 넣고, 부모 노드는 자식들의 최댓값을 저장한다.

import sys
input = sys.stdin.readline

def build(node, start, end):
    if start == end:
        tree[node] = arr[start]
        return tree[node]
    mid = (start + end) // 2
    left = build(node * 2, start, mid)
    right = build(node * 2 + 1, mid + 1, end)
    tree[node] = max(left, right)
    return tree[node]

2. 구간 최댓값 쿼리

쿼리 구간 $[l, r]$과 현재 노드가 담당하는 구간 $[start, end]$를 비교한다.

  • 구간이 겹치지 않음: $-\infty$ 반환
  • 구간이 완전히 포함됨: 현재 노드 값 반환
  • 일부만 겹침: 양쪽 자식에게 재귀
def query(node, start, end, l, r):
    if r < start or end < l:
        return -float('inf')
    if l <= start and end <= r:
        return tree[node]
    mid = (start + end) // 2
    left = query(node * 2, start, mid, l, r)
    right = query(node * 2 + 1, mid + 1, end, l, r)
    return max(left, right)

전체 코드 (2357번)

import sys
input = sys.stdin.readline

def build(node, start, end):
    if start == end:
        min_tree[node] = max_tree[node] = arr[start]
        return
    mid = (start + end) // 2
    build(node * 2, start, mid)
    build(node * 2 + 1, mid + 1, end)
    min_tree[node] = min(min_tree[node * 2], min_tree[node * 2 + 1])
    max_tree[node] = max(max_tree[node * 2], max_tree[node * 2 + 1])

def query_min(node, start, end, l, r):
    if r < start or end < l:
        return float('inf')
    if l <= start and end <= r:
        return min_tree[node]
    mid = (start + end) // 2
    return min(query_min(node * 2, start, mid, l, r),
               query_min(node * 2 + 1, mid + 1, end, l, r))

def query_max(node, start, end, l, r):
    if r < start or end < l:
        return -float('inf')
    if l <= start and end <= r:
        return max_tree[node]
    mid = (start + end) // 2
    return max(query_max(node * 2, start, mid, l, r),
               query_max(node * 2 + 1, mid + 1, end, l, r))

n, m = map(int, input().split())
arr = [int(input()) for _ in range(n)]
min_tree = [0] * (4 * n)
max_tree = [0] * (4 * n)

build(1, 0, n - 1)

for _ in range(m):
    a, b = map(int, input().split())
    print(query_min(1, 0, n - 1, a - 1, b - 1),
          query_max(1, 0, n - 1, a - 1, b - 1))

시간복잡도

연산시간복잡도

Build $O(N)$
Query $O(\log N)$
총 (M개 쿼리) $O(N + M \log N)$