Loss
Losses
- class mindnlp.common.loss.CMRC2018Loss(reduction='mean')[源代码]
基类:
Cell用于计算CMRC2018中文问答任务
- 参数
reduction (str) – 计算loss的方式,候选有 mean 和 sum. 默认:mean
- construct(target_start, target_end, context_len, pred_start, pred_end)[源代码]
计算 CMRC2018Loss
- 参数
target_start (Tensor) – size: batch_size, dtype: int.
target_end (Tensor) – size: batch_size, dtype: int.
context_len (Tensor) – size: batch_size, dtype: float.
pred_start (Tensor) – size: batch_size*max_len, dtype: float.
pred_end (Tensor) – size: batch_size*max_len, dtype: float.
- 返回
Tensor, 计算后的 CMRC2018Loss
- 抛出
ValueError – 计算方式 reduction 没有选择 sum 或 mean
示例
>>> cmrc_loss = CMRC2018Loss() >>> tensor_a = mindspore.Tensor(np.array([1, 2, 1]), mindspore.int32) >>> tensor_b = mindspore.Tensor(np.array([2, 1, 2]), mindspore.int32) >>> my_context_len = mindspore.Tensor(np.array([2., 1., 2.]), mindspore.float32) >>> tensor_c = mindspore.Tensor(np.array([ >>> [0.1, 0.2, 0.1], >>> [0.1, 0.2, 0.1], >>> [0.1, 0.2, 0.1] >>> ]), mindspore.float32) >>> tensor_d = mindspore.Tensor(np.array([ >>> [0.2, 0.1, 0.2], >>> [0.2, 0.1, 0.2], >>> [0.2, 0.1, 0.2] >>> ]), mindspore.float32) >>> my_loss = cmrc_loss(tensor_a, tensor_b, my_context_len, tensor_c, tensor_d) >>> print(my_loss)
- class mindnlp.common.loss.RDropLoss(reduction='none')[源代码]
基类:
CellR-Drop Loss 的实现。更多关于R-drop的信息请参考这篇文章:https://arxiv.org/abs/2106.14448
原始实现请参考这里的代码:https://github.com/dropreg/R-Drop
- 参数
reduction (str) –
计算loss的方式,候选有 none, batchmean, mean, sum。默认: none
mean:将返回降维后的loss均值
batchmean:将返回降维后的loss批次均值
sum:将返回降维后的loss求和
none:不采取任何降维方法
- construct(p, q, pad_mask=None)[源代码]
返回p和q的rdrop loss
- 参数
p (Tensor) – 训练样本的第一次前向向量
q (Tensor) – 训练样本的第二次前向向量
pad_mask (Tensor) – 包含要索引的二进制掩码的张量Tensor,其数据类型为 bool。默认: None
- 返回
Tensor, p和q的rdrop loss
- 抛出
ValueError – 计算方式 reduction 不是 sum , mean , batchmean 或 none 的其中一个。
示例
>>> r_drop_loss = RDropLoss() >>> p = Tensor(np.array([1., 0. , 1.]), mindspore.float32) >>> q = Tensor(np.array([0.2, 0.3 , 1.1]), mindspore.float32) >>> loss = r_drop_loss(p, q) >>> print(loss) 0.100136