Baekjoon

[Baekjoon] 행렬 제곱(python)

김철현 2022. 9. 28. 21:21

https://www.acmicpc.net/problem/10830

문제 접근

N x N 행렬의 곱셈을 구하는 코드는 보통 아래처럼 구현할 것이다.

for i in range(N):
    for j in range(N):
        for k in range(N):
            matrix[i][j] += a[i][k] * b[k][j]

시간 복잡도는 O(N^3)이고 N x N 행렬의 B 제곱을 구하는 것이니 O(BN^3)이 될 것이다.

 

당연히 시간 초과가 날 것이다...

 

그렇다면 어떻게 효율적으로 행렬의 제곱을 할 수 있을까?

행렬 제곱은 N x N x N... x N처럼 순서대로 처리할 필요는 없다.

 

N x N x N = N x N^2

N x N x N x N x N = N^2 x N^3

와 같이 전체 제곱 수를 반으로 분할하여 최소 단위까지 쪼갠 후 계산한 결과를 합치면 효율적으로 행렬 연산을 수행할 수 있다.

 

즉, 분할 정복을 이용하는 것이다.


풀이 과정

행렬 제곱을 구하기 위해서 반으로 분할하여 최소 단위까지 쪼갠 후 행렬 곱셈을 수행하면 된다.

 

제곱 수가 홀수인 경우는 반으로 나눠지지 않으므로 짝수로 만들어주도록 하자.

 

N이 홀수인 경우는 N -> (N - 1) * 1로 분할하면 된다.


정답 코드

"""
    백준 10830번 행렬 제곱
"""

import sys


def multiply(a, b):
    result = [[0] * N for _ in range(N)]

    for i in range(N):
        for j in range(N):
            for k in range(N):
                result[i][j] += (a[i][k] * b[k][j]) % 1000
                result[i][j] %= 1000

    return result


def divide_and_conquer(n):
    if n == 1:
        return matrix

    if n % 2 == 0:
        m = divide_and_conquer(n // 2)

        return multiply(m, m)
    else:
        return multiply(divide_and_conquer(n - 1), matrix)


N, B = list(map(int, sys.stdin.readline().split()))
matrix = []

for _ in range(N):
    matrix.append(list(map(int, sys.stdin.readline().split())))

matrix = divide_and_conquer(B)

for row in matrix:
    for col in row:
        print(col % 1000, end=" ")
    print()

'Baekjoon' 카테고리의 다른 글

[Baekjoon] 미네랄(python)  (0) 2022.10.02
[Baekjoon] 피보나치 수 3(python)  (0) 2022.10.01
[Baekjoon] 백조의 호수(python)  (1) 2022.09.25
[Baekjoon] 가운데를 말해요(python)  (1) 2022.09.24
[Baekjoon] 평범한 배낭(python)  (0) 2022.09.24