当前位置 博文首页 > 不太冷的莱昂的博客:cs231n作业之KNN分器

    不太冷的莱昂的博客:cs231n作业之KNN分器

    作者:[db:作者] 时间:2021-09-02 16:33

    ?

    下面只会贴上需要自己完成的地方

    k_nearest_neighbor.py

    def compute_distances_two_loops(self, X):
    
            num_test = X.shape[0]
            num_train = self.X_train.shape[0]
            dists = np.zeros((num_test, num_train))
            for i in range(num_test):
                for j in range(num_train):
                    #####################################################################
                    # TODO:                                                             #
                    # Compute the l2 distance between the ith test point and the jth    #
                    # training point, and store the result in dists[i, j]. You should   #
                    # not use a loop over dimension, nor use np.linalg.norm().          #
                    #####################################################################
                    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
                     dists=np.sqrt(mp.sum(np.square(X[i]-X_train[j])))
    
                    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
            return dists
    
    def compute_distances_one_loop(self, X):
           
            num_test = X.shape[0]
            num_train = self.X_train.shape[0]
            dists = np.zeros((num_test, num_train))
            for i in range(num_test):
                #######################################################################
                # TODO:                                                               #
                # Compute the l2 distance between the ith test point and all training #
                # points, and store the result in dists[i, :].                        #
                # Do not use np.linalg.norm().                                        #
                #######################################################################
                # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
                 dists[i,:]=np.sqrt(np.sum(np.square(X[i]-self.X_train),axis=1))
    
                # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
            return dists
    def compute_distances_no_loops(self, X):
           
            num_test = X.shape[0]
            num_train = self.X_train.shape[0]
            dists = np.zeros((num_test, num_train))
            #它的注释的意思是说不用显式的循环,而是通过一个矩阵乘法和两个numpy的广播机制进行矩阵相加完成 
            #运算
            # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
            ab = np.dot(X,self.X_train.T)
            a2 = np.sum(np.square(X),axis=1).reshape(-1,1)
            b2 = np.sum(np.square(self.X_train.T),axis=0).reshape(1,-1)
            dists = -2 * ab +a2 + b2 
            dists = np.sqrt(dists)
             
    
            # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
            return dists
    def predict_labels(self, dists, k=1):
           
            num_test = dists.shape[0]
            y_pred = np.zeros(num_test)
            for i in range(num_test):
                # A list of length k storing the labels of the k nearest neighbors to
                # the ith test point.
                closest_y = []
             
                # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
                closest_y=self.y_train[np.argsort(dists[i],axis=1)[0:k]]
    
                # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
               
                # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
                y_pred[i] = np.argmax(np.bincount(closest_y)) #统计结果的票数
    
                # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
            return y_pred

    这里有一个在线问题

    Inline Question 1

    Notice the structured patterns in the distance matrix, where some rows or columns are visible brighter. (Note that with the default color scheme black indicates low distances while white indicates high distances.)

    • What in the data is the cause behind the distinctly bright rows?
    • What causes the columns?

    Y𝑜𝑢𝑟𝐴𝑛𝑠𝑤𝑒𝑟:1.测试样例与训练集的所有图片都有很大不同,即L2欧氏距离远 2.训练集的一张图片与所有测试样例都不太相似

    k_nearest_neighbor.py

    Cross-validation

    num_folds = 5
    k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]
    
    X_train_folds = []
    y_train_folds = []
    
    ################################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    X_train_folds=np.split(X_train, num_folds, axis = 0)
    y_train_folds=np.split(y_train, num_folds, axis = 0)
    
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    k_to_accuracies = {}
    
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    for k in k_choices:
        accuracies = []
        for i in range(num_folds):
            X_train_1 = np.vstack(X_train_folds[0:i] + X_train_folds[i+1:])
            y_train_1 = np.hstack(y_train_folds[0:i] + y_train_folds[i+1:])
            X_valid = X_train_folds[i]
            y_valid = y_train_folds[i]
            
            classifier.train(X_train_1,y_train_1)
            dists = classifier.compute_distances_no_loops(X_valid)
            y_test_pred = classifier.predict_labels(dists,k)
            num_correct = np.sum(y_test_pred == y_valid)
            accuracy = float(num_correct) / y_valid.shape[0]
            accuracies.append(accuracy)
        
        k_to_accuracies[k]= accuracies
                 
    
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    
    # Print out the computed accuracies
    for k in sorted(k_to_accuracies):
        for accuracy in k_to_accuracies[k]:
            print('k = %d, accuracy = %f' % (k, accuracy))
    交叉验证结果(对于每个K值跑五次)
    k = 1, accuracy = 0.263000
    k = 1, accuracy = 0.257000
    k = 1, accuracy = 0.264000
    k = 1, accuracy = 0.278000
    k = 1, accuracy = 0.266000
    k = 3, accuracy = 0.239000
    k = 3, accuracy = 0.249000
    k = 3, accuracy = 0.240000
    k = 3, accuracy = 0.266000
    k = 3, accuracy = 0.254000
    k = 5, accuracy = 0.248000
    k = 5, accuracy = 0.266000
    k = 5, accuracy = 0.280000
    k = 5, accuracy = 0.292000
    k = 5, accuracy = 0.280000
    k = 8, accuracy = 0.262000
    k = 8, accuracy = 0.282000
    k = 8, accuracy = 0.273000
    k = 8, accuracy = 0.290000
    k = 8, accuracy = 0.273000
    k = 10, accuracy = 0.265000
    k = 10, accuracy = 0.296000
    k = 10, accuracy = 0.276000
    k = 10, accuracy = 0.284000
    k = 10, accuracy = 0.280000
    k = 12, accuracy = 0.260000
    k = 12, accuracy = 0.295000
    k = 12, accuracy = 0.279000
    k = 12, accuracy = 0.283000
    k = 12, accuracy = 0.280000
    k = 15, accuracy = 0.252000
    k = 15, accuracy = 0.289000
    k = 15, accuracy = 0.278000
    k = 15, accuracy = 0.282000
    k = 15, accuracy = 0.274000
    k = 20, accuracy = 0.270000
    k = 20, accuracy = 0.279000
    k = 20, accuracy = 0.279000
    k = 20, accuracy = 0.282000
    k = 20, accuracy = 0.285000
    k = 50, accuracy = 0.271000
    k = 50, accuracy = 0.288000
    k = 50, accuracy = 0.278000
    k = 50, accuracy = 0.269000
    k = 50, accuracy = 0.266000
    k = 100, accuracy = 0.256000
    k = 100, accuracy = 0.270000
    k = 100, accuracy = 0.263000
    k = 100, accuracy = 0.256000
    k = 100, accuracy = 0.263000

    对交叉验证的结果可视化

    Inline Question 3

    Which of the following statements about?𝑘k-Nearest Neighbor (𝑘k-NN) are true in a classification setting, and for all?𝑘k? Select all that apply.

    1. The decision boundary of the k-NN classifier is linear.
    2. The training error of a 1-NN will always be lower than that of 5-NN.
    3. The test error of a 1-NN will always be lower than that of a 5-NN.
    4. The time needed to classify a test example with the k-NN classifier grows with the size of the training set.
    5. None of the above.

    Y𝑜𝑢𝑟𝐴𝑛𝑠𝑤𝑒𝑟:YourAnswer:5

    Y𝑜𝑢𝑟𝐸𝑥𝑝𝑙𝑎𝑛𝑎𝑡𝑖𝑜𝑛:YourExplanation:1.K-NN的决策边界不是线性的 2.两者的训练误差没有这种一定的大小关系 3.两者的测试误差没有这种一定的大小关系 4.训练集的大小不影响测试的时间花销,训练集是用来生成模型的

    cs
    下一篇:没有了