计算机系统应用教程网站

网站首页 > 技术文章 正文

如何评估预测值与真实值之间的匹配质量

btikc 2024-10-23 09:14:48 技术文章 14 ℃ 0 评论

1 代码实现

来自opencompass/opencompass/datasets/teval/evaluators/planning_evaluator.py
这段代码是一个比较复杂的过程,用于评估预测计划与真实计划之间的匹配质量。它包括多个关键步骤:计算相似度分数、创建匹配图、执行匹配算法、生成匹配映射,以及使用最长递增子序列(LIS)算法来计算正确匹配的节点数。

预测值如下
“predictions”: “[{\n “id”: 0,\n “name”: “AirbnbSearch.search_property_by_place”,\n “args”: “{‘place’: ‘柏林’}”\n}, {\n “id”: 1,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 1}”\n}, {\n “id”: 2,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 2}”\n}, {\n “id”: 3,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 3}”\n}, {\n “id”: 4,\n “name”: “ArxivSearch.get_arxiv_article_information”,\n “args”: “{‘query’: ‘太阳能’}”\n}]”,
真实值如下
“references”: [
{
“name”: “AirbnbSearch.search_property_by_place”,
“id”: 0,
“args”: “{‘place’: ‘Berlin’}”
},
{
“name”: “AirbnbSearch.get_property_reviews”,
“id”: 1,
“args”: “{‘propertyId’: 36339325}”
},
{
“name”: “AirbnbSearch.get_property_details”,
“id”: 2,
“args”: “{‘propertyId’: 45475252}”
},
{
“name”: “AirbnbSearch.get_property_checkout_price”,
“id”: 3,
“args”: “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”
},
{
“name”: “ArxivSearch.get_arxiv_article_information”,
“id”: 4,
“args”: “{‘query’: ‘solar energy’}”
},
{
“name”: “ArxivSearch.get_arxiv_article_information”,
“id”: 5,
“args”: “{‘query’: ‘solar energy’}”
},
{
“name”: “FinishAction”,
“id”: 6,
“args”: “{}”
}
],
最终匹配结果如下
“precision”: 0.8

def bertscore_match(self, pred_plan, gt_plan) -> dict:
        """
            Calculate the similarity between predicted plan and golden answer,
            A plan can be regarded a sequence of actions, and each action has a name and args.
            Firstly, use bertscore to calculate pointwise similarity by:
                similarity(u, v) = bertscore(u.name, v.name) * name_weight + bertscore(u.args, v.args) * args_weight;
            Secondly, use Hungarian matching to match the points;
            Finally, use LIS to calculate the number of matched nodes.
        """
        if len(pred_plan) == 0 or len(gt_plan) == 0:
            return {
                'precision': 0,
                'recall': 0,
                'f1_score': 0
            }

        pred_plan = copy.deepcopy(sorted(pred_plan, key=lambda x: x['id']))
        gt_plan = copy.deepcopy(sorted(gt_plan, key=lambda x: x['id']))

        #Add end action
        #Currently it is hard-code
        if pred_plan[-1]['name'] == 'FinishAction':
            pred_plan = pred_plan[:-1]
        if gt_plan[-1]['name'] == 'FinishAction':
            gt_plan = gt_plan[:-1]
        #The total counts of nodes and edges.
        len_pred = len(pred_plan)
        len_gt = len(gt_plan)

        bert_score_matrix = np.zeros((len_pred, len_gt))
        name_pred, args_pred = [], []
        name_gt, args_gt = [], []
        for i in range(len_pred):
            name_pred.append(pred_plan[i]['name'])
            args_pred.append(str(pred_plan[i]['args']))
        for i in range(len_gt):
            name_gt.append(gt_plan[i]['name'])
            args_gt.append(str(gt_plan[i]['args']))
        
        name_pred_emb = self.sentence_model.encode(name_pred, convert_to_tensor=True)
        name_gt_emb = self.sentence_model.encode(name_gt, convert_to_tensor=True)
        args_pred_emb = self.sentence_model.encode(args_pred, convert_to_tensor=True)
        args_gt_emb = self.sentence_model.encode(args_gt, convert_to_tensor=True)
        name_cosine_scores = np.maximum(util.cos_sim(name_pred_emb, name_gt_emb).cpu().numpy(), 0)
        args_cosine_scores = np.maximum(util.cos_sim(args_pred_emb, args_gt_emb).cpu().numpy(), 0)
        for i in range(len_pred):
            for j in range(len_gt):
                bert_score_matrix[i][j] = \
                    name_cosine_scores[i][j] * self.name_weight \
                    + args_cosine_scores[i][j] * self.args_weight
        G = nx.Graph()
        for i in range(len_pred):
            for j in range(len_gt):
                if bert_score_matrix[i][j] > self.match_threshold:
                    G.add_edge(i, str(j), weight=bert_score_matrix[i][j])
        max_weight_matching = nx.max_weight_matching(G)

        pred_to_gt_mapping = dict()
        for key in max_weight_matching:
            if type(key[0]) == int:
                pred_to_gt_mapping[int(key[0])] = int(key[1])
            else:
                pred_to_gt_mapping[int(key[1])] = int(key[0])

        #If a prediction node does not match any golden answer node, we mark the node as -1.
        for i in range(len_pred):
            if i not in pred_to_gt_mapping:
                pred_to_gt_mapping[i] = -1
        #Calculate how many nodes are matched by Longest Increasing Subsequence (LIS)
        dp = np.ones(len_pred)
        for i in range(len_pred):
            for j in range(i):
                if pred_to_gt_mapping[i] == -1 or pred_to_gt_mapping[j] == -1:
                    continue
                if pred_to_gt_mapping[i] > pred_to_gt_mapping[j]:
                    dp[i] = max(dp[i], dp[j] + 1)
        correct_count = int(max(dp))

        precision =  correct_count / len(pred_plan)

