본문 바로가기
Tech/Coding

AtCoder🐨Beginner Contest 362 - C. Sum = 0

by redcubes 2024. 7. 14.

L_i와  R_i 범위에서 X_i를 선택해서 수열을 만들고 합을 0으로 만드는 문제다.
처음엔 만들수 있나? 만 물어보는 줄 알았는데.. 그래서 쉽다고 생각했다
만들 수 있으려면 L의 합들과 R의 합들 범위에 0이 있어야 한다.

그런데 실제 조건을 만족하는 수열 X를 구해야 했다.
L의 합으로 X를 정해두고 그리디하게 0이 될 때 까지 최대한 더해주면 된다고 생각했다.

n, *lr = map(int, open(0).read().split())
sum_l, sum_r = 0, 0
res = []
diff = []
for i in range(0, n<<1, 2):
    left,right = lr[i],lr[i + 1]
    res.append(left)
    diff.append(right-left)
    sum_l += left
    sum_r += right

if not(sum_l <= 0 <= sum_r):
    print("No")
else:
    print("Yes")
    total = sum_l

    for i in range(n):
        if total == 0:
            print(" ".join(map(str, res)))
            break

        if total < 0:
            increase_amount = min(-total, diff[i])
            res[i] += increase_amount
            total += increase_amount


정말 잘 풀었는데 50여개 케이스 중 3개가 통과하지 않는 거다...
느낌상 경과의 원소가 하나뿐인 경우가 아닌가 해서 계속 예외처리를 하다가 풀지 못했다.


에디토리얼을 보고 나랑 로직이 같다는 것을 깨달았고 내 코드가 안 된 이유를 알게 되었다.....
그건 바로........

n, *lr = map(int, open(0).read().split())
sum_l, sum_r, res, diff = 0, 0, [], []
for i in range(0, n << 1, 2):
    left, right = lr[i], lr[i + 1]
    res.append(left)
    diff.append(right - left)
    sum_l += left
    sum_r += right

if sum_l > 0 or sum_r < 0:
    open(1,"w").write("No\n")
else:
    for i in range(n):
        if sum_l == 0:break
        increase = min(-sum_l, diff[i])
        res[i] += increase
        sum_l += increase
    open(1,"w").write("Yes\n"+" ".join(map(str, res))+"\n")

이것이 나의 1솔을 날려먹은 그 부분이다.....

저렇게 해 놓으면 브레이킹 없이 풀로 돌아야 답이 나오는 X는 for문을 브레이킹 없이 끝까지 돌기 때문에 출력이 없다.....
저 3개의 예제는 바로 그런 예제였다. 리스트 끝까지 가야 조정되는 .......
위 코드처럼 포문바깥으로 프린트를 옮기자 바로 AC..

이런 류의 실수를 하지 않게 포문과 브레이크가 있으면 꼼꼼하게 체크해야 하겠다는 생각이 들었다.
반복문 브레이킹을 할 때 출력하는 것과 브레이킹 후 출력이 다르다.

허탈한 마음으로 있다가 또 다른 아이디어가 떠올랐다.

L의 합을 구해서 최소값을 구한 뒤, 최소값을 구하는 과정에서 R에서 L을 빼서 이동가능한 범위를 구했었다.
그럼 그냥 한 번 순회하는 김에
$R-L$의 누적합을 리스트에 저장하고 이분탐색해서 L리스트의 합계와 더한 누적합이
0과 같거나 최초로 0보다 큰 것을 찾는 아이이디어다.

이러면 두 번 순회하지 않고 로그복잡도만에 찾을 수 있지 않을까? 

n, *lr = map(int, open(0).read().split())
sum_l, sum_r = 0, 0
res, prefix_diff = [], [0] * n

# left와 right를 계산하고 누적합을 계산합니다.
for i in range(n):
    left, right = lr[2 * i], lr[2 * i + 1]
    res.append(left)
    diff = right - left
    sum_l += left
    sum_r += right
    prefix_diff[i] = diff + (prefix_diff[i - 1] if i > 0 else 0)

def binary_search(target):
    low, high = 0, n - 1
    while low < high:
        mid = (low + high) // 2
        if prefix_diff[mid] >= target:
            high = mid
        else:
            low = mid + 1
    return low

if sum_l > 0 or sum_r < 0:
    open(1, "w").write("No\n")
else:
    if sum_l < 0:
        pos = binary_search(-sum_l)
        for i in range(pos):
            res[i] += (lr[2 * i + 1] - lr[2 * i])
        res[pos] += -sum_l - (prefix_diff[pos - 1] if pos > 0 else 0)
    open(1, "w").write("Yes\n" + " ".join(map(str, res)) + "\n")

효과는 미미했다.

from bisect import bisect_left

n, *lr = map(int, open(0).read().split())
sum_l, sum_r = 0, 0
res, prefix_diff = [], [0] * n

# left와 right를 계산하고 누적합을 계산합니다.
for i in range(n):
    left, right = lr[2 * i], lr[2 * i + 1]
    res.append(left)
    diff = right - left
    sum_l += left
    sum_r += right
    prefix_diff[i] = diff + (prefix_diff[i - 1] if i > 0 else 0)

if sum_l > 0 or sum_r < 0:
    open(1, "w").write("No\n")
else:
    if sum_l < 0:
        pos = bisect_left(prefix_diff, -sum_l)
        for i in range(pos):
            res[i] += (lr[2 * i + 1] - lr[2 * i])
        res[pos] += -sum_l - (prefix_diff[pos - 1] if pos > 0 else 0)
    open(1, "w").write("Yes\n" + " ".join(map(str, res)) + "\n")