Pythonで強連結成分分解するのにscipyが高速に動く(らしい)話

お久しぶりです。kawap23です。

10月から就職したため精進できていません。痛勤電車内でいかに精進できるかがこれ以降のノビに効いてきそうです。

さて今回は強連結成分分解も、前記事の最短距離問題と同じくscipyで速くなるんじゃねって思って調べ検証してみました。

基本的には、正確性より問題で使えるようになること目指しました。正確性は保証しません(大事なことなので(略))。

2019年10月06日 夜 追記

maspyさんに自前実装の最適化がされていないと教えて頂いたので、プログラムを変更し時間を測り直しました。maspyさん、ありがとうございます。

参考にしたもの

強連結成分分解とは

まぁ蟻本やWikiでも読んでください。不正確に簡単に言うと、お互いに行き来することができる点を一つにまとめるって感じです。使う問題としては旧ARCにはなりますが、ARC030-C 有向グラフなどがあると思います。

そもそもScipyを使った実装って速いの?

測定に使ったコードは最後に貼っておきますが、頂点数Nに対して、辺の数が2倍、3倍の各パターンに対して自前実装とScipy実装とで計算時間を比べました。

 はっきり言って、自前実装とか相手にもなりません。touristと僕以上の差があります。

 

頂点数N 辺の数M 自前実装 scipy実装
10^1 2 * 10^1 0.00 0.00
10^2 2 * 10^2 0.00 0.00
10^3 2 * 10^3 0.01 0.00
10^4 2 * 10^4 0.06 0.00
10^5 2 * 10^5 0.73 0.01
10^6 2 * 10^6 7.86 0.33
10^7 2 * 10^7 83.60 4.16
       
頂点数N 辺の数M 自前実装 scipy実装
10^1 3 * 10^1 0.00 0.00
10^2 3 * 10^2 0.00 0.00
10^3 3 * 10^3 0.01 0.00
10^4 3 * 10^4 0.07 0.00
10^5 3 * 10^5 0.81 0.01
10^6 3 * 10^6 9.15 0.51
10^7 3 * 10^7 100.35 6.38

 ※頂点数10^6, 辺の数3 * 10 ^ 6は計算が終わりません。数時間たってPCがスリープになっている事故をやらかしてます。近いうちに計算させます。プログラム(実装)が悪くて時間がかかっているだけでした。10^7のまで普通に計算できました。

f:id:kawap23:20191006203145p:plain

自前実装 vs scipy

見ての通り、Scipy実装が圧倒的に速いです。最大で数万倍速度差があります。意味がわかんないです。(自前実装が最適化されていない可能性も大いにあります)

さらに両対数グラフにしてみると…

f:id:kawap23:20191006203225p:plain

自前実装 vs scipy (両対数グラフ)

となりO(N + M)で比例の関係にありそうなことがわかります。多分アルゴリズムあってそうです。ここまできれいに見えると嬉しいですね。

使い方

グラフの受け取り

最短距離問題のときと同じく、疎行列で受け取るのが楽でいいと思います。

csr_matrixには重みを与える必要がありますが、0以外ではなんでも同じように扱われるので、とりあえず1にしときます。多重辺があるときは重みが勝手に増えますが問題ありません。scipy内部で勝手に1になってます。

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
N = int(input())
M = int(input())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
tmp = np.ones(M, dtype = np.int64).T
graph = csr_matrix((tmp, (edge[:] -1)), (N, N))

print (graph)

# 入力例
# 10
# 20
# 1 6
# 3 8 
# 3 4
# 3 9
# 3 2
# 4 7
# 4 5
# 5 1
# 5 10
# 5 9
# 7 8
# 7 9
# 8 1
# 8 10
# 8 9
# 9 10
# 10 8
# 10 1
# 10 6
# 10 6

# 出力
#   (0, 5)        1
#   (2, 1)        1
#   (2, 3)        1
#   (2, 7)        1
#   (2, 8)        1
#   (3, 4)        1
#   (3, 6)        1
#   (4, 0)        1
#   (4, 8)        1
#   (4, 9)        1
#   (6, 7)        1
#   (6, 8)        1
#   (7, 0)        1
#   (7, 8)        1
#   (7, 9)        1
#   (8, 9)        1
#   (9, 0)        1
#   (9, 5)        2
#   (9, 7)        1
関数の使い方

関数の形は connected_componentsで、返り値は強連結成分を一つとみなしたときの、成分の個数(int) 及び numpy.ndarrayです。返り値が2つあるのに気をつけましょう(自戒)

  • グラフ:上で作ったscr_matrixあるいは隣接行列
  • 有向・無向:directed = Trueで有向、Falseで無向
  • 連結条件:connection = 'strong'で両方向に繋がっているときのみ連結、'weak'で片方向で連結でもOK
  • 有向グラフの強連結成分分解はdirected = True, connection = 'strong' でOK
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
N = int(input())
M = int(input())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
tmp = np.ones(M, dtype = np.int64).T
graph = csr_matrix((tmp, (edge[:] -1)), (N, N))

print (connected_components(graph, directed = True, connection = 'strong'))

# 入力
# 10
# 20
# 1 6
# 3 8 
# 3 4
# 3 9
# 3 2
# 4 7
# 4 5
# 5 1
# 5 10
# 5 9
# 7 8
# 7 9
# 8 1
# 8 10
# 8 9
# 9 10
# 10 8
# 10 1
# 10 6
# 10 6

# 出力
# (8, array([1, 2, 7, 6, 5, 0, 4, 3, 3, 3]))

結果の見方としては、10頂点だが8, 9, 10の3頂点が強連結成分であり、8 ~ 10を一つとみなすと8個の成分であるということを示しています。基本的には必要な情報が返ってくるのでAtCoderでも使えると思います。

まとめ

Scipyが有能だということがわかりました。Scipyやnumpy等を実装?している人たちすごいですね。感謝しながら使っていこうと思います。

付録

計算時間の測定に使ったコードです。

def measure(n, times): #頂点の数10 **  n, 10 ** n * times = 辺の数
    import time
    from random import randint
    N = 10 ** n
    M = N * times
    G = [[] for _ in range(N)] #0-indexでの隣接リスト
    RG = [[] for _ in range(N)] #0-indexでの隣接リスト #逆辺用
    G2 = [] #scipy用
    for _ in range(M):
        A = randint(0, N -1)
        while True:
            B = randint(0, N - 1) #自己辺を許さないが、多重辺は許す
            if A != B:
                break
        G[A].append(B)
        RG[B].append(A)
        G2.append([A, B])
    # ----------------------------------------------------------
    # 以下自前実装用
    # ----------------------------------------------------------
    s_mine = time.time()    

    def scc(): #非再帰関数で実装
        def dfs(v):
            stack = [v]
            used[v] = True
            while len(stack) != 0:
                tmp = stack[-1]
                flag = True
                for i in G[tmp]:
                    if not used[i]:
                        flag = False
                        used[i] = True
                        stack.append(i)
                        break
                if flag: #どこにも行かなかった時
                    stack.pop()
                    # stack = stack[:-1] #一行上に最適化
                    vs.append(tmp)

        def rdfs(v, k):
            stack = [v]
            used[v] = True
            cmp[v] = k
            while len(stack) != 0:
                tmp = stack[-1]
                stack.pop()
                # stack = stack[:-1] #一行上に最適化
                used[tmp] = True
                for i in RG[tmp]:
                    if not used[i]:
                        cmp[i] = k
                        stack.append(i)

        used = [False] * N #既に調べたかどうか
        vs = [] #帰りがけの並び
        cmp = [-1] * N            
        for i in range(N):
            if not used[i]:
                dfs(i)
        k = 0
        used = [False] * N #既に調べたかどうか
        for i in vs[::-1]:
            if not used[i]:
                rdfs(i, k)
                k += 1
        return k, cmp #強連結成分分解をしたあとの要素数kとそれぞれの点がどこに位置するか
    l1, lst1 = scc()

    f_mine = time.time()

    # ----------------------------------------------------------
    # 以下scipy用
    # ----------------------------------------------------------
    import numpy as np
    from scipy.sparse import csr_matrix
    G2 = np.array(G2, dtype = np.int64).T
    tmp = np.ones(M, dtype = np.int64).T
    graph = csr_matrix((tmp, (G2)), (N, N))

    s_scipy = time.time()

    import numpy as np
    from scipy.sparse import csr_matrix
    from scipy.sparse.csgraph import connected_components
    l2, lst2 = connected_components(graph, directed = True, connection = 'strong')

    f_scipy = time.time()

    if l1 == l2: #答えが一致している時
        flag = True
        print ('OK')
        print ('頂点数: 10 **', n)
        print ('頂点数 : 辺の数 = 1 :', times)
    else: #答えが一致していない時
        flag = False
        print ('NO')
        print ('スタック l1 =', l1)
        print ('scipy l2 =', l2)
        # print ('再帰関数 l3 =', l3)
        print (G)
        # print (graph)
   
    print ('自前実装  :', format(f_mine - s_mine, '.10f'))
    print ('scipy 実装:', format(f_scipy - s_scipy, '.10f'))
    print ()
    return 

for i in range(2, 4):
    for j in range(1, 8):
        measure(j, i)