点击左侧箭头调整阅读页面大小

关于CS231N-Assignment1-KNN中no-loop矩阵乘法代码的讲解

在使用无循环的算法进行计算距离的效率是很高的
可以看到No loop算法使用的时间远远小于之前两种算法

Two loop version took 56.785069 seconds
One loop version took 136.449761 seconds
No loop version took 0.591535 seconds   #很快!

实现代码主要为以下这一段:
其中X为500×3072的矩阵(测试矩阵)
X_train为5000×3072的矩阵(训练矩阵)
dists 为500×5000的矩阵(距离矩阵)
题中的目的就是将X中每一行的像素数值与X_train中每一行的像素数值(3072个)进行距离运算得出欧氏距离(L2)再储存到dists中
核心公式

test_sum = np.sum(np.square(X), axis=1)  # num_test x 1
train_sum = np.sum(np.square(self.X_train), axis=1)  # num_train x 1
inner_product = np.dot(X, self.X_train.T)  # num_test x num_train
dists = np.sqrt(-2 * inner_product + test_sum.reshape(-1, 1) + train_sum)  # broadcast

公式讲解:
假设现在有三个矩阵:A(X)、B(X_train)、C(dists )
将维数缩小以方便操作,稍微进行推导,就可以得出上面的公式了
推导过程如下:
《关于CS231N-Assignment1-KNN中no-loop矩阵乘法代码的讲解》

  点赞
本篇文章采用 知识共享署名-相同方式共享 4.0 国际许可协议 进行许可
转载请务必注明来源: https://oldpan.me/archives/knn-no-loop

   欢迎关注Oldpan博客微信公众号,同步更新博客深度学习文章。


说点什么吧(邮箱、网址选填)

avatar
280
  订阅  
提醒