netlink socket(linux内核模块与用户态之间通信实例)

本文通过一个编程实例来更深入地了解netlink。

1.1        实现内容

netlink socket(linux内核模块与用户态之间通信实例)

1.          功能

Ø  实现一个并发的echo服务器程序,它(接收端)将接收到字符串进行转换处理后,返回给发送端;允许有多个发送端同时存在;

Ø  字符串的处理包括:直接返回、全部转换为大写、全部转换为小写;处理方式应可以配置,配置方式包括全局(缺省)及每“发送-接收对(xmit/recv peer)”的。配置转换时,不应该影响正在处理的字符串;

Ø  为模拟处理过程的延时,接收端中每个字符转换添加200ms的延迟;

Ø  接收端中需统计其处理字符的数量,并且给外部提供命令形式的查询手段;

Ø  具备必要的例外处理,如内存不足等、client未关闭链接即意外退出等;  

1.2        验收成果

1.        程序结构(应用专业方向)

Ø  以有链接的client-server方式实现echo程序  

Ø  存在多个client,client与server在同一个系统中,它们之间采用domain socket或netlink socket相连;client之间没有关联

Ø  client是一个程序的多个实例,而server只允许运行一个实例  

2.        程序结构(内核专业方向)

Ø  内核模块作为接收端(服务器),而发送端(客户端)是用户空间的应用程序;

Ø  内核模块可动态加载与卸载 

Ø  驱动相关组:

·          创建虚拟字符型设备,用于接受来自客户端的字符

·          使用ioctl控制接口来配置echo处理方式,ioctl的cmd自定义;

·          并发使用驱动的多实例来模拟;

Ø  非驱动相关组:

·          创建内核线程,作为内核中的server,用于,通信可采用netlink socket或其它可用的方式;

·          使用netlink接口来配置echo处理转换方式;


用户态netlink socket源码

#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <string.h>
#include <sys/time.h>
#include <linux/netlink.h>
#include <signal.h>

#define NETLINK_TEST    17
#define MSG_LEN         125
#define BUF_LEN         125 
#define TIME            210

int skfd;
struct sockaddr_nl local;
struct sockaddr_nl dest;
struct nlmsghdr *message;

struct msg_to_kernel
{
    struct nlmsghdr hdr;
    char data[MSG_LEN];
};
struct u_packet_info
{
    struct nlmsghdr hdr;
    char msg[MSG_LEN];
};

static void sig_pipe(int sign)
{
    printf("Catch a SIGPIPE signal\n");
    close(skfd);
    kill(local.nl_pid,SIGUSR1);
    exit(-1);
}

int init_netlink(void)
{
    //char *send_data = "aaaaaaaaaaaaBBBBBBBBBBBBBBBBBBaaaaaaaaaaaaaaaaBBBBBBBBBBBB11111112222";
    char send_data[BUF_LEN];
    message = (struct nlmsghdr *)malloc(1);       
    skfd = socket(PF_NETLINK, SOCK_RAW, NETLINK_TEST);
    
    if(skfd < 0){
        printf("can not create a netlink socket\n");
        return -1;
    }    
    
    memset(&local, 0, sizeof(local));
    local.nl_family = AF_NETLINK;
    local.nl_pid = getpid();    
    local.nl_groups = 0;
    if(bind(skfd, (struct sockaddr *)&local, sizeof(local)) != 0){
        printf("bind() error\n");
        return -1;
    }
    memset(&dest, 0, sizeof(dest));
    dest.nl_family = AF_NETLINK;
    dest.nl_pid = 0;
    dest.nl_groups = 0;

    memset(message, '\0', sizeof(struct nlmsghdr));
    message->nlmsg_len = NLMSG_SPACE(MSG_LEN);
    message->nlmsg_flags = 0;
    message->nlmsg_type = 0;
    message->nlmsg_seq = 0;
    message->nlmsg_pid = local.nl_pid;    
    
    while(1) {
        printf("input  the  data: ");
        fgets(send_data, MSG_LEN, stdin);
        if(0 == (strlen(send_data)-1)) 
            continue;
        else
            break;
    }
    memcpy(NLMSG_DATA(message), send_data, strlen(send_data)-1);
    printf("send  to  kernel: %s,  send_id: %d   send_len: %d\n", \
        (char *)NLMSG_DATA(message),local.nl_pid, strlen(send_data)-1);
    return 0;
}

