KNeighborsClassifierにおけるmetric=”precomputed”の使い方

本記事では主に以下の点について説明したいと思います。

本記事を経て、最終的には以下の2つの記事も執筆予定です。

パッケージのインポート

初めに必要なパッケージのインポートをまとめておきます。使用するライブラリはscipyとscikit learnです。

from sklearn.neighbors import KNeighborsClassifier
from scipy.sparse import coo_matrix
from scipy.spatial.distance import minkowski

metric=”precomputed”の使い方

KNeighborsClassifierの使い方

KNeighborsClassifierのmetric=”precomputed”については解説されている記事があまりなかったので、使い方を理解するのに時間がかかりました。今回のように自作の距離に変わるものを利用するときに便利な機能だと思いますので、ぜひこの記事で使い方を学んでみてください。

まずはKNeighborsClassifierの使い方を復習しておきます。scikit learnのページを参考に、簡単なサンプルを以下に記載します。

def kneighbors_ex():
    X = [[0], [1], [2], [3]]
    y = [0, 0, 1, 1]
    sample = [[1.1]]

    neigh = KNeighborsClassifier(n_neighbors=3, metric="minkowski")
    neigh.fit(X, y)
    print(neigh.predict(sample))
    print(neigh.kneighbors(sample))

    # >>> [0]
    # >>> (array([[0.1, 0.9, 1.1]]), array([[1, 2, 0]]))

KNeighborsClassifierは引数であるn_neighborsで使用する近傍点の数、metricで距離算出に使用する方法を指定できます。minkowskiはmetricのデフォルト値となっている手法です。

fitメソッドでは、Xとして座標を、yとして各座標のラベルを渡しています。

その後、predictメソッドで1.1の座標のラベルを予測しています。予測結果は0と返ってきました。0, 1がラベル0で2, 3がラベル1なので0と1の間の1.1がラベル0なので問題ないですね。

最後にkneighborsで近傍点までの距離および近傍点の座標を求めています。1番近いのが座標が1の点でそこまでの距離が0.1、2番目に近い座標が2でそこまでの距離が0.9と計算されています。

metric=callableの使い方

KNeighborsClassifierの使い方の復習ができたので、次はmetric=callableの使い方を見ていきます。ちなみにcallableは関数が入るという意味で書いています。

metric=”minkowski”を指定した時は、scikit learnの内部でscipy.spatial.distance.minkowskiを使って距離計算を行っています。callableの使い方が合っているかどうかを調べるために、まずはmetricでscipy.spatial.distance.minkowskiを指定して算出を行ってみたいと思います。

def kneighbors_callable_ex():
    X = [[0], [1], [2], [3]]
    y = [0, 0, 1, 1]
    sample = [[1.1]]

    print(minkowski(X[0], X[2]))
    # >>> 2.0

    neigh = KNeighborsClassifier(n_neighbors=3, metric=minkowski)
    neigh.fit(X, y)
    print(neigh.predict(sample))
    print(neigh.kneighbors(sample))

    # >>> [0]
    # >>> (array([[0.1, 0.9, 1.1]]), array([[1, 2, 0]]))

print(minkowski(X[0], X[2]))はminkowski関数の使い方を簡単に記載してるものになります。第一引数と第二引数の距離を計算しています。このminkowski関数をmetricで指定し計算した結果と、先ほどのmetric=”minkowski”の結果を比較すると全く同じものであることがわかります。

このようにmetricで関数を指定する時は第一引数と第二引数の距離を算出するものをしていすればいいことがわかりました。

metric=”precomputed”の使い方

metric=”precomputed”を指定した場合は、その名前の通り距離をあらかじめ計算し、fitの第一引数としてモデルに渡す必要があります。

また、scikit learnのドキュメントによると計算した距離は非ゼロ要素のsparse graphである必要があるとのことです。このsparse graphとは何なのでしょうか。こちらもscikit learnのドキュメントによると「ほとんどすべての要素がゼロである、対応する高密度の numpy 配列よりもメモリ効率の高い 2 次元数値データの表現。」だそうです。さらに続けてscipy.sparseフレームワークを使用していると記載されています。このページから、どうやら事前に計算した距離はscipy.sparseフレームワークを用いてsparse graphとして表現すればいいということがわかりました。

以上の情報から作成したサンプルコードを以下に記載します。

def kneighbors_precomputed_fit_ex():
    X = [[0], [1], [2], [3]]
    y = [0, 0, 1, 1]

    distances = []
    for i in range(len(X)):
        distance = []
        for j in range(len(X)):
            d = minkowski(X[i], X[j])
            distance.append(d)
        distances.append(distance)
    
    print(distances)
    # >>> [[0.0, 1.0, 2.0, 3.0], [1.0, 0.0, 1.0, 2.0], [2.0, 1.0, 0.0, 1.0], [3.0, 2.0, 1.0, 0.0]]

    coo = coo_matrix(distances)
    print(coo)
    # >>> (0, 1)        1.0
    # >>> (0, 2)        2.0
    # >>> (0, 3)        3.0
    # >>> (1, 0)        1.0
    # >>> (1, 2)        1.0
    # >>> (1, 3)        2.0
    # >>> (2, 0)        2.0
    # >>> (2, 1)        1.0
    # >>> (2, 3)        1.0
    # >>> (3, 0)        3.0
    # >>> (3, 1)        2.0
    # >>> (3, 2)        1.0
    
    neigh = KNeighborsClassifier(n_neighbors=3, metric="precomputed")
    neigh.fit(coo, y)

distancesは、全ての点との距離を要素とする二次元配列となっています。この配列をscipy.sparse.coo_matrixでsparce graphに変換すると、コード内にコメントで書いているような構造に変換されます。これにより事前の距離計算は完了で、この変数cooをfitの第一引数として渡し、第二引数には今まで通りyを入れることで学習を行うことができます。

学習が完了したので次は推論を行います。推論時にも同じように、推論したい点と学習用データの全ての点との距離を計算し、sparce graphとしてpredictメソッドに渡す必要があります。学習コードに以下のコードを追記することで推論を行うことが可能となります。

    sample = [[1.1]]
    pred_distances = []
    for i in range(len(sample)):
        distance = []
        for j in range(len(X)):
            d = minkowski(sample[i], X[j])
            distance.append(d)
        pred_distances.append(distance)

    s = coo_matrix(pred_distances)
    print(neigh.predict(s))
    print(neigh.kneighbors(s))

    # >>> [0]
    # >>> (array([[0.1, 0.9, 1.1]]), array([[1, 2, 0]]))

ラベル、距離、近傍点の全てがこれまでの値と一致していることがわかります。metric=”precomputed”のサンプルコードとしてはこれでよさそうです。