首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >二值分类的决策树

二值分类的决策树
EN

Code Review用户
提问于 2015-03-02 01:31:47
回答 1查看 647关注 0票数 3

我想成为一个优秀的Python程序员,所以我想知道在我的代码实践中我可以改进什么。总的来说,我觉得自己是一个非常可靠的程序员,但是编写这段代码感觉非常"Java“,所以我可能仍然在遵循Python代码实践方面的糟糕实践。

代码语言:javascript
复制
__author__ = "arthur"

import pandas as pd

pd.set_option('display.max_rows', 1000)

def indices_next(ls):
    indices = []
    for i, element in enumerate(ls):
        if i != 0:
            if element != ls[i-1]:
                indices.append(i)
    return indices

def summed_list(ls):
    for i, elt in enumerate(ls):
        if i != 0:
            ls[i] += ls[i-1]
    return ls

class TreeNode(object):
    class_counter = 0
    def __init__(self):
        self.name = TreeNode.class_counter
        TreeNode.class_counter += 1
        self.split_gini = -1000
        self.data = pd.DataFrame()
        self.node_type = "Node"
        self.node_gini = 1.0
        self.split_value = -1000
        self.split_attribute = ""
        self.parent = None
        self.left_child_node = None
        self.left_child_complete = False
        self.split_dict = dict()

        self.right_child_node = None
        self.right_child_complete = False
        self.level = 0

    def compute_gini_new_node(self):
        split_dict = self.data["OK"].value_counts().to_dict()
        self.split_dict = split_dict
        if len(split_dict) == 2:
            # print "Size of split_dict is 2"
            zero_count = float(split_dict[0])
            one_count = float(split_dict[1])
            gini = 1 - (zero_count/(zero_count+one_count))**2 - (one_count/(zero_count+one_count))**2
            self.node_gini = gini

