- 배열에서 연속된 구간의 합을 빠르게 구하기 위한 알고리즘 기법
- 보통 반복문으로 매번 구간의 합을 직접 계산하면 O(N)이 걸리지만 누적합 배열을 미리 만들어두면 O(1)에 구간합을 구할 수 있어 시간을 획기적으로 줄일 수 있습니다.
예시 문제1)
https://www.acmicpc.net/problem/11659
문제
수 N개가 주어졌을 때, i번째 수부터 j번째 수까지 합을 구하는 프로그램을 작성하시오.
시간복잡도를 생각하지 않고 문제를 단순히 봤을 때 반복문을 사용해서 구현할 수 있음
n, m = map(int, input().split())
data = list(map(int, input().split()))
ran = []
for _ in range(m):
s, e = map(int, input().split())
ran.append([s, e])
for i in range(m):
result = 0
while ran[i][1] >= ran[i][0]:
result += data[ran[i][1]-1]
ran[i][1] -= 1
print(result)
제약 다시 보기:
- N ≤ 100,000
- M ≤ 100,000
- 1 ≤ i ≤ j ≤ N
- 1초
현재 코드 동작 방식:
- for i in range(m) 루프 → m번 반복
- 그 안에서 while ran[i][1] >= ran[i][0] → 최대 n번까지 더함
즉, 최악의 경우 M × N = 10⁵ × 10⁵ = 10¹⁰ 연산 발생
시간 제한이 1초인 경우 -> 10⁷~ 10⁸ 사이의 연산이 가능
-> 일반적인 반복문을 통해 구현하면 시간 초과 발생
이때, 위와 같이 연속된 구간의 합을 구할때 누적합을 사용하면 시간 복잡도를 줄일 수 있음
누적합 배열을 만드는데 걸리는 시간은 O(N)이지만, 구간 합을 구할때는 O(M)에 구할 수 있음
누적합 적용 예시)
1. 예시 배열
A = [1, 2, 3, 4, 5]
2. 누적합 배열 생성
S[0] = 0
S[1] = A[0] = 1
S[2] = A[0] + A[1] = 3
S[3] = A[0] + A[1] + A[2] = 6
...
S[i] = A[0] + A[1] + ... + A[i-1]
3. i번째부터 j번째까지 합 구하기
A[i] + A[i+1] + ... + A[j] = S[j+1] - S[i]
예시 문제에 적용해보기)
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
data = list(map(int, input().split()))
#누적합 배열 생성
prefix = [0] * (n+1)
for i in range(1, n+1):
prefix[i] = prefix[i - 1] + data[i - 1]
#구간별 구간의 합 계산
for _ in range(m):
s, e = map(int, input().split())
result = prefix[e] - prefix[s-1]
print(result)
2차원 누적합
예시 문제2)
https://www.acmicpc.net/problem/11660
문제
N×N개의 수가 N×N 크기의 표에 채워져 있다. (x1, y1)부터 (x2, y2)까지 합을 구하는 프로그램을 작성하시오. (x, y)는 x행 y열을 의미한다.
누적합 배열을 생성하고 누적합을 구하면 (x1, y1)부터 (x2, y2)까지의 합을 O(1)에 처리할 수 있음
S[i][j]: (1,1)부터 (i,j)까지의 사각형 영역의 합
# (x1, y1)부터 (x2, y2)까지의 합
S[x2][y2] - S[x1-1][y2] - S[x2][y1-1] + S[x1-1][y1-1]
예시 코드)
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
arr = [[0] * (n + 1)] # 1행부터 시작하기 위해 0번째 행 추가
for _ in range(n):
row = [0] + list(map(int, input().split())) # 1열부터 시작하기 위해 0 추가
arr.append(row)
# 누적합 계산
prefix = [[0] * (n + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
for j in range(1, n + 1):
prefix[i][j] = arr[i][j]
+ prefix[i - 1][j]
+ prefix[i][j - 1]
- prefix[i - 1][j - 1]
# 질의 처리
for _ in range(m):
x1, y1, x2, y2 = map(int, input().split())
result = prefix[x2][y2]
- prefix[x1 - 1][y2]
- prefix[x2][y1 - 1]
+ prefix[x1 - 1][y1 - 1]
print(result)