多服务器显卡使用状态监控工具实现

一、前言

实验室是做计算机视觉的,拥有几台服务器,每台服务器上有1到8张数量不等的显卡,用于同学们跑深度学习算法。每次到了会议/期刊截稿前,服务器都人满为患,大家各种抢显卡用(同时占用多张显卡能提高程序的并行性,训练模型花费时间变少),旱的旱死涝旳涝死,有的同学占了十几张卡,而有的同学天天蹲守服务器也等不到一张空闲的卡。于是,产品经理(其实是实验室大师兄)提了个需求,让我写个显卡监控工具,用于监控所有服务器的显卡使用情况,并汇总到一起做个排名(每人用了多少张卡,占了多少显存等等)。

二、实现

  • 我的想法基于C/S模型
    • 在每个需要监控的服务器上运行一个客户端,每隔一段时间就把当前服务器的显卡使用情况发送给服务器
    • 服务器负责接收这些信息,并做汇总排序
  • 有了想法后,再参考网上大佬的代码,具体实现如下

2.1 客户端

  • 用Popen函数执行nvidia-smi命令,获取返回信息。然后使用正则匹配的方式获取nvidia-smi返回信息中的GPU信息和进程信息。GPU信息包含每张显卡的显存大小、已被占用显存;进程信息包含显卡ID、进程ID、执行命令、占用显存。然后通过进程ID找到对应的用户名
def get_owner(pid):
    try:
        for line in open('/proc/%d/status' % pid):
            if line.startswith('Uid:'):
                uid = int(line.split()[1])
                return pwd.getpwuid(uid).pw_name
    except:
        return None

def get_info():
    info = { 'gpu': [], 'process': [] }
    msg = subprocess.Popen('nvidia-smi', stdout = subprocess.PIPE).stdout.read().decode()
    msg = msg.strip().split('\n')

    lino = 8
    while True:
        status = re.findall('.*\d+%.*\d+C.*\d+W / +\d+W.* +(\d+)MiB / +(\d+)MiB.* +\d+%.*', msg[lino])
        if status == []: break
        mem_usage, mem_total = status[0]
        info['gpu'].append({
            'mem_usage': float(mem_usage),
            'mem_total': float(mem_total),
        })
        lino += 3

    lino = -1
    while True:
        lino -= 1
        status = re.findall('\| +(\d+) +(\d+) +\w+ +([^ ]*) +(\d+)MiB \|', msg[lino])
        if status == []: break
        gpuid, pid, program, mem_usage = status[0]
        username = get_owner(int(pid))
        if username is None:
            print('进程已经不存在')
            continue
        try:
            p = psutil.Process(int(pid))
            p.cpu_percent()
            time.sleep(0.5)
            cpu_percent = p.cpu_percent()
        except psutil.NoSuchProcess:
            print('进程已经不存在')
            continue
        info['process'].append({
            'gpuid': int(gpuid),
            'pid': int(pid),
            'program': program,
            'cpu_percent': cpu_percent,
            'mem_usage': float(mem_usage),
            'username': username,
        })
    info['process'].reverse()

    return info
  • 获取显卡信息后通过HTTP库request,把显卡信息发送给服务器,并sleep一定时间(每隔一定时间获取一次显卡信息并发送)
while True:
    mean_info = get_info()
    data = json.dumps(mean_info)
    try:
        response = requests.get(url, data = data)
        print('HTTP状态码:', response.status_code)
    except Exception as e:
        print(e)
    time.sleep(opt.persecond)

2.2 服务器

  • 使用http.server中的HTTPServer类,开启一个HTTP服务器,在do_GET函数中对客户端的GET请求进行处理,主要是更新客户端的显卡信息和时间戳
class CustomHandler(BaseHTTPRequestHandler):
    alert_record = { }

    def do_GET(self):
        length = int(self.headers['content-length'])
        info = json.loads(self.rfile.read(length).decode())
        slaver_address, _ = self.client_address
        lock.acquire()
        if slaver_address not in info_record:
            info_record[slaver_address] = {}
        info_record[slaver_address]['info'] = info
        info_record[slaver_address]['timestamp'] = time.time()
        lock.release()
        report_user()
        self.send_response(200)
        self.end_headers()
  • 在report_user函数中对显卡信息进行排序输出。遍历info_record(info_record是一个dict,用于保存每个客户端的显卡信息和时间戳,直接用客户IP地址字符串作为dict的键),判断客户端的显卡信息是否过期(默认100s不更新就算过期),若过期,直接continue;若不过期,则分别统计用户显存占用和用户显卡占用。处理完info_record后,把排序结果print到终端
def report_user():
    usage_dict = { }
    usage_num = { }
    lock.acquire()
    for slaver_address in sorted(info_record.keys()):
        if isExpire(info_record[slaver_address]['timestamp']):
            continue
        pi_list = info_record[slaver_address]['info']['process']
        for pi in pi_list:
            username = pi['username']
            mem_usage = pi['mem_usage']
            gpu_id = pi['gpuid']
            usage_dict[username] = usage_dict.get(username, 0) + mem_usage
            if username in usage_num:
                usage_num[username].add(slaver_address + ':' + str(gpu_id))
            else:
                s = set()
                s.add(slaver_address + ':' + str(gpu_id))
                usage_num[username] = s
    lock.release()

    print('=' * 60)
    print(time.strftime('%Y-%m-%d %H:%M:%S\n', time.localtime(time.time())))
    usage_list = sorted(usage_dict.items(), key = lambda x: x[1], reverse=True)
    print('用户显存占用排序:')
    print('<用户ID> : <占用显存(MB)>')
    for username, usagememory in usage_list:
        print('%s : %dM' % (username, usagememory))
    # report1 = ['%s : %dM' % (n, u) for n, u in usage_list]
    # report1 = '\n'.join(report1)
    print()

    usage_num_list = sorted(usage_num.items(), key = lambda x: len(x[1]), reverse=True)
    print('用户显卡占用排序:')
    print('<用户ID> : <占用显卡数量> : <显卡IP地址和ID>')
    for username, gpuset in usage_num_list:
        print('%s : %d : ' % (username, len(gpuset)), end='')
        for gpuaddr in gpuset:
            print('%s, ' % gpuaddr, end='')
        print()
    print('=' * 60)

    return

三、最终效果

多服务器显卡使用状态监控工具实现

完整代码