首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >SVM损失函数

SVM损失函数
EN

Stack Overflow用户
提问于 2021-03-22 04:42:33
回答 1查看 327关注 0票数 1
代码语言:javascript
复制
def svm_loss_naive(W, X, y):
"""
SVM loss function, naive implementation calculating loss for each sample
using loops.    

Inputs:
- X: A numpy array of shape (n, m) containing data(samples).
- y: A numpy array of shape (m, ) containing labels
- W: A numpy array of shape (p, n) containing weights.  
   
"""
# Compute the loss
num_classes = W.shape[0] # classes weights are in row wise fashion
num_samples = X.shape[1] # samples of unknown images are in column-wise fashion
loss = 0.0
delta = 1 # SVM parameter
for i in range(num_samples):
    scores = np.dot(W, X[:,i])
    correct_class_score = scores[y[i]] 
    for j in range(num_classes):       
        if j == y[i]:
            continue
        margin = max(0, scores[j] - correct_class_score + delta )
        loss = loss + margin

# Average loss
loss = loss / num_samples

return loss

根据我对python代码的理解

  1. 我们首先通过将第1行的权重乘以第1样本列

来计算第1类的得分。

然后,我们将获取存储在数组y

中的ith示例的correct_class_score。

  1. ,那么,我们在迭代类的数量(假设是3),我不明白的是j == yi在做什么?我的意思是,当j在0到2之间时,当j等于yi时,yi只是ith示例correct_class_score

的索引。

我预先理解的其余代码

EN

回答 1

Stack Overflow用户

发布于 2022-11-06 18:15:35

这是SVM丢失(数据丢失)函数的定义:

在内部求和中,明确排除了等于易的j指数。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66740435

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档