알고리즘 문제

[백준] 14502번 연구소

feelcoding 2020. 1. 5. 21:23
728x90

삼성 SW 역량 테스트 기출 문제인 14502번

 

14502번: 연구소

인체에 치명적인 바이러스를 연구하던 연구소에서 바이러스가 유출되었다. 다행히 바이러스는 아직 퍼지지 않았고, 바이러스의 확산을 막기 위해서 연구소에 벽을 세우려고 한다. 연구소는 크기가 N×M인 직사각형으로 나타낼 수 있으며, 직사각형은 1×1 크기의 정사각형으로 나누어져 있다. 연구소는 빈 칸, 벽으로 이루어져 있으며, 벽은 칸 하나를 가득 차지한다.  일부 칸은 바이러스가 존재하며, 이 바이러스는 상하좌우로 인접한 빈 칸으로 모두 퍼져나갈 수 있다.

www.acmicpc.net

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int m = in.nextInt();
        int[][] arr = new int[n][m];
        int[][] ori = new int[n][m];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                arr[i][j] = in.nextInt();
                ori[i][j] = arr[i][j];
            }
        }
        int max = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                if (arr[i][j] == 1 || arr[i][j] == 2) continue;
                for (int k = i; k < n; k++) {
                    for (int l = 0; l < m; l++) {
                        if (k == i && j >= l) continue;
                        if (arr[k][l] == 1 || arr[k][l] == 2) continue;
                        int count = 0;
                        for (int o = k; o < n; o++) {
                            for (int p = 0; p < m; p++) {
                                if (o == k && p <= l) continue;
                                if (arr[o][p] == 1 || arr[o][p] == 2) continue;
                                arr[i][j] = 1;
                                arr[k][l] = 1;
                                arr[o][p] = 1;
                                for (int a = 0; a < n; a++) {
                                    for (int b = 0; b < m; b++) {
                                        if (arr[a][b] == 2) {
                                            if (a == 0 && b == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == 0 && b == m - 1) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (a == n - 1 && b == 0) {
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == n - 1 && b == m - 1) {
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (a == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (b == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == n - 1) {
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (b == m - 1) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                            } else {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            }
                                        }
                                    }
                                }
                                for (int a = n - 1; a >= 0; a--) {
                                    for (int b = m - 1; b >= 0; b--) {
                                        if (arr[a][b] == 2) {
                                            if (a == 0 && b == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == 0 && b == m - 1) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (a == n - 1 && b == 0) {
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == n - 1 && b == m - 1) {
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (a == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            } else if (b == 0) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (a == n - 1) {
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                            } else if (b == m - 1) {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                            } else {
                                                if (arr[a + 1][b] == 0) arr[a + 1][b] = 2;
                                                if (arr[a - 1][b] == 0) arr[a - 1][b] = 2;
                                                if (arr[a][b + 1] == 0) arr[a][b + 1] = 2;
                                                if (arr[a][b - 1] == 0) arr[a][b - 1] = 2;
                                            }
                                        }
                                    }
                                }
                                count = 0;
                                for (int a = 0; a < n; a++) {
                                    for (int b = 0; b < m; b++) {
                                        if(arr[a][b] == 0) count++;
                                    }
                                }
                                if(count > max) {
                                    max = count;
                                }
                                for (int a = 0; a < n; a++) {
                                    for (int b = 0; b < m; b++) {
                                        arr[a][b] = ori[a][b];
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        System.out.println(max);
    }
}

살다살다 8중 for문을 쓸 줄이야
내가 조합(Combination)을 코드로 구현할 줄 몰라서 무식한 방법으로 했다

예제 입력을 입력으로 넣었을 때는 모두 옳은 답이 나오는데 제출하면 틀리다고 한다.

예제 입력1
예제 입력2
예제 입력3

 

아직도 못 풀었지만 꼭 내 손으로 풀 거다

 

2020.01.14

이후에 순열과 조합 알고리즘을 공부하고 오늘 다시 도전해보았다

def permutation(result, num, n, li):
    if num == 3:
        v = [[False] * m for i in range(n)]
        import collections
        q = collections.deque()
        for i in virus:
            q.append((i[0], i[1]))
        while q:
            cur = q.popleft()
            v[cur[0]][cur[1]] = True
            if cur[0] - 1 >= 0 and li[cur[0] - 1][cur[1]] == 0 and not v[cur[0] - 1][cur[1]]:
                q.append((cur[0] - 1, cur[1]))
            if cur[0] + 1 < n and li[cur[0] + 1][cur[1]] == 0 and not v[cur[0] + 1][cur[1]]:
                q.append((cur[0] + 1, cur[1]))
            if cur[1] - 1 >= 0 and li[cur[0]][cur[1] - 1] == 0 and not v[cur[0]][cur[1] - 1] :
                q.append((cur[0], cur[1] - 1))
            if cur[1] + 1 < m and li[cur[0]][cur[1] + 1] == 0 and not v[cur[0]][cur[1] + 1]:
                q.append((cur[0], cur[1] + 1))
        count = 0
        for i in range(n):
            for j in range(m):
                if li[i][j] == 0 and not v[i][j]:
                    count += 1
        c.append(count)
        return
    for i in range(n):
        for j in range(m):
            if not visited[i][j]:
                if li[i][j] == 0:
                    if num > 0:
                        if result[num - 1][0] < i:
                            result[num] = (i, j)
                            visited[i][j] = True
                            li[i][j] = 1
                            permutation(result, num + 1, n, li)
                            visited[i][j] = False
                            li[i][j] = 0
                        elif result[num - 1][0] == i and result[num - 1][1] < j:
                            result[num] = (i, j)
                            visited[i][j] = True
                            li[i][j] = 1
                            permutation(result, num + 1, n, li)
                            visited[i][j] = False
                            li[i][j] = 0
                        else:
                            continue
                    else:
                        result[num] = (i, j)
                        visited[i][j] = True
                        li[i][j] = 1
                        permutation(result, num + 1, n, li)
                        visited[i][j] = False
                        li[i][j] = 0
from sys import stdin
n, m = list(map(int, stdin.readline().split()))
li = []
virus = []
for i in range(n):
    temp = list(map(int, stdin.readline().split()))
    li.append(temp)
    for j in range(m):
        if temp[j] == 2:
            virus.append((i, j))
visited = [[False] * m for i in range(n)]
r = [(0, 0)] * 3
c = []
permutation(r, 0, n, li)
print(max(c))

시간초과가 뜬다... 조합이 재귀함수라서 그런걸까

 

6시간 뒤

성공! 위와 코드는 똑같은데 큐에 들어가 있는 것을 중복으로 집어넣는 것을 방지해주었더니 바로 해결됐다.

오늘 또 교훈을 얻는다. 큐에 넣을 때 visited를 True로 바꾸자

def permutation(result, num, n, li):
    if num == 3:
        v = [[False] * m for i in range(n)]
        import collections
        q = collections.deque()
        for i in virus:
            q.append((i[0], i[1]))
        while q:
            cur = q.popleft()
            if not v[cur[0]][cur[1]] and li[cur[0]][cur[1]] == 0:
                v[cur[0]][cur[1]] = True
            if cur[0] - 1 >= 0 and li[cur[0] - 1][cur[1]] == 0 and not v[cur[0] - 1][cur[1]]:
                q.append((cur[0] - 1, cur[1]))
                v[cur[0] - 1][cur[1]] = True
            if cur[0] + 1 < n and li[cur[0] + 1][cur[1]] == 0 and not v[cur[0] + 1][cur[1]]:
                q.append((cur[0] + 1, cur[1]))
                v[cur[0] + 1][cur[1]] = True
            if cur[1] - 1 >= 0 and li[cur[0]][cur[1] - 1] == 0 and not v[cur[0]][cur[1] - 1] :
                q.append((cur[0], cur[1] - 1))
                v[cur[0]][cur[1] - 1] = True
            if cur[1] + 1 < m and li[cur[0]][cur[1] + 1] == 0 and not v[cur[0]][cur[1] + 1]:
                q.append((cur[0], cur[1] + 1))
                v[cur[0]][cur[1] + 1] = True
        count = 0
        for i in range(n):
            for j in range(m):
                if li[i][j] == 0 and not v[i][j]:
                    count += 1
        c.append(count)
        return
    for i in range(n):
        for j in range(m):
            if not visited[i][j]:
                if li[i][j] == 0:
                    if num > 0:
                        if result[num - 1][0] < i:
                            result[num] = (i, j)
                            visited[i][j] = True
                            li[i][j] = 1
                            permutation(result, num + 1, n, li)
                            visited[i][j] = False
                            li[i][j] = 0
                        elif result[num - 1][0] == i and result[num - 1][1] < j:
                            result[num] = (i, j)
                            visited[i][j] = True
                            li[i][j] = 1
                            permutation(result, num + 1, n, li)
                            visited[i][j] = False
                            li[i][j] = 0
                        else:
                            continue
                    else:
                        result[num] = (i, j)
                        visited[i][j] = True
                        li[i][j] = 1
                        permutation(result, num + 1, n, li)
                        visited[i][j] = False
                        li[i][j] = 0
from sys import stdin
n, m = list(map(int, stdin.readline().split()))
li = []
virus = []
for i in range(n):
    temp = list(map(int, stdin.readline().split()))
    li.append(temp)
    for j in range(m):
        if temp[j] == 2:
            virus.append((i, j))
visited = [[False] * m for i in range(n)]
r = [(0, 0)] * 3
c = []
permutation(r, 0, n, li)
print(max(c))
728x90