web-dev-qa-db-ja.com

ユーザー定義メトリックによるSklearn kNNの使用

現在、私はkNNアルゴリズムを使用して特定のポイントの上位k最近傍を検索する必要があるプロジェクトを実行しています。たとえば、P。imはpython、sklearnパッケージを使用してジョブを実行していますが、定義済みのメトリックはこれらのデフォルトの1つではありません。メトリック。 herehereで見つけることができるsklearnのドキュメントから、ユーザー定義メトリックを使用する必要があります

Sklearn kNNの最新バージョンはユーザー定義メトリックをサポートしているようですが、それを使用する方法が見つかりません。

import sklearn
from sklearn.neighbors import NearestNeighbors
import numpy as np
from sklearn.neighbors import DistanceMetric
from sklearn.neighbors.ball_tree import BallTree
BallTree.valid_metrics

mydist = max(x-y)というメトリックを定義したとしたら、DistanceMetric.get_metricを使用してそれをDistanceMetricオブジェクトにします。

dt=DistanceMetric.get_metric('pyfunc',func=mydist)

ドキュメントから、行は次のようになります

nbrs = NearestNeighbors(n_neighbors=4, algorithm='auto',metric='pyfunc').fit(A)
distances, indices = nbrs.kneighbors(A)

しかし、どこにdtを置けますか?ありがとう

26
user2926523

メトリックをmetric paramとして渡し、追加のメトリック引数をキーワードパラメタとしてNNコンストラクタに渡します。

>>> def mydist(x, y):
...     return np.sum((x-y)**2)
...
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])

>>> nbrs = NearestNeighbors(n_neighbors=4, algorithm='ball_tree',
...            metric='pyfunc', func=mydist)
>>> nbrs.fit(X)
NearestNeighbors(algorithm='ball_tree', leaf_size=30, metric='pyfunc',
         n_neighbors=4, radius=1.0)
>>> nbrs.kneighbors(X)
(array([[  0.,   1.,   5.,   8.],
       [  0.,   1.,   2.,  13.],
       [  0.,   2.,   5.,  25.],
       [  0.,   1.,   5.,   8.],
       [  0.,   1.,   2.,  13.],
       [  0.,   2.,   5.,  25.]]), array([[0, 1, 2, 3],
       [1, 0, 2, 3],
       [2, 1, 0, 3],
       [3, 4, 5, 0],
       [4, 3, 5, 0],
       [5, 4, 3, 0]]))
31
alko

前の回答への小さな追加。 追加の引数をとるユーザー定義メトリックの使用方法。

>>> def mydist(x, y, **kwargs):
...     return np.sum((x-y)**kwargs["metric_params"]["power"])
...
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> Y = np.array([-1, -1, -2, 1, 1, 2])
>>> nbrs = KNeighborsClassifier(n_neighbors=4, algorithm='ball_tree',
...            metric=mydist, metric_params={"power": 2})
>>> nbrs.fit(X, Y)
KNeighborsClassifier(algorithm='ball_tree', leaf_size=30,                                                                                                                                                          
       metric=<function mydist at 0x7fd259c9cf50>, n_neighbors=4, p=2,
       weights='uniform')
>>> nbrs.kneighbors(X)
(array([[  0.,   1.,   5.,   8.],
       [  0.,   1.,   2.,  13.],
       [  0.,   2.,   5.,  25.],
       [  0.,   1.,   5.,   8.],
       [  0.,   1.,   2.,  13.],
       [  0.,   2.,   5.,  25.]]),
 array([[0, 1, 2, 3],
       [1, 0, 2, 3],
       [2, 1, 0, 3],
       [3, 4, 5, 0],
       [4, 3, 5, 0],
       [5, 4, 3, 0]]))
14
Mahmoud