Tensor.kthvalueの使い方について公式のDocumentではよく分からなかったので自分で動作確認した結果をまとめる。
・公式ドキュメント
本記事のテーマ
「Tensor.kthvalueの動作概要」
について実動作確認した結果をまとめる。
関数の説明
x.kthvalue(n, dim=1, keepdim=True)で動作確認をおこなった。
kthvalueはn番目に小さな値を抽出する関数です。
dimで選択した次元を探索します。
動作確認
動作確認した結果を紙にまとめました。
例として作ったTensorは
x: Tensor
x.shape: (2, 3, 6)
x[0]:
[[ 1, 2, 3, 1000, 5, 6]
[ 7, 8, 9, 10, 11, 12]
[13, 14, 15, 16, 17, 1000]]
x[1]:
[[ 19, 20, 21, 22, 23, 24]
[ 25, 26, 27, 28, 29, 30]
[ 31, 32, 33, 34, 1000, 36]]
x.kthvalue(3, dim=1, keepdim=True)を実行すると、写真の左の図に記載されているように1次元目を縦に探索して3番めに小さい数字、1次元目のサイズは3なため、最も大きい数字を探索していることになる。(x[0次元目][1次元目][2次元目])
結果は3番めに小さい数字とその数字が格納されているIndexの配列をTupleとして出力してくれる。
torch.return_types.kthvalue(
values=tensor([ 13., 14., 15., 1000., 17., 1000.,
31., 32., 33., 34., 1000., 36.]),
indices=tensor([2, 2, 2, 0, 2, 2,
2, 2, 2, 2, 2, 2]))、
■indecesの見方
dimで1次元目を探索するようにしているため、x[1次元目]に当たるIndexの数字が出力される。
つまり、唯一0になっている箇所を確認すると、x[0][0][3]=1000で1次元目のIndexは0なので出力結果のindecesも0になる。他も同様に確認すると全て2になることが分かる。