フィボナッチ数列を計算するデバイスドライバ

Amazon から プログラミング言語Erlang入門 が届きました。

どんな構成だろうね、と会社で同僚数人とわいわいやっていたら、「フィボナッチ数列を計算するサーバー」という例があって、みんなのツボに入りました。Erlang の並列計算処理能力とネットワークプログラミングのしやすさを示すという上で良い例だと思うのですが、「フィボナッチ数列を計算する」というのと「ネットワークサーバーを書く」、という二つのテーマの不思議なギャップが面白いのでしょう。

そういえば関数型言語が得意な id:maoe は、はてなの採用面接の際に、はてなのボーナス計算を計算するシステムを作ってきたのですが、なぜかクライアント/サーバシステム、ネットワークサーバーを Haskell で、クライアントを Scheme で書き、プロトコルが S 式という実装をみんなの前で披露して、周囲の笑いを誘っていました。

ちょっとしたプログラムの実行に大げさなインタフェース/実行環境を用意する、というネタは作るのも結構楽しいものです。そんなわけで、ついカッとなってフィボナッチ数を出力する Linuxデバイスドライバを作りました。

% cat /proc/fib
10

と、proc 経由で項数が分かります。そして /dev/fib を read すると

% cat /dev/fib
55

と項数 10 のフィボナッチ数が得られます。項数を変更するのは ioctl(2) で行います。ここでは ioctl(2) を呼ぶプログラム fibdev_ioctl を実行して変更します。

% ./fibdev_ioctl 20
% cat /proc/fib
20
% cat /dev/fib
6765

項数を大きくすると計算に時間がかかって OS の反応が鈍くなってしまいす。カーネル空間ではそんなにがんばって計算してはいけません、という良い(悪い)例です。おバカですいません。でもいいんです。

ソースを晒しておきます。

fibdev.c

デバイスドライバ本体です。キャラクタデバイスです。CentOS 4.4 のカーネル 2.6.9-42 を前提にしています。フィボナッチ数の計算は一応、デバイスバッファ上でメモ化しています。メジャー番号は動的に割り当てているので、dmesg や /var/log/messages などで printk の出力を確認した後

% sudo mknod /dev/fib c 254 0

と mknod でデバイスファイルを作りました。

#include <linux/module.h>
#include <linux/init.h>
#include <linux/fs.h>
#include <linux/cdev.h>
#include <linux/proc_fs.h>
#include <asm/uaccess.h>
#include <asm/semaphore.h>

#include "fibdev_ioctl.h"

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Naoya Ito");

enum {
    MAXFIBS    = 99,
    MAXCOPYBUF = 10,
};

struct fibdev_t
{
    dev_t            id;
    int              n;
    struct semaphore sem;
    struct cdev      cdev;
    unsigned long    fibs[MAXFIBS];
};

static const char *msg = "fibdev";
static struct fibdev_t *fibdev_p;

static int     fibdev_open (struct inode *inode, struct file *filp);
static ssize_t fibdev_read (struct file *filp, char __user *user_buf, size_t count, loff_t *offset);
static int     fibdev_ioctl (struct inode *inode, struct file *filp, unsigned int cmd, unsigned long arg);

static struct file_operations fibdev_fops = {
    .owner = THIS_MODULE,
    .open  = fibdev_open,
    .read  = fibdev_read,
    .ioctl = fibdev_ioctl,
};

static unsigned long fib(int n) 
{
    if (n <= 2) return 1L;
    return fib(n -1) + fib(n -2);
}

static unsigned long fib_memoized (int n, struct fibdev_t *dev) 
{
    if (dev->fibs[n]) return dev->fibs[n];
    return dev->fibs[n] = fib(n);
}

static int fibdev_open (struct inode *inode, struct file *filp) 
{
    filp->private_data = container_of(inode->i_cdev, struct fibdev_t, cdev);
    return 0;
}

static ssize_t fibdev_read (struct file *filp, char __user *user_buf, size_t count, loff_t *offset)
{
    struct fibdev_t *dev = filp->private_data;
    
    if (*offset > 0)
        return 0;

    if (down_interruptible(&dev->sem))
        return -ERESTARTSYS;

    char copy_buf[MAXCOPYBUF];
    int copy_len = sprintf(copy_buf, "%lu\n", fib_memoized(dev->n, dev));
    
    if (copy_to_user(user_buf, copy_buf, copy_len)) {
        up(&dev->sem);
        printk(KERN_ALERT "%s: copy_to_user failed\n", msg);
        return -EFAULT;
    }

    up(&dev->sem);
    *offset += copy_len;
    
    return copy_len;
}

static int fibdev_ioctl (struct inode *inode, struct file *filp, unsigned int cmd, unsigned long arg)
{
    int retval = 0;
    struct fibdev_t *dev = filp->private_data;
    
    switch (cmd) {
    case FIBDEV_IOC_RESET:        
        dev->n = 1;
        break;
    case FIBDEV_IOC_SETN:
        retval = get_user(dev->n, (int __user *)arg);
        if (dev->n > MAXFIBS)
            dev->n = 1;
        break;
    case FIBDEV_IOC_GETN:
        retval = put_user(dev->n, (int __user *)arg);
        break;
    default:
        retval = -ENOTTY;
        break;
    }

    return retval;
}

static int fibdev_current_n (char *page, char **start, off_t offset, int count, int *eof, void *data) 
{
    return sprintf(page, "%d\n", fibdev_p->n);
}

static void fibdev_setup(struct fibdev_t *dev) 
{
    int err = alloc_chrdev_region(&dev->id, 0, 1, "fibdev");
    if (err) {
        printk(KERN_ALERT "%s: alloc_chrdev_region() failed (%d)\n", msg, err);
        return;
    } else
        printk(KERN_INFO "%s : Major number: %d\n", msg, MAJOR(dev->id));

    cdev_init(&dev->cdev, &fibdev_fops);
    dev->cdev.owner = THIS_MODULE;
    dev->n = 1;
    init_MUTEX(&dev->sem);
    
    err = cdev_add(&dev->cdev, dev->id, 1);
    if (err) {
        printk(KERN_ALERT "%s: cdev_add() failed (%d)\n", msg, err);
        return;
    }
}

static int __init fibdev_init (void) 
{
    fibdev_p = kmalloc(sizeof(struct fibdev_t), GFP_KERNEL);
    if (!fibdev_p) {
        printk(KERN_ALERT "%s: kmalloc() failed\n", msg);
        return -ENOMEM;
    }
    memset(fibdev_p, 0, sizeof (struct fibdev_t));
    fibdev_setup(fibdev_p);

    create_proc_read_entry("fib", 0, NULL, fibdev_current_n, NULL);

    return 0;
}

static void __exit fibdev_exit (void) 
{
    remove_proc_entry("fib", NULL);
    cdev_del(&fibdev_p->cdev);
    unregister_chrdev_region(fibdev_p->id, 1);
    kfree(fibdev_p);
}

module_init(fibdev_init);
module_exit(fibdev_exit);

まともなデバイスドライバの開発の経験がないので変てこかもですが、そのあたりはご愛嬌。

fibdev_ioctl.h

ioctl(2) のための定数を定義した fibdev_ioctl.h。コマンドはリセット、項数設定、項数取得の三つのみ。

#ifndef FIBDEV_IOCTL_H
#define FIBDEV_IOCTL_H

#define FIBDEV_IOC_MAGIC 'k'
#define FIBDEV_IOC_RESET _IO(FIBDEV_IOC_MAGIC, 0)
#define FIBDEV_IOC_SETN  _IOW(FIBDEV_IOC_MAGIC, 1, int)
#define FIBDEV_IOC_GETN  _IOR(FIBDEV_IOC_MAGIC, 2, int)

#endif

Makefile

make ファイル。Linuxデバイスドライバ 第3版 から拝借。

ifneq ($(KERNELRELEASE),)
	obj-m :=fibdev.o
else
	KERNELDIR ?= /lib/modules/$(shell uname -r)/build
	PWD := $(shell pwd)

default:
	$(MAKE) -C $(KERNELDIR) M=$(PWD) modules

clean:
	$(MAKE) -C $(KERNELDIR) -r M=$(PWD) clean

endif

fibdev_ioctl.c

項数を設定する ioctl(2) を実行するユーザープログラムです。

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <unistd.h>
#include <errno.h>

#include "fibdev_ioctl.h"

int main (int argc, char **argv) 
{
    if (argc != 2) {
        fprintf(stderr, "usage: fibdev_ioctl <num>\n");
        exit(-1);
    }
    
    int fd;
    if ((fd = open("/dev/fib", O_RDONLY)) < 0) {
        perror(argv[0]);
        exit(errno);
    }

    int n = atoi(argv[1]);
    if (ioctl(fd, FIBDEV_IOC_SETN, &n) < 0) {
        perror(argv[0]);
        exit(errno);
    }
    close(fd);

    return 0;
}