代码中关键变量的取值示例

gt_plan [{‘name’: ‘AirbnbSearch.search_property_by_place’, ‘id’: 0, ‘args’: “{‘place’: ‘Berlin’}”}, {‘name’: ‘AirbnbSearch.get_property_reviews’, ‘id’: 1, ‘args’: “{‘propertyId’: 36339325}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 2, ‘args’: “{‘propertyId’: 45475252}”}, {‘name’: ‘AirbnbSearch.get_property_checkout_price’, ‘id’: 3, ‘args’: “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 4, ‘args’: “{‘query’: ‘solar energy’}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 5, ‘args’: “{‘query’: ‘solar energy’}”}]

pred_plan [{‘name’: ‘AirbnbSearch.search_property_by_place’, ‘id’: 0, ‘args’: “{‘place’: ‘柏林’}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 1, ‘args’: “{‘propertyId’: 1}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 2, ‘args’: “{‘propertyId’: 2}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 3, ‘args’: “{‘propertyId’: 3}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 4, ‘args’: “{‘query’: ‘太阳能’}”}]

args_gt 和 args_pred 分别是真实参数列表和预测参数列表。
name_gt 和 name_pred 分别是真实动作名称列表和预测动作名称列表。

args_gt [“{‘place’: ‘Berlin’}”, “{‘propertyId’: 36339325}”, “{‘propertyId’: 45475252}”, “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”, “{‘query’: ‘solar energy’}”, “{‘query’: ‘solar energy’}”]

args_pred [“{‘place’: ‘柏林’}”, “{‘propertyId’: 1}”, “{‘propertyId’: 2}”, “{‘propertyId’: 3}”, “{‘query’: ‘太阳能’}”]

name_gt [‘AirbnbSearch.search_property_by_place’, ‘AirbnbSearch.get_property_reviews’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_checkout_price’, ‘ArxivSearch.get_arxiv_article_information’, ‘ArxivSearch.get_arxiv_article_information’]

name_pred [‘AirbnbSearch.search_property_by_place’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_details’, ‘ArxivSearch.get_arxiv_article_information’]

max_weight_matching {(‘5’, 4), (‘1’, 3), (2, ‘3’), (‘2’, 1), (0, ‘0’)}

pred_to_gt_mapping {4: 5, 3: 1, 2: 3, 1: 2, 0: 0}

dp array([1., 2., 3., 2., 4.])
precision 0.8

2 计算 BertScore 矩阵

  1. 双层循环:通过两层循环,遍历预测列表(len_pred)和真实列表(len_gt)的每个元素。这是为了比较每个预测元素与每个真实元素的相似度。
  2. 相似度计算:使用余弦相似度分数(name_cosine_scoresargs_cosine_scores),结合设定的权重(name_weightargs_weight),计算出预测和真实元素间的综合相似度。这个分数存储在 bert_score_matrix 矩阵中。

3 创建图并执行匹配

  1. 图的创建:使用网络X库创建一个图 G。图中的每个节点代表一个预测或真实元素。
  2. 添加边:根据 bert_score_matrix 中的相似度分数,如果分数超过预设的阈值(match_threshold),则在对应的预测节点和真实节点之间添加一条边。边的权重是相似度分数。
  3. 执行匹配:使用网络X的 max_weight_matching 方法找出图中的最大权重匹配,这个方法会尝试找出一种匹配方式,使得所有匹配的权重总和最大。

4 生成匹配映射

  1. 创建映射表:通过解析匹配结果,创建一个映射表 pred_to_gt_mapping,记录每个预测节点所匹配到的真实节点索引。
  2. 未匹配节点标记:遍历所有预测节点,如果某个节点没有在匹配结果中,将其映射值标记为 -1,表示该预测节点没有找到对应的匹配。

5 计算正确匹配的节点数

  1. 初始化动态规划数组dp 数组用于存储每个节点为终点的最长递增子序列的长度。
  2. 计算 LIS:双层循环遍历 pred_to_gt_mapping,如果两个节点都有有效的匹配,并且它们的映射值是递增的,则更新 dp 数组。
  3. 计算正确匹配的总数:最后,通过取 dp 数组中的最大值得到正确匹配的节点总数。