class DecisionTree(object):
    def __init__(self):
        self.root = TreeNode()
        self.attributes = []
        self.used_attributes = set()

    def is_leaf_node(self, node):
        result = False
        data_ct = len(node.data)
        if len(node.split_dict) == 1:
            result = True
        elif len(node.split_dict) == 2:
            zeroes = node.split_dict[0]
            ones = node.split_dict[1]
            if zeroes == 0 or ones == 0:
                result = True
            else:
                result = False
            if zeroes == 0 and ones == 0:
                result = False
        else:
            result = False
        return result

    def create(self, filename):
        current_node = self.root
        current_node.data = pd.read_csv(filename)
        self.attributes = current_node.data.columns

    def create_test(self, filename):
        csv = pd.read_csv(filename)
        current_node = self.root
        current_node.data = csv[:int((0.8*len(csv)))]
        self.attributes = current_node.data.columns

        testing_data = csv[int((0.8*len(csv))):]

        self.train_tree()

        for i in range(400, 400+len(testing_data)):
            current_node = self.root
            record = testing_data.loc[i]
            while current_node.left_child_node is not None and current_node.right_child_node is not None:
                attr = current_node.split_attribute
                if current_node.split_value <= record[attr]:
                    current_node = current_node.right_child_node
                    continue
                else:
                    current_node = current_node.left_child_node
                    continue
            if len(current_node.split_dict) == 2:
                zero = current_node.split_dict[0]
                #print current_node.split_dict
                one = current_node.split_dict[1]
                zeropc = zero/(float(zero + one))

                onepc = one/(float(zero + one))
                if max(zeropc, onepc) == zeropc:
                    print str(record['ID']) + " is a 0 " + str(zeropc)
                else:
                    print str(record['ID']) + " is a 1 " + str(onepc)


    def train_tree(self):
        self.train_tree_hidden(self.root)

    def get_attribute_ginis(self, current_node):
        attribute_ginis = dict()
        hold_ginis = []
        for attribute in self.attributes[1:10]:
            if attribute not in self.used_attributes:
                # print self.used_attributes
                attribute_ginis[attribute] = []
                # hold_ginis will hold the gini coefficients of every possible splitting condition to find the best one
                # attribute_df uses built in pandas functions to sort by attribute, THEN by ID
                attribute_df = current_node.data.sort([attribute,"ID"])

                # attribute_vals are the actual sorted values of the individual attribute
                attribute_vals = attribute_df[attribute]

                # buckets is a histogram of the different attribute value counts.
                buckets = attribute_vals.value_counts()

                #print attribute

                # since attribute_vals is sorted, we can use this to know the offset
                series = buckets.sort_index()

                # series_keys = series.keys().tolist()
                # for key in series_keys:
                #     offset = series[key]
                #     print str(key) + " " + str(offset)

                summedlist = summed_list(series.tolist())
                #print summedlist
                #return
                count = 0
                for element in summedlist:

                    # we get the sorted list of attribute values, and using elt, the summed indices, we grab the
                    # data that's been sectioned up (see attribute_vals, buckets, etc)
                    subsection = attribute_df[:element]
                    # last_val = subsection[-1:]

                    # this is a series
                    val_counts = subsection["OK"].value_counts()
                    series_size = val_counts.size
                    if series_size == 2:
                        # then we know this node will split and it is not a leaf node
                        left = val_counts[0]
                        right = val_counts[1]
                        if left != 0 or right != 0:
                            gini = 1 - (float(left)/(left+right))**2 - (float(right)/(left+right))**2

                            # hold_ginis.append(tpl)
                    elif series_size == 1:
                        if len(subsection) > 0:
                            gini = 0
                    tpl = (gini, count, element, attribute)
                    attribute_ginis[attribute].append(tpl)
                    count += 1
        return attribute_ginis


    def split(self, current_node):
        attributes_start = current_node.data.columns[1:10]

        attribute_ginis = self.get_attribute_ginis(current_node)

        # print attribute_ginis
        tuple_dict = dict()
        slimmer_ginis = []
        for attribute in attributes_start:
            tuple_dict[attribute] = []
            if attribute not in self.used_attributes:
                # print attribute_ginis
                tuple_list = sorted(attribute_ginis[attribute],  key=lambda x: x[0])
                if len(tuple_list) > 0:
                    slimmer_ginis.append(tuple_list[0])
                    # No, the first tuple is not necessarily the best.
                    # first_tuple_is_best = sorted(slimmer_ginis, key= lambda x: x[0])
                for tpl in tuple_list:
                    if tpl[2] != 500:
                        left_df = current_node.data.sort([attribute,"ID"])[:tpl[2]]
                        right_df = current_node.data.sort([attribute,"ID"])[tpl[2]:]
                        #print "first half"
                        #left_freq = left_df[attribute].value_counts()
                        #print "second half"

                        # right_freq = right_df[attribute].value_counts()

                        left_freq = left_df["OK"].value_counts()

                        right_freq = right_df["OK"].value_counts()
                        leftsum = 0
                        rightsum = 0
                        if len(left_freq) == 0:
                            left_gini = 0
                        elif len(left_freq) == 1:
                            left_gini = 0
                        else:

                            left_gini = 1 - (left_freq[0]/float(left_freq[0] + left_freq[1]))**2 - (left_freq[1]/float(left_freq[0] + left_freq[1]))**2
                            leftsum = left_freq[0] + left_freq[1]

                        #print "LG " + str(left_gini)

                        if len(right_freq) == 0:
                            right_gini = 0
                        elif len(right_freq) == 1:
                            right_gini = 0
                        else:
                            right_gini = 1 - (right_freq[0]/float(right_freq[0] + right_freq[1]))**2 - (right_freq[1]/float(right_freq[0] + right_freq[1]))**2
                            rightsum = right_freq[0] + right_freq[1]

                        #print "RG " + str(right_gini)
                        if left_gini == 0 or right_gini == 0:
                            continue
                        else:

                            split_gini = left_gini*(leftsum/float(leftsum+rightsum)) + right_gini*(rightsum/float(leftsum+rightsum))
                            #print "SG " + str(split_gini)
                            info_gain = current_node.node_gini - split_gini
                            #print info_gain

                            tuple_dict[attribute].append((tpl[0], tpl[1], tpl[2], attribute, left_gini, right_gini, split_gini, info_gain))
                        # print tuple_dict
        total_list = []
        #print tuple_dict
        for key in tuple_dict.keys():
            for val in tuple_dict[key]:
                total_list.append(val)
        list = sorted(total_list,  key=lambda x: x[7])

        if len(list) == 0:
            return

        x = list[0]

        gini = x[0]
        valsplit = x[1]
        ind = x[2]

        attribute_to_split_on = x
        # smallest_gini_value_tuple = start_tuple[0]
        current_node.node_gini = gini
        best_attribute_to_split_on_based_on_gini = x[3]
        best_attribute = best_attribute_to_split_on_based_on_gini
        attribute_df = current_node.data.sort([best_attribute,"ID"])
        # print len(attribute_df)


        current_node.split_value = x[1]
        # attribute_df[:ind][-1:][best_attribute].tolist()[0]


        current_node.split_attribute = best_attribute

        current_node.left_child_node = TreeNode()
        current_node.left_child_node.parent = current_node

        current_node.right_child_node = TreeNode()
        current_node.right_child_node.parent = current_node

        current_node.left_child_node.data = attribute_df[:ind]
        current_node.right_child_node.data = attribute_df[ind:]

        # print best_attribute
        # print "left node data count " + str(len(current_node.left_child_node.data))
        # print "LEFT NODE DATA BEGINS"
        # print "--------------"
        # #print current_node.left_child_node.data
        # print "--------------"
        # print "LEFT NODE DATA ENDS"

        lnode_data = current_node.left_child_node.data
        #
        # print "right node data count " + str(len(current_node.right_child_node.data))
        # print "RIGHT NODE DATA BEGINS"
        # print "--------------"
        # #print current_node.right_child_node.data
        # print "--------------"
        # print "RIGHT NODE DATA ENDS"
        rnode_data = current_node.right_child_node.data

        current_node.left_child_node.compute_gini_new_node()
        current_node.right_child_node.compute_gini_new_node()

        left_half_split_gini = (float(len(lnode_data))/len(current_node.data))*current_node.left_child_node.node_gini
        right_half_split_gini = (float(len(rnode_data))/len(current_node.data))*current_node.right_child_node.node_gini

        current_node.split_gini = left_half_split_gini + right_half_split_gini

        # print "Split gini " + str(current_node.split_gini)
        # print current_node.split_attribute
        # print "I am " + str(current_node.name)
        # if current_node.parent is not None:
        #   print "My parent is " + str(current_node.parent.name)
        # print "LTE " + str(current_node.split_value)
        # print "left is " + str(current_node.left_child_node.data["OK"].value_counts().to_dict())
        # print "right is " + str(current_node.right_child_node.data["OK"].value_counts().to_dict())
        #
        #
        # left_gini = current_node.left_child_node.node_gini
        # right_gini = current_node.right_child_node.node_gini
        #
        # print "left gini is " + str(left_gini)
        #
        # print "right gini is " + str(right_gini)
        # print ""

        self.used_attributes.add(x[3])

    def train_tree_hidden(self, current_node):
        not_done = True

        while not_done:
            self.split(current_node)

            if current_node.left_child_node is None and current_node.right_child_node is None:
                print "Current node is a leaf node!!!"
                return
            else:
                left_gini = current_node.left_child_node.node_gini
                right_gini = current_node.right_child_node.node_gini
                split_gini = current_node.split_gini
                # print "do i get here"
                left_gain = (left_gini - split_gini)
                right_gain = (right_gini - split_gini)

                if max(left_gain, right_gain) == left_gain:
                    current_node = current_node.left_child_node
                else:
                    current_node = current_node.right_child_node
                continue



    def test(self, filename):
        current_node = self.root

        csv = pd.read_csv(filename)
        freq = dict()
        freq[0] = 0
        freq[1] = 0
        print "ID,OK"
        for i in range(0, len(csv)):
            current_node = self.root
            record = csv.loc[i]
            while current_node.left_child_node is not None and current_node.right_child_node is not None:
                attr = current_node.split_attribute
                if current_node.split_value <= record[attr]:
                    current_node = current_node.right_child_node
                    continue
                else:
                    current_node = current_node.left_child_node
                    continue
            if len(current_node.split_dict) == 2:
                zero = current_node.split_dict[0]
                # print current_node.split_dict
                one = current_node.split_dict[1]
                zeropc = zero/(float(zero + one))

                onepc = one/(float(zero + one))
                if max(zeropc, onepc) == zeropc:
                    print str(record['ID'])+",0"
                    freq[0] += 1
                else:
                    print str(record['ID']) + ",1" # + str(onepc)
                    freq[1] += 1
        #print freq






        # print csv

