Solution Idea

50052. K-means Algorithm

題意

用 k-means 將 N 個長度為 L 的字串分成三類。每個字串可以根據組成字元的 ASCII code 表示成一個 vector,兩個字串間的距離是兩個 vector 的 Manhattan distance。

k-menas 的做法是,一開始先選前三個字串當作 leader,接著重覆以下三個步驟 R 次。

  1. 把每個字串 assign 給離自己最近的 leader (有多個選擇就選字典序小的),有同樣 leader 的字串組成一組,總共有三組。
  2. 計算每組的 mean vector,算法是把這組內的字串 vector 相加,除以組員數目。
  3. 每組內,離 mean vector 最近的字串是新的 leader (有多個選擇就選字典序小的)。

詳細步驟說明可以看題目中 sample input 底下的 "Explanation"。

範例程式說明

這裡說明一些範例程式中比較重要的 functions。

  • void kmeans(int *leaderIds, int (*memberIds)[MAXN], int *memberCnt)
    演算法主要的架構在這個 function 裡。
    用一個二維陣列 memberIds[gid][mid] 記錄 gid 這組的第 mid 個 member 是誰,並用陣列 memberCnt[gid] 紀錄 gid 這組總共有幾個 member。 在每個 round,先幫每個字串找到最接近的 leader,存在 memberIds, 再從 memberIds 分別計算三組的 mean vector 並指定新的 leaders。

  • int pickClosest(char *str, int *candidateIds, int numCandidate)
    candidateIds 陣列中,挑出一個離 str 最近的 string。
    這個 function 用在兩個地方,第一個是 step 1 幫每個字串挑 leader 的時候,str 代表字串,candidateIds 則是三個 leader 的 id。第二個用的地方是 step 3,在找離 mean vector 最近的字串的時候,str 代表 mean vector,candidateIds 則是這組內所有字串,要挑一個離 mean 最近的。

  • int assignNewLeader(int *memberIds, int memberCnt)
    先計算 memberIdsmemberCnt 個值的 mean vector,再用 pickClosest() 找離 mean vector 最近的當下一輪的 leader。
    計算 mean 的時候,要先用 int 存 sum ,除以 memberCnt 之後再轉成 char 陣列。因為 char 只能存 8 個 bits ,也就是 -128 ~ 127 。如果直接用 char 型態計算 sum 會 overflow。

程式碼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAXN 50
#define MAXL 20
#define K 3 // number of groups
 
int N, L, R;
char strs[MAXN][MAXL+5];
 
int dist(char *str1, char *str2){ // Manhattan distance of two strings
    int s = 0;
    for(int i = 0; i < L; i++)
        s += abs(str1[i] - str2[i]);
    return s;
}
 
int isSmaller(char *str1, char *str2){ // check if str1 is lexigraphically smaller than str2
    return strcmp(str1, str2) < 0;
}
 
int isCloser(char *curStr, char *oriStr, char *newStr, int oriDist){ // check if newStr is closer to curStr (compare with dist(curStr, oriStr))
    int newDist = dist(curStr, newStr);
    return newDist != oriDist ? newDist < oriDist : isSmaller(newStr, oriStr); // if the distance is the same, check the lexigraphic order
}
 
int pickClosest(char *str, int *candidateIds, int numCandidate){ // pick the closest candidate from candidateIds[] which has the minimal distance to str
    int minDist = dist(str, strs[candidateIds[0]]), minId = 0;
    for(int i = 1; i < numCandidate; i++){ // for each candidate
        char *candidateStr = strs[candidateIds[i]];
        if (isCloser(str, strs[candidateIds[minId]], candidateStr, minDist)) // check if it is better than the currenct min
            minDist = dist(str, candidateStr), minId = i;
    }
    return minId;
}
 
int assignNewLeader(int *memberIds, int memberCnt){ // given the members, find the new leader
    int sum[MAXL+5] = {0};
    char meanStr[MAXL+5] = {0};
    for(int i = 0; i < L; i++){ // calculate the mean of the group (step 2.)
        for(int j = 0; j < memberCnt; j++)
            sum[i] += strs[memberIds[j]][i];
        meanStr[i] = sum[i] / memberCnt;
    }
    return memberIds[pickClosest(meanStr, memberIds, memberCnt)]; // pick the one with minimal distance to the mean (step 3.)
}
 
int sort(int *idx, int cnt){ // sort the index (selection sort)
    int tmp;
    for(int i = 0; i < cnt; i++){
        for(int j = i + 1; j < cnt; j++){
            if (isSmaller(strs[idx[j]], strs[idx[i]]))
                tmp = idx[i], idx[i] = idx[j], idx[j] = tmp;
        }
    }
}
 
void kmeans(int *leaderIds, int (*memberIds)[MAXN], int *memberCnt){
    for(int i = 0; i < K; i++) leaderIds[i] = i;
 
    for(int r = 0; r < R; r++){
        for(int i = 0; i < K; i++) memberCnt[i] = 0;
        for(int i = 0; i < N; i++){
            int gid = pickClosest(strs[i], leaderIds, K); // pick leader (step 1.)
            memberIds[gid][memberCnt[gid]++] = i; // increase group member count
        }
        for(int i = 0; i < K; i++)
            leaderIds[i] = assignNewLeader(memberIds[i], memberCnt[i]); // find new leader (step 2. 3.)
    }
}
 
 
int main(){
    scanf("%d %d %d", &N, &L, &R);
 
    for(int i = 0; i < N; i++)
        scanf("%s", strs[i]) == 1;
 
    int leaderIds[K], memberIds[K][MAXN], memberCnt[K];
    kmeans(leaderIds, memberIds, memberCnt); // iterate k-mens for R rounds
    sort(leaderIds, K); // sort the leaders
 
    for(int i = 0; i < K; i++)
        printf("%s\n", strs[leaderIds[i]]);
 
    return 0;
}

Discussion