这段代码实现的是一个动态规划算法,用于计算最长递增子序列(LIS)的长度。在这个上下文中,它被用来确定预测到真实匹配项之间的递增顺序的长度,从而评估预测序列的质量。

最长递增子序列(LIS)

最长递增子序列是一个在数列中找出一组递增排序的最长序列的问题。在这个应用中,pred_to_gt_mapping 映射了预测节点到真实节点的索引,而 LIS 用于找出这些映射中的最长有序(递增)关系。

初始化

dp 数组用于存储每个元素作为序列结束点时的最长递增子序列的长度。数组初始化为1,表示每个元素自身至少可以构成长度为1的序列。

动态规划过程

  • 外层循环 (for i in range(len_pred)):遍历每一个预测节点。
  • 内层循环 (for j in range(i)):对于每个节点 i,遍历所有在其前面的节点 j
  • 条件判断if pred_to_gt_mapping[i] == -1 or pred_to_gt_mapping[j] == -1: 如果任一节点未匹配任何真实节点(标记为 -1),则跳过当前循环迭代,因为未匹配的节点不能用来形成有效的递增序列。if pred_to_gt_mapping[i] > pred_to_gt_mapping[j]: 如果节点 i 的匹配索引大于节点 j 的匹配索引,说明这两个节点可以形成一个递增的序列。此时,需要更新 dp[i] 的值。

更新dp[i]

  • dp[i] = max(dp[i], dp[j] + 1): 这一步是核心的动态规划更新逻辑。如果节点 i 和节点 j 可以形成递增序列,那么以 i 结尾的最长递增子序列长度可以从 j 的序列长度加1得到。更新 dp[i] 为其自身的值和 dp[j] + 1 之间的最大值。

计算最长递增子序列的总长度

  • correct_count = int(max(dp)): 通过取 dp 数组中的最大值,我们得到整个预测序列中最长的递增子序列的长度,这代表了正确匹配的最大顺序数量。

示例说明

给定的映射 {4: 5, 3: 1, 2: 3, 1: 2, 0: 0} 中:

  • 开始时每个点自身至少构成长度为1的序列。
  • 当检查到节点 3 和节点 1 时,因为 1 的映射是 23 的映射是 1,且 1 < 2,不满足递增关系,因此不更新 dp[3]
  • 检查到节点 4 和之前的节点时,会发现它与之前的所有有效映射节点(0, 1, 2, 3)都形成了递增序列,所以它的 dp 值将基于最长的递增序列更新。

这样,dp 数组最终记录了以每个节点结尾的最长递增子序列的长度,其最大值代表了整个序列中正确匹配的最大顺序数。

6 通俗解释

这个过程可以通过一个比喻来通俗地解释:想象一下你在组织一个舞会,其中的每个舞者(动作)有他们特定的舞步(动作名称)和舞伴(参数)。你的任务是为每个舞者找到最佳的舞伴,以便他们的舞步能够尽可能完美地配合。

生成 BertScore 矩阵:

步骤1: 将每个舞者的舞步和他们选择的舞伴看作是一个计划中的一个动作。你需要记录下每个舞者的名字和他们喜欢的舞伴的类型。

步骤2: 使用一种特殊的“评分系统”(self.sentence_model.encode),这个系统可以评估每位舞者的舞步和舞伴选择与其他舞者的相似度。这就像给他们的舞步和舞伴选择打分,分数越高说明他们越可能搭配得好。

步骤3: 计算舞步和舞伴之间的配合度(余弦相似度),这可以帮助你看到哪些舞者可能是理想的舞伴。

步骤4: 根据舞步和舞伴的重要性(name_weightargs_weight),合成一个最终得分(BertScore)。这就像决定在评估一个舞者的时候,他们的舞步技巧和选择的舞伴哪个更重要。

匈牙利匹配算法(Hungarian Matching):

步骤5: 使用生成的得分(BertScore 矩阵)来创建一个舞会的配对图。在这个图中,每个节点代表一个舞者,每条边代表两个舞者可能成为舞伴的得分。

步骤6: 使用一个算法(匈牙利匹配算法,通过网络X的 max_weight_matching 方法实现)来找出这场舞会中所有可能的最佳舞伴组合。这保证了每个舞者都找到了最匹配的舞伴,而且整个舞会的舞伴匹配得分最高。

计算匹配和最长递增子序列 (LIS):

步骤7: 为了记录哪些舞者找到了舞伴,你创建一个映射表 (pred_to_gt_mapping),记录每个舞者匹配的结果。

步骤8: 如果某个舞者找不到任何合适的舞伴,你就把他们标记为 -1

步骤9: 最后,使用最长递增子序列(LIS)的方法来计算实际上有多少舞者与他们的舞伴舞步完美匹配。这个数值能帮你了解舞会的整体配对成功率。

这整个过程就像是组织一场完美的舞会,每个舞者都能找到与之舞步完美匹配的舞伴,使得整个舞会的舞步和谐统一。

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表