def main(filename):
    filename = "training.csv"

    dt = DecisionTree()
    dt.create(filename)
    dt.train_tree()
    dt.test("test.csv")

    # dt.create_test(filename)




# dt.test("test.csv")

main("training.csv")
EN

回答 1

Code Review用户

发布于 2015-03-02 02:12:34

免责声明--我不太关心代码的Pythonic,而更关心它的易懂程度。

几个观察

  1. 大约有400行代码。很难看出这一切是如何结合在一起的。
  2. 有一些评论。他们中没有一个人画大图。许多评论只是死气沉沉的代码。这些都妨碍了清晰。
  3. 测试与生产代码混合在一起。这增加了批量,使主要逻辑更难理解,而不是更少。

--

几点建议

  1. 通过从生产逻辑中删除测试,使代码更加模块化。
  2. 删除死代码。考虑使用版本控制来维护历史调查。
  3. 考虑一下对程序的总体描述,这样任何阅读它的人,包括一周后的你自己,都能更快地理解代码是干什么的,以及它是如何结合在一起完成的。
  4. 考虑更好的名字。
    • tuple_dict代表什么,即这些元组是什么?
    • zero可能被标记为zero_coefficient,除非它是0
    • 为什么是TreeNode而不是Node?程序的其余部分使用节点来引用节点。

  5. 为什么TreeNode包含这么多神奇的数字?也许这应该作为文件中的一个参数传递?
票数 3
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/82957

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档