본문 바로가기

알고리즘/문제풀이

[백준] Q11049 행렬 곱셈 순서

https://www.acmicpc.net/problem/11049

 

11049번: 행렬 곱셈 순서

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

www.acmicpc.net

 

1. 문제 유형과 이해

DP

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.

 

문제에서 얻을 수 있는 힌트는 'AB*C 와 A*BC = ABC' 입니다. 

우리가 구하고자 하는 행렬의 범위 (0,N-1)에 대해서 반복하여 계산하여 구할 수 있다는 것을 알 수 있습니다.

 

2. 문제 접근 방법

dp를 풀기 위해 먼저 점화식을 세워보도록 합시다.

N가지의 행렬에서 첫번째 시작점 0을 두고 만들 수 있는 방법의 수는 총 N가지가 됩니다.

dp[0][N] = dp[0][0] ~[1][N] // [0][1] ~ [2][N] .... [0][N-1] *[N][N] 이중에서 가장 작은 값을 찾으면 됩니다.

어떠한 중간값 mid ( start<=mid<end) 에 대해 계산을 해나가면 되겠습니다.

 

우리는 '연산의 최소값'을 구하고자 합니다. 어떠한 행렬 start부터 end까지의 연산의 최소값을

dp[start][end]로 정의합니다. 이후, 위에서의 가짓수 처럼 존재하는 mid들에 대해 계산해줍니다.

 

두 연산횟수를 더하는 것 뿐 아니라 새롭게 곱해지는 연산의 횟수도 더해주어야 합니다.

행렬에서 행렬의 크기는 A[X][Y] * B[Y][Z]일 때 [X][Z]의 사이즈를 가지게 됩니다.

시작점 Start의 행의 크기  * 중간행렬 mid의 열의 크기(==마지막 행렬 행의 크기) * 마지막 행렬 end의 열의 크기

 

결론적으로 점화식은

dp[start][end] = dp[start][mid-1] + dp[mid][end] + (size[start][0] * size[mid][1] * size[end][1])

(start<mid<=end)

 

Top-Down방식으로 구할 수도 있지만 저는 bottom-up 방식으로

행렬의 크기가 2일때 3일때 ... N일때로 구하였습니다.

 

문제에서 값의 범위는 $2^31$-1 보다 작다고 하였으므로 int형으로 선언해주어도 무방합니다.

 

Code
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

/**
 * @문제번호 : Q11049
 * @문제이름 : 행렬곱셈순서
 * @난이도 : Gold III
 * @date : 2022-02-11 오후 12:01
 * @author : pcrmcw0486
 * @문제이해
 * 크기가 N*M인 A와 M*K를 곱할 때 곱셉 연산의 수는 N*M*K번.
 * ABC AB*C와 A*BC 에 따라 행렬 곱 횟수가 달라진다.
 * N개의 행렬 크기가 주어질 때 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수 최소값.
 * 행렬의 순서를 바꾸면 안된다.
 * N <= 500  (r,c) <= 500
 * 정답은 int형 안에 가능하다.
 * @알고리즘
 * dp
 * @접근방법
 * (0,N)에 대해 (0,x)(x+1,N)을 곱하게 되면 (0,N)이 만들어진다.
 * 이 때, (0,N)의 최소값을 찾는 문제이다.
 * TSP랑 비슷한 구조같은데
 * dp[i][j]는 (i~j)까지 곱한 횟수 (i==j라면 0)
  * dp[i][j] = dp[i][i] + dp[i+1][j] + num[i][0]*num[i+1][0]*num[i+1][1];
 *              (i=0~j)
*/
public class Q11049 {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        StringTokenizer st;
        int[][] size = new int[N][2];
        int[][] dp = new int[N][N];
        for(int i =0;i<N;i++){
            st = new StringTokenizer(br.readLine());
            size[i][0] = Integer.parseInt(st.nextToken());
            size[i][1] = Integer.parseInt(st.nextToken());
        }
        for(int length=1;length<N;length++){
            //자신 기준 오른쪽으로 lengh만큼
            for(int start =0;start<N-length;start++){
                dp[start][start+length] = Integer.MAX_VALUE;
                //시작점
                for(int end =start;end<(start+length);end++){
                    //여러 끝점들
              		int cost = dp[start][end] + dp[end+1][start + length] + (size[start][0] * size[end][1] * size[start+length][1])
                    dp[start][start+length] = Math.min(dp[start][start+length], cost);
                }
            }
        }
        System.out.println( dp[0][N-1]);
    }
}