/**
 * NAME: ngsa_test_init 
 *
 * DESCRIPTION:
 *      ngsa test model 初始化
 * @*psdhdr   
 * @*addr     
 * @size      
 * 
 * RETURN: 
 */
int main(int argc, char* argv[]) 
{    
    int ret,len; 
    socklen_t destlen = sizeof(struct sockaddr_nl);    
    struct u_packet_info info; 
    fd_set fd_sets;
    struct timeval select_time;
    
    signal(SIGINT, sig_pipe);
    signal(SIGINT, sig_pipe);
    
    ret = init_netlink();
    if(ret<0) {
        close(skfd);
        perror("netlink failure!");
        exit(-1);
    }
    
    FD_ZERO( &fd_sets );
    FD_SET( skfd, &fd_sets);
    
    len = sendto(skfd, message, message->nlmsg_len, 0,(struct sockaddr *)&dest, sizeof(dest));
    if(!len){
        perror("send pid:");
        exit(-1);
    }

    select_time.tv_sec = TIME;
    select_time.tv_usec = 0;
    ret = select(skfd+1,&fd_sets,NULL,NULL,&select_time);

    if(ret > 0){
        /* 接受内核态确认信息 */        
        len = recvfrom(skfd, &info, sizeof(struct u_packet_info),0, (struct sockaddr*)&dest, &destlen);        
        printf("recv from kernel: %s,  recv_len: %d\n",(char *)info.msg, strlen(info.msg));
    }else if(ret < 0) {
        perror("\n error! \n");
        exit(-1);
    }else {
        printf("\n kernel server disconncet! \n");
        kill(local.nl_pid, SIGUSR1);
    }
        
    /* 内核和用户关闭通信 */
    close(skfd);
    return 0;
}

#########################################################################

linux内核源码(module.c)

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/sched.h>
#include <net/sock.h>
#include <linux/netlink.h>
#include <linux/kthread.h>
#include <linux/sched.h>
#include <linux/err.h>
#include <linux/fs.h>
#include <linux/init.h>
#include <linux/cdev.h>

#define MEMDEV_MAJOR    255   /* 预设的mem的主设备号 */
#define MEMDEV_NR_DEVS  1     /* 设备数 */
#define MEMDEV_SIZE     1024  /* 分配内存的大小 */
#define NETLINK_TEST    17
#define UP_TO_LOW       0
#define LOW_TO_UP       1
#define MAX_PID_COUNT   100
#define MSG_LEN         125 

#ifndef SLEEP_MILLI_SEC
#define SLEEP_MILLI_SEC(nMillisec) \
    do { \
        long timeout = (nMillisec) * HZ /1000; \
        while (timeout > 0) { \
            timeout = schedule_timeout(timeout); \
        } \
    }while(0);
#endif

static int echo_major = MEMDEV_MAJOR;
module_param(echo_major, int, S_IRUGO);
struct echo_dev *echo_devp;                   /*设备结构体指针*/
struct cdev cdev;
char *echo_dev = "echodev";
static struct sock *netlinkfd = NULL;
static struct task_struct *task_test[MAX_PID_COUNT];
static int pid_index = 0;
static int char_num = 0;
static int char_cnvt_flag = 0;

/* mem设备描述结构体 */
struct echo_dev                                     
{                                                        
  char *data;          /* 分配到的内存的起始地址 */                    
  unsigned long size;  /* 内存的大小 */
};

struct{
    __u32 pid;
}user_process;


/* netlink */
struct echo_netlink                                     
{  
  __u32 pid;            /* netlink pid */
  char  buf[MSG_LEN];   /* data */  
  int   length;         /* buf len  */
};

struct echo_netlink client_netlink[MAX_PID_COUNT];

static int echo_open(struct inode *inode, struct file *filp);
static ssize_t echo_read(struct file *filp, char __user *buf, size_t size, loff_t *ppos);
static long echo_ioctl(struct file *filp, unsigned int cmd, unsigned long arg);

static const struct file_operations echo_fops =
{
    .owner = THIS_MODULE,
    .open = echo_open,
    .read = echo_read,
    .unlocked_ioctl = echo_ioctl,
};

static int echo_open(struct inode *inode, struct file *filp)
{
    /*获取次设备号*/
    printk(KERN_DEBUG"[kernel space] open char device!!\n");
    return 0;
}

static ssize_t echo_read(struct file *filp, char __user *buf, size_t size, loff_t *ppos)
{
    printk(KERN_DEBUG"[kernel space] test_netlink_exit!!\n");
    return char_num;
}

static long echo_ioctl(struct file *filp, unsigned int cmd, unsigned long arg)
{
    int result = 0;
    
    switch(cmd) {
        case UP_TO_LOW:
            char_cnvt_flag = 0;
            break;
        case LOW_TO_UP:
            char_cnvt_flag = 1;
            break;
        default       :
            result = -1;
            break;
    }
    printk(KERN_DEBUG"[kernel space] ioctl cmd: %d\n",char_cnvt_flag);
    return result;
}

int init_char_device(void)
{
    int i,result;
    dev_t devno = MKDEV(echo_major, 0);
        
    if (echo_major)
        /* 静态申请设备号*/
        result = register_chrdev_region(devno, 2, "echodev");
    else {
        /* 动态分配设备号 */
        result = alloc_chrdev_region(&devno, 0, 2, "echodev");
        echo_major = MAJOR(devno);
    }

    if ( result<0 )
        return result;

    /* 初始化cdev结构 */
    cdev_init(&cdev, &echo_fops);
    cdev.owner = THIS_MODULE;
    cdev.ops = &echo_fops;

    /* 注册字符设备 */    
    cdev_add(&cdev, MKDEV(echo_major, 0), MEMDEV_NR_DEVS);

    /* 为设备描述结构分配内存 */
    echo_devp = kmalloc(MEMDEV_NR_DEVS * sizeof(struct echo_dev), GFP_KERNEL);

    /* 申请失败 */
    if (!echo_devp)
    {
        result = -1;
        goto fail_malloc;
    }

    memset(echo_devp, 0, sizeof(struct echo_dev));

    /* 为设备分配内存 */
    for(i= 0; i < MEMDEV_NR_DEVS; i++) {
        echo_devp[i].size = MEMDEV_SIZE;
        echo_devp[i].data = kmalloc(MEMDEV_SIZE, GFP_KERNEL);
        memset(echo_devp[i].data, 0, MEMDEV_SIZE);
    }
    
    printk(KERN_ERR"[kernel space] create char device successfuly!\n");
    return 0;
    
    fail_malloc: 
        unregister_chrdev_region(devno, 1);
    return result;
}

void delete_device(void)
{
    /* 注销设备 */
    cdev_del(&cdev);
    
    /* 释放设备号 */
    unregister_chrdev_region(MKDEV(echo_major, 0), 2);
    printk(KERN_DEBUG"[kernel space] echo_cdev_del!!\n");
}

static int kernel_send_thread(void *index)

    int threadindex = *((int *)index);
    int size;
    struct sk_buff *skb;
    unsigned char *old_tail;
    struct nlmsghdr *nlh;  //报文头    
    int retval;     
    int i=0;     
    
    size = NLMSG_SPACE(client_netlink[threadindex].length);

    /* 分配一个新的套接字缓存,使用GFP_ATOMIC标志进程不会被置为睡眠 */
    skb = alloc_skb(size, GFP_ATOMIC); 

    /* 初始化一个netlink消息首部 */
    nlh = nlmsg_put(skb, 0, 0, 0, NLMSG_SPACE(client_netlink[threadindex].length)-sizeof(struct nlmsghdr), 0);
    old_tail = skb->tail;    

    //memcpy(NLMSG_DATA(nlh), client_netlink[i].buf, client_netlink[i].length);  //填充数据区
    strcpy(NLMSG_DATA(nlh), client_netlink[threadindex].buf);  //填充数据区
    nlh->nlmsg_len = skb->tail - old_tail;  //设置消息长度
    
    /* 设置控制字段 */
    NETLINK_CB(skb).pid = 0;
    NETLINK_CB(skb).dst_group = 0;

    printk(KERN_DEBUG "[kernel space] send  to  user: %s,  send_pid: %d,  send_len: %d\n", \
        (char *)NLMSG_DATA((struct nlmsghdr *)skb->data), client_netlink[threadindex].pid, \
        client_netlink[threadindex].length);    

    /* 发送数据 */
    retval = netlink_unicast(netlinkfd, skb, client_netlink[threadindex].pid, MSG_DONTWAIT);

    if (retval<0) {
        printk(KERN_DEBUG "[kernel space] client closed: \n");        
    }
    
    while(!(i = kthread_should_stop())) {
        printk(KERN_DEBUG "[kernel space] kthread_should_stop: %d\n", i);
        SLEEP_MILLI_SEC(1000*10);             
    }
    
    return 0;
}

void char_convert(int id) 
{    
    int len = client_netlink[id].length;
    int i = 0;

    client_netlink[id].buf[len] = '\0';
    if( UP_TO_LOW == char_cnvt_flag ) {
        printk(KERN_DEBUG "[kernel space] UP_TO_LOW\n");
        while(client_netlink[id].buf[i] != '\0') {           
            if(client_netlink[id].buf[i] >= 'A' && client_netlink[id].buf[i] <= 'Z') {
                client_netlink[id].buf[i] = client_netlink[id].buf[i] + 0x20;
                mdelay(200);
            }            
            i++;
        }
    }
    else if( LOW_TO_UP == char_cnvt_flag ) {
        printk(KERN_DEBUG "[kernel space] LOW_TO_UP\n");
        while(client_netlink[id].buf[i] != '\0') {            
            if(client_netlink[id].buf[i]  >= 'a' && client_netlink[id].buf[i]  <= 'z') {
                client_netlink[id].buf[i] = client_netlink[id].buf[i]  - 0x20;
                mdelay(200);
            }            
            i++;
        }
    }    
    char_num += len;    
}

void run_netlink_thread(int thread_index)
{
    int err;
    char process_name[64] = {0};
 
    void* data = kmalloc(sizeof(int), GFP_ATOMIC);
    *(int *)data = thread_index;
    snprintf(process_name, 63, "child_thread-%d", thread_index);

    task_test[thread_index] = kthread_create(kernel_send_thread, data, process_name);

    if(IS_ERR(task_test[thread_index])) {
        err = PTR_ERR(task_test[thread_index]);
        printk(KERN_DEBUG "[kernel space] creat child thread failure \n");
    } else {
        printk(KERN_DEBUG "[kernel space] creat child_thread-%d \n", thread_index);    
        wake_up_process(task_test[thread_index]);
    }      
}

void buf_deal(int id)
{  
    char_convert(id);
    
    /* 唤醒线程 */
    run_netlink_thread(id);     
}

void kernel_recv_thread(struct sk_buff *__skb)
{
    struct sk_buff *skb;
    struct nlmsghdr *nlh = NULL;
    char *recv_data = NULL;
    int pid_id = 0;
    printk(KERN_DEBUG "[kernel space] begin kernel_recv\n");
    skb = skb_get(__skb);

    if(skb->len >= NLMSG_SPACE(0)) {
        nlh = nlmsg_hdr(skb);

        if(pid_index < MAX_PID_COUNT) {
            client_netlink[pid_index].pid = nlh->nlmsg_pid;
            recv_data = NLMSG_DATA(nlh);            
            strcpy(client_netlink[pid_index].buf,recv_data);
            client_netlink[pid_index].length = strlen(recv_data);
            printk(KERN_DEBUG "[kernel space] recv from user: %s,  recv_pid: %d,  recv_len: %d\n", \
            (char *)NLMSG_DATA(nlh), client_netlink[pid_index].pid, strlen(recv_data));            
            pid_id = pid_index;
            pid_index++;
            buf_deal(pid_id);
        } else {
            printk(KERN_DEBUG "[kernel space] out of pid\n");
        }
        kfree_skb(skb);
    }   
}

int init_netlink(void)
{    
    netlinkfd = netlink_kernel_create(&init_net,NETLINK_TEST,0,kernel_recv_thread,NULL,THIS_MODULE);
    if(!netlinkfd )
        return -1;
    else {
        printk(KERN_ERR"[kernel space] create netlink successfuly!\n");
        return 0;
    }        
}

void netlink_release(void)
{
    printk(KERN_DEBUG"[kernel space] echo_netlink_exit!\n");
    if(netlinkfd != NULL)
        sock_release(netlinkfd->sk_socket); 
}

void stop_kthread(void)
{
    int i;
    printk(KERN_ERR"[kernel space] stop kthread!\n");
    for(i=0; i != pid_index; i++) {
        if(task_test[i] != NULL) {            
            kthread_stop(task_test[i]);            
            task_test[i] = NULL;
        }
    }    
}

void  init_client(void)
{
    int i = 0;
    for(i=0; i<MAX_PID_COUNT; i++) {        
        client_netlink[i].pid = 0;        
        task_test[i] = NULL;
    }
}

/**
 * NAME: init_echo_module 
 *
 * DESCRIPTION:
 *      模块加载函数
 * @*psdhdr   
 * @*addr     
 * @size      
 * 
 * RETURN: 
 */
int __init init_echo_module(void)
{
    int result = 0;

    init_client();   

    result = init_char_device();
    if ( result<0 ) {
        printk(KERN_ERR"[kernel space] cannot create a netlinksocket!\n");
        return result;
    }
    
    result = init_netlink();
    if ( result<0 ) {
        printk(KERN_ERR"[kernel space] cannot create a netlinksocket!\n");
        return result;
    }        
    
    return result;
}

/**
 * NAME: exit_echo_module 
 *
 * DESCRIPTION:
 *      模块卸载函数
 * @*psdhdr   
 * @*addr     
 * @size      
 * 
 * RETURN: 
 */
void __exit exit_echo_module(void)
{    
    netlink_release();
    stop_kthread();
    delete_device();    
}

module_init(init_echo_module);
module_exit(exit_echo_module);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("zhang");

MODULE_VERSION("V1.0");


#####################用户态文件操作源码####################

#include <stdio.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <linux/fs.h>
#include <stdlib.h>  
#include <string.h> 
#include <unistd.h>
#include <signal.h>
#include <sys/ioctl.h>

#define MAX_SIZE 1024
char buf[MAX_SIZE]; //缓冲区
char *dir = "/dev/echodev";
int fd = 0;

static void stop(int sign)
{
    printf("Catch a signal\n");
    close(fd);
    exit(0);
}

int main()
{    
    int arg = 0;
    int cmd = 0;    
    int ret = 0;
    int len = 0;
    
    /*打开设备文件*/
    fd = open(dir, O_RDWR | O_NONBLOCK);

    if(fd == -1) {
        printf("open failure: %s\n",dir);
        close(fd);
        return -1;
    }    
    signal(SIGINT, stop);
    signal(SIGTSTP, stop);
    char opt[50];
    while(1){
        printf("please input operation (read/ioctl): ");
        scanf("%s",opt);
        
        if(0 == strcmp(opt,"read")) {
            /* 读取数据 */
            len = read(fd, buf, sizeof(buf));
            printf(": %d\n", len);
        } else if(0 == strcmp(opt,"ioctl")) {
            printf("select 0(A-a)/1(a-A): ");
            while(scanf("%d",&cmd) != EOF) {
                ret = ioctl(fd,cmd,&arg);
                if (ret<0)                    
                    continue;
                else
                    break;
            }
        }else {
            printf("input error!\n");
            continue;
        }
    } 
    
    /* 关闭设备 */
    close(fd);
    return 0;

}


linux内核模块具体加载流程,详见:

linux内核模块Makefile编写格式,详见:

gcc的Makefile编写格式,详见: