알고리즘/스택

[백준] 2504 - 괄호의 값 Python

기억은 RAM, 기록은 HDD 2022. 9. 12. 00:08

https://www.acmicpc.net/problem/2504

 

2504번: 괄호의 값

4개의 기호 ‘(’, ‘)’, ‘[’, ‘]’를 이용해서 만들어지는 괄호열 중에서 올바른 괄호열이란 다음과 같이 정의된다. 한 쌍의 괄호로만 이루어진 ‘()’와 ‘[]’는 올바른 괄호열이다.  만일

www.acmicpc.net

다음 규칙에 따라 괄호열의 값을 계산해 출력하는 문제다.

  1. ‘()’ 인 괄호열의 값은 2이다.
  2. ‘[]’ 인 괄호열의 값은 3이다.
  3. ‘(X)’ 의 괄호값은 2×값(X) 으로 계산된다.
  4. ‘[X]’ 의 괄호값은 3×값(X) 으로 계산된다.
  5. 올바른 괄호열 X와 Y가 결합된 XY의 괄호값은 값(XY)= 값(X)+값(Y) 로 계산된다. 

문제 설명에 따르면 (()[[]])([]) 의 괄호값은, ()[[]] 의 괄호값이 2 + 3×3=11 이므로 (()[[]]) 의 괄호값은 2×11=22 이다. 그리고 ([]) 의 값은 2×3=6 이므로 전체 괄호열의 값은 22 + 6 = 28 이다. 이를 부분 부분으로 쪼개서 해석하면,  (()[[]])([]) 는 (()[[]]) + ()[[]] 과 값이 동일하고, (()[[]]) 는 () + [[]] 값에 2를 곱한 값과 동일하다. 이를 일반화하면, 주어진 괄호열은 부분 괄호열로 쪼갤 수 있고, 각각의 합을 더해 최종 괄호열의 값이 결정된다.

 

따라서 빈 문자열을 base case 로 설정해 1을 반환하게끔 하면, 다음과 같이 재귀적으로 답을 구할 수 있다.

- 각 재귀 호출마다 가장 바깥 괄호마다 (괄호 내부 괄호열을 계산한 값) * (괄호의 값) 을 계산해 총합을 반환한다. 

예를 들어, []() 의 경우, [] 는 괄호 내부 괄호열을 계산한 값이 base case 이므로 1 이고, 괄호의 값은 3 이므로 3이다. () 의 경우도 마찬가지로 2이다. 따라서 총합은 5 임을 알 수 있다.

 

가장 바깥 괄호임을 확인하는 방법에서 스택이 이용된다. 괄호열의 각 문자들을 스택에 집어넣으면서 () 또는 [] 가 만들어질 경우 스택에서 pop 시킨다면 가장 바깥 괄호쌍이 이루어질 때 스택이 비게 된다. 따라서 이 경우 재귀 호출을 통해 내부 괄호열을 계산한 값을 가져오고, 이를 괄호의 값과 곱해, 총합을 계산하도록 한다.

 

재귀 호출을 그림으로 그려보면 다음과 같다. 문제에서 설명한 (()[[]])([]) 의 경우를 예시로 사용했다.

번호는 각 노드의 방문 순서, 입력으로 주어지는 괄호열을 노드 옆에 표시했다. 

계산 과정은 다음과 같다.

- 1번 = 2번*2 + 6번*2 = 22 + 6 = 28

- 2번 = 3번*2 + 4번*3 = 2 + 9 = 11

- 3번 = 1

- 4번 = 5번*3 = 3

- 5번 = 1

- 6번 = 7번*3 = 3

- 7번 = 1

 

전체 코드는 아래와 같다.

from sys import stdin

S = stdin.readline().rstrip()

def dfs(S):
    if S == "":
        return 1
    stack = []
    cnt = 0
    s = 0
    for i in range(len(S)):
        if not stack:
            stack.append(S[i])
            continue
        if stack[-1] + S[i] == "()" or stack[-1] + S[i] == "[]":
            stack.pop()
        else:
            stack.append(S[i])
        if not stack:
            if S[i] == ")":
                cnt += dfs(S[s+1:i])*2
            elif S[i] == "]":
                cnt += dfs(S[s+1:i])*3
            s = i+1
    if stack:
        return 0
    return cnt

print(dfs(S))