[SWEA][D4][#05250] 최소 비용

작성:    

업데이트:

카테고리:

태그: , ,

출처

학습용 포스트입니다. 본 포스트가 문제가 될 시 수정 또는 삭제하겠습니다.


문제

출발에서 최종 도착지까지 경유하는 지역의 높이 차이에 따라 연료 소비량이 달라지기 때문에, 최적의 경로로 이동하면 최소한의 연료로 이동할 수 있다.

다음은 각 지역의 높이를 기록한 표의 예로, 항상 출발은 맨 왼쪽 위, 도착지는 가장 오른쪽 아래이며, 각 칸에서는 상하좌우 칸이 나타내는 인접 지역으로만 이동할 수 있다.

(표에 표시되지 않은 지역이나 대각선 방향으로는 이동 불가.)

그림 생략

인접 지역으로 이동시에는 기본적으로 1만큼의 연료가 들고, 더 높은 곳으로 이동하는 경우 높이 차이만큼 추가로 연료가 소비된다.

그림 생략

색이 칠해진 칸을 따라 이동하는 경우 기본적인 연료 소비량 4에, 높이가 0에서 1로 경우만큼 추가 연료가 소비되므로 최소 연료 소비량 5로 목적지에 도착할 수 있다.

이동 가능한 지역의 높이 정보에 따라 최소 연료 소비량을 출력하는 프로그램을 만드시오.


입력

첫 줄에 테스트 케이스의 개수 T가 주어지고, 테스트 케이스 별로 첫 줄에 표의 가로, 세로 칸수N, 다음 줄부터 N개 지역의 높이 H가 N개의 줄에 걸쳐 제공된다.

1<=T<=50, 3<=N<=100, 0<=H<1000


출력

각 줄마다 “#T” (T는 테스트 케이스 번호)를 출력한 뒤, 답을 출력한다.


예제

입력

3
3
0 2 1
0 1 1
1 1 1
5
0 0 0 0 0 
0 1 2 3 0 
0 2 3 4 0 
0 3 4 5 0 
0 0 0 0 0 
5
0 1 1 1 0 
1 1 0 1 0 
0 1 0 1 0 
1 0 0 1 1 
1 1 1 1 1


출력

#1 5
#2 8
#3 9


My Sol

from heapq import heappop, heappush

didj = ((-1,0), (0,-1), (1,0), (0,1))
def check(i, j):
    global path
    for di, dj in didj:
        si, sj = i+di, j+dj
        if not (0<=si<N and 0<=sj<N): continue
        nh, sh = high[i][j], high[si][sj]
        sw = 1 if nh > sh else sh-nh+1
        path[i][j].append((sw, si, sj))


for tc in range(1, int(input())+1):
    N = int(input())
    high = [list(map(int, input().split())) for _ in range(N)]
    memo = [[0xffffff]*N for _ in range(N)]
    memo[0][0] = 0
    path = []
    for _ in range(N):
        lst = [[] for _ in range(N)]
        path.append(lst)

    for i in range(N):
        for j in range(N):
            check(i, j)

    Q = []
    heappush(Q, (0,0,0))

    while Q:
        w, i, j = heappop(Q)
        for sw, si, sj in path[i][j]:
            if memo[si][sj] > memo[i][j] + sw:
                memo[si][sj] = memo[i][j] + sw
                heappush(Q, (memo[i][j]+sw, si, sj))

    print(f'#{tc} {memo[N-1][N-1]}')

다익스트라 알고리즘을 이용해 푸는 문제였다. heap을 사용하고, 시작점으로부터 가장 적은 가중치를 memo에 저장하며 우선순위를 이용해 가장 작은 가중치의 i, j를 선택하며 이동한다.

w와 memo를 혼용해서 조금 헷갈리는 코드였는데, 처음 적용해보다보니 생긴 시행착오였나보다. 더 깊은 이해와 활용이 필요한 것 같다.


결과

PASS


모범답안

T = int(input())
 
for tc in range(1, T + 1):
    N = int(input())
 
    arr = [list(map(int, input().split())) for _ in range(N)]
 
    G = [[] for _ in range(N * N)]
 
    for i in range(N):
        for j in range(N):
 
            for d_i, d_j in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                n_i = i + d_i
                n_j = j + d_j
 
                if not (0 <= n_i < N and 0 <= n_j < N):
                    continue
 
                if arr[n_i][n_j] > arr[i][j]:
                    G[i*N + j].append((n_i*N + n_j, arr[n_i][n_j] - arr[i][j] + 1))
                else:
                    G[i * N + j].append((n_i * N + n_j, 1))
 
    D = [0xfffff] * (N * N)
    D[0] = 0
 
    while True:
        STOP = True
 
        for u in range(N * N):
            for v, w in G[u]:
                if D[v] > D[u] + w:
                    D[v] = D[u] + w
                    STOP = False
 
        if STOP:
            break
 
    print(f'#{tc} {D[N*N - 1]}')

간선완화를 이용하는 방법인데, 간단함에도 불구하고 상당히 빠른 연산시간을 보여준 코드이다. while문을 반복하면서 모든 간선에 대해 갱신할 간선이 없다면 탈출하는 방식이다.

댓글남기기