Pythonで最短距離問題解くのにscipyが高速に動く(らしい)話

お久しぶりです。kawap23です。

入社が近づいてきてドキドキしています

kawap23.hatenablog.com

ARC025-C: ウサギとカメやARC035-C: アットコーダー王国の交通事情のように最短距離問題を使う問題でPython (pypy)じゃ遅すぎる!ってなっていろいろな方のブログ等を参考にして、なんとかscipyが使えたのでメモがてら残しておこうと思います。

自分が忘れていたときに、ここを見て思い出すことが目的ですので、正確性や厳密性は保証しません。

参考にしたブログ

いきなりですが、参考にさせていただいた方のブログを書いておきます。

・scipyのインストールについて

Masato's IT Library:

WindowsにPython3系とnumpy・scipyをインストールする方法(2/3 ライブラリ編) - Masato's IT Library

・scipyの使い方について

じゅっぴーダイアリー:

scipyのFloyd-WarshallとDijkstraのすすめ Python 競技プログラミング Atcoder - じゅっぴーダイアリー

これとは別に、maspyさんがAtCoderで提出されていたコードなども参考にさせていただきました。

ありがとうございます。

 

あと、良くわからなかったらこれ読んだらいいと思います。英語なのでいい感じに和訳されているサイトあったら教えて下さい。Numpy and Scipy Documentation — Numpy and Scipy documentation

下準備 -scipyのインストール-

 上で紹介したブログですが、実は2017年4月の記事で2年以上前のものでした。が、2019年9月に行っても問題なく行えました。自分の使っているPythonのバージョンとあったnumpy・scipyをダウンロードすることだけ気をつけたら大丈夫だと思います。

(そろそろ自分がPCを買い換えるので自分のためのメモ)

AtCoderはPython3.4で自分は3.6使ってますが、今のところ問題は起きていません。

そもそもscipy使った実装って早いの?

前述のじゅっぴーさんはFloyd-Warshall法の速度がくっそ早いとアピールしていましたので、私はDijkstra法の速度を載せときます。

まずDijkstra法はO(辺の数M × log (頂点数N))で、ある頂点から他の頂点までの最短距離を求める手法です。調べるのに使ったコードや詳しい条件は一番最後に載せておきます。どこまで含めて時間とするかは個々考えがあると思いますので、そんなもんねぐらいの気持ちでお願いします。

頂点数

N

辺の数

M

自前

実装

scipy

dijkstra

dijkstra

全体

scipy

auto

auto

全体

10^3 10^4 0.22秒 0.02秒 0.31秒 0.03秒 0.03秒
10^4 10^5 2.74秒 0.15秒 0.52秒 0.14秒 0.22秒
10^5 10^6 34.56秒 2.52秒 3.51秒 2.64秒 3.31秒
10^6 10^7 429.58秒 41.03秒 49.05秒 41.97秒 49.73秒

 

f:id:kawap23:20190909160824p:plain

自前実装 vs scipy

scipyの早いですね。自作の関数捨てて scipy 使います。

使い方

グラフの受け取り

AtCoderでよくある、頂点数N、辺の数M、Ai, Bi 間の重みがC_i のグラフの受け取り方

入力は1-indexが多いので途中で-1しています。

※入力例は蟻本から(以降全て同じ)

# 入力
# N M
# A0 B0 C0
# A1 B1 C1
# ....
# A(M-1) B(M-1) C(M-1)

import numpy as np
from scipy.sparse import csr_matrix
N, M = map(int, input().split())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
graph = csr_matrix((edge[2], (edge[:2] - 1)), (N, N))

print (graph)

# 入力例
# 7 10
# 1 2 2
# 1 3 5
# 2 3 4
# 2 4 6
# 2 5 10
# 3 4 2
# 4 6 1
# 5 6 3
# 5 7 5
# 6 7 9

# 出力例
#   (0, 1)        2
#   (0, 2)        5
#   (1, 2)        4
#   (1, 3)        6
#   (1, 4)        10
#   (2, 3)        2
#   (3, 5)        1
#   (4, 5)        3
#   (4, 6)        5
#   (5, 6)        9

隣接行列用意してM回値を更新してもいいですが、めんどくさい & Nが大きくMが小さいとき殆どが不要な情報(疎行列)となるのでこっちのほうが効率的かなと思います。何より短く書ける。

 使い方-1 dijkstra

 関数の形はdijkstra(グラフ, 有向 or 無向, 始点)で、返り値は numpy.ndarray です。

  • グラフ:上で作ったもの、あるいは隣接行列
  • 有向・無向:directed = Trueで有向、Falseで無向グラフ
  • 始点:indicesで表す(0 - index)

 下記の例では、無向グラフで頂点0からの最短距離、有向グラフで頂点3からの最短距離を計算しています。行くことが出来ない頂点に関してはinfとなるみたいですね。

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import dijkstra

N, M = map(int, input().split())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
graph = csr_matrix((edge[2], (edge[:2] - 1)), (N, N))


ans = dijkstra(graph, directed = False, indices = 0)
print (type(ans))
print (ans)
print (dijkstra(graph, directed = True, indices = 3))

# 入力例
# 7 10
# 1 2 2
# 1 3 5
# 2 3 4
# 2 4 6
# 2 5 10
# 3 4 2
# 4 6 1
# 5 6 3
# 5 7 5
# 6 7 9

# 出力例
# <class 'numpy.ndarray'>
# [ 0.  2.  5.  7. 12.  8. 17.]
# [inf inf inf  0. inf  1. 10.]
使い方-2 floyd-warshall

 関数の形はfloyd_warshall(グラフ, 有向 or 無向)で、返り値は numpy.ndarray です。

  • グラフ:上で作ったもの、あるいは隣接行列
  • 有向・無向:directed = Trueで有向、Falseで無向グラフ

注意は重みが正の辺のみに対応していること*1。重み0の辺を作りたいなら、10 ** (-9)とかの重みにして、最後切り捨てでもするんか…?(試してません)

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall

N, M = map(int, input().split())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
graph = csr_matrix((edge[2], (edge[:2] - 1)), (N, N))

ans = floyd_warshall(graph, directed = False)
print (type(ans))
print (ans)

ans = floyd_warshall(graph, directed = True)
print (ans)

# 入力
# 7 10
# 1 2 2
# 1 3 5
# 2 3 4
# 2 4 6
# 2 5 10
# 3 4 2
# 4 6 1
# 5 6 3
# 5 7 5
# 6 7 9

# 出力
# <class 'numpy.ndarray'>
# [[ 0.  2.  5.  7. 11.  8. 16.]
#  [ 2.  0.  4.  6. 10.  7. 15.]
#  [ 5.  4.  0.  2.  6.  3. 11.]
#  [ 7.  6.  2.  0.  4.  1.  9.]
#  [11. 10.  6.  4.  0.  3.  5.]
#  [ 8.  7.  3.  1.  3.  0.  8.]
#  [16. 15. 11.  9.  5.  8.  0.]]

# [[ 0.  2.  5.  7. 12.  8. 17.]
#  [inf  0.  4.  6. 10.  7. 15.]
#  [inf inf  0.  2. inf  3. 12.]
#  [inf inf inf  0. inf  1. 10.]
#  [inf inf inf inf  0.  3.  5.]
#  [inf inf inf inf inf  0.  9.]
#  [inf inf inf inf inf inf  0.]]
使い方-3 shortest_path

関数の形は shortest_path(グラフ, 方法, 有向 or 無向, (始点))で、返り値は numpy.ndarray です。これの良いところは方法に"自動 (auto)"が設定できる点。いい感じに処理してくれるらしい。自分で指定することもできるし、これ一択の気がする。理解度が深まるかは別の話として強そう。

  • グラフ:上で作ったもの、あるいは隣接行列
  • 有向・無向:directed = Trueで有向、Falseで無向グラフ
  • 始点:indicesで表す(0 - index) -->dijkstra法のときは必要

ここまでは今までのと同じ。違うのは方法のところで、method = ****

  • 'auto':下記4種類から自動で選択 (明示的に選択も可能)
  • 'FW': floyd_warshall法 (O (N ^ 3))
  • 'D': dijsktra法 (O (N (N * k + N * log(N)) ))
  • 'BF': bellman_ford法 (O (N (N ^ 2 * k))) 負の辺があっても使えるのが特徴
  • 'J': Johnson's algorithm: 知りませんm(_ _)

kは各頂点から伸びている辺の平均なので N * k = Mが成り立つ。

import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import shortest_path

N, M = map(int, input().split())
edge = np.array([input().split() for _ in range(M)], dtype = np.int64).T
graph = csr_matrix((edge[2], (edge[:2] - 1)), (N, N))

ans = shortest_path(graph, method = 'auto', directed = False)
print (type(ans))
print (ans)

ans = shortest_path(graph, method = 'auto', directed = True)
print (ans)

# 入力
# 7 10
# 1 2 2
# 1 3 5
# 2 3 4
# 2 4 6
# 2 5 10
# 3 4 2
# 4 6 1
# 5 6 3
# 5 7 5
# 6 7 9

# 出力
# <class 'numpy.ndarray'>
# [[ 0.  2.  5.  7. 11.  8. 16.]
#  [ 2.  0.  4.  6. 10.  7. 15.]
#  [ 5.  4.  0.  2.  6.  3. 11.]
#  [ 7.  6.  2.  0.  4.  1.  9.]
#  [11. 10.  6.  4.  0.  3.  5.]
#  [ 8.  7.  3.  1.  3.  0.  8.]
#  [16. 15. 11.  9.  5.  8.  0.]]

# [[ 0.  2.  5.  7. 12.  8. 17.]
#  [inf  0.  4.  6. 10.  7. 15.]
#  [inf inf  0.  2. inf  3. 12.]
#  [inf inf inf  0. inf  1. 10.]
#  [inf inf inf inf  0.  3.  5.]
#  [inf inf inf inf inf  0.  9.]
#  [inf inf inf inf inf inf  0.]]

まとめ

特に改造が必要なければshortest_path(graph, method = 'auto')で使ったらいいんじゃないですかね…

課題

ついでに最短経路を返してくれる関数ないんすかんね…

2019年9月10日追記

maspyさんに、返り値に最短経路を追加させる方法を教えていただいたのでそれについてもまとめました。

kawap23.hatenablog.com

付録

 速度測定に使ったコード

N, Mを手動で変えながら測定。場合によっては、連結じゃないときもあると思います。

ランダムに選んだ10点からの最短距離を求めました。

import time
from random import randint

#グラフの作成
N = 10 ** 6 #頂点数
M = 10 ** 7 #辺の数

G_1 = [[] for _ in range(N)] #自前用 隣接リストで管理
G_2 = [] #scipy用

for _ in range(M):
    # a <--> b間のつなぐo重みcの無向グラフ
    a = randint(0, N - 1)
    b = randint(0, N - 1)
    c = randint(1, 10 ** 2)
    G_1[a].append([c, b])
    G_1[b].append([c, a])
    G_2.append([a, b, c])

#始点として調べる点
lst = [randint(0, N - 1) for _ in range(10)]

#自前の実装 (参考: 蟻本)
def dijksrea_manu(s): #始点s
    from heapq import heappop, heappush
    INF = 10 ** 9
    d = [INF] * N
    d[s] = 0 #始点の距離を0にする
    pque = []
    heappush(pque, [0, s]) #要素を[距離、頂点]として管理 最初の位置を入れる

    while len(pque) != 0: #queの中に要素が残っている時
        p = heappop(pque) #最も距離が短いものを取り出す
        v = p[1] #距離が最も短い頂点
        if d[v] < p[0]: #取り出した値より既に小さい値がdに入っているときは無視して次へ
            continue
        for i in range(len(G_1[v])): #頂点vの隣接リストを走査
            e = G_1[v][i]
            if d[e[1]] > d[v] + e[0]: #距離が更新できるかを検討
                d[e[1]] = d[v] + e[0]
                heappush(pque, [d[e[1]], e[1]]) #更新できた場合、その値をpqueに入れる
    return d

s_mine = time.time()

for i in lst: #始点10点調べる
    d_1 = dijksrea_manu(i)

f_mine = time.time()

#scipy dijkustra
s_scipy = time.time()
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import dijkstra

G_2_mat = np.array(G_2, dtype = np.int64).T
graph = csr_matrix((G_2_mat[2], (G_2_mat[:2])), (N, N))

h_scipy = time.time()

for i in lst:
    d_2 = dijkstra(graph, directed = False, indices = i)

f_scipy = time.time()

#scipy shortest_path auto
s_scipy_auto = time.time()
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import shortest_path

G_3_mat = np.array(G_2, dtype = np.int64).T
graph = csr_matrix((G_3_mat[2], (G_3_mat[:2])), (N, N))

h_scipy_auto = time.time()

for i in lst:
    d_3 = shortest_path(graph, method = 'auto', directed = False, indices = i)

f_scipy_auto = time.time()


# print (d_1)
# print (d_2)
# print (d_3)

print ('N = ', N)
print ('M = ', M)

print ('自前の実装: ', f_mine - s_mine)
print ('import含むscipy全体 (dijkstra): ', f_scipy - s_scipy)
print ('scipyのdijkstraの計算 (dijkstra): ', f_scipy - h_scipy)
print ('import含むscipy全体 (shortest_path(auto)): ', f_scipy_auto - s_scipy_auto)
print ('scipyのdijkstraの計算 (shortest_path(auto)): ', f_scipy_auto - h_scipy_auto)