LinkedList の途中への挿入の効率

連結リストの挿入は O(1) ってのに異論はないんだけど、挿入場所が分かっていなければ O(1) じゃなくて O(n) になるってのは忘れられがち・・・というか、下手したら認識されてさえいない。


例えば List#add(int index, E element) の効率だけど、これは index の場所を見つけるまで先頭、もしくは末尾からシーケンシャルな走査が必要なため、O(1) じゃなくて O(n) となる。
Sun の実装は、

public void add(int index, E element) {
    addBefore(element, (index==size ? header : entry(index)));
}

public Entry<E> entry(int index) {
    if (index < 0 || index >= size)
        throw new IndexOutOfBoundsException("Index: "+index+
                                            ", Size: "+size);
    Entry<E> e = header;
    if (index < (size >> 1)) {
        for (int i = 0; i <= index; i++)
            e = e.next;
    } else {
        for (int i = size; i > index; i--)
            e = e.previous;
    }
    return e;
}

こんな感じに、index が size と等しくない場合は近い方から走査が走っている。
なので、リストの中央に add(int, E) で要素を挿入するような場合、ArrayList も LinkedList も O(n) となる。
さらに、ArrayList に比べて LinkedList は定数項が大きいので、基本的に ArrayList よりも低速となる。

サイズ LinkedList ArrayList
10000 1067 744
20000 2078 1139
30000 2968 1443
40000 3904 1803
50000 4815 2121
60000 5733 2560
70000 6710 2822
80000 7718 3186
90000 8640 3536
100000 9616 3886

ではどんなメソッドを使えばいいかというと・・・実は、どんぴしゃりなメソッドはない。
なので、イテレータを介することになるんだけど・・・Iterator には add 系のメソッドがないので、ListIterator を使用する。
この場合、LinkedList では O(1)、ArrayList では O(n) となり、LinkedList の方が高速になる。

サイズ LinkedList ArrayList
10000 46 529
20000 49 977
30000 48 1407
40000 49 1924
50000 49 2421
60000 47 2740
70000 49 3199
80000 51 3801
90000 50 4008
100000 51 4400

まとめ

  • LinkedList を使用した場合、挿入だけをするように見えるメソッドでも実は走査が走っているのは見落としがち
  • もし O(1) でリストの途中に挿入したい場合は、拡張 for 文やリストのメソッドを使用せずに、ListIterator を使用する必要がある
  • add(int, E) を使うなら、ArrayList を使用した方が基本的には高速

使用した環境

PC 等
CPU Core2Duo E8400
メモリ DDR2 800 2GB * 2
OS Windows Vista Ultimate SP1 (64bit)
JVM 1.6.0_14
実行オプション -Xms1024m -Xmx1024m
List#add(int index, E element) のコード

test メソッド内で List に挿入している。

import java.util.*;

public class Main {
    
    // ここと
    static final int SIZE = 10000;
    static final int INS_TIMES = 100;
    static final int INS_POS = SIZE / 2;
    
    static void test(List<String> strs) {
        for (int i = 0; i < INS_TIMES; i++)
            strs.add(INS_POS, "hoge");
    }
    
    static List<String> prepare(int mode) {
        if (mode == 0) {
            LinkedList<String> result = new LinkedList<String>();
            for (int i = 0; i < SIZE; i++)
                result.add("");
            return result;
        } else {
            ArrayList<String> result = new ArrayList<String>(SIZE);
            for (int i = 0; i < SIZE; i++)
                result.add("");
            return result;
        }
    }

    public static void main(String[] args) {
        int total = 0;
        for (int i = 0; i < 100; i++) {
            // ここを変更してコンパイルしなおしてから実行
            List<String> lst = prepare(0);
            long start = System.nanoTime();
            test(lst);
            long time = System.nanoTime() - start;
            total += time / 1000;
        }
        System.out.println(total / 100);
    }
}

ListIterator#add(E) のコード

同じく test メソッド内で挿入。

import java.util.*;

public class Main {
    
    static final int SIZE = 10000;
    static final int INS_TIMES = 100;
    static final int INS_POS = SIZE / 2;
    
    static void test(ListIterator<String> itr) {
        for (int i = 0; i < INS_TIMES; i++)
            itr.add("hoge");
    }
    
    static ListIterator<String> prepare(int mode) {
        ListIterator<String> result = null;
        if (mode == 0) {
            LinkedList<String> ll = new LinkedList<String>();
            for (int i = 0; i < SIZE; i++)
                ll.add("");
            result = ll.listIterator();
        } else {
            ArrayList<String> al = new ArrayList<String>(SIZE);
            for (int i = 0; i < SIZE; i++)
                al.add("");
            result = al.listIterator();
        }
        // イテレータをリストの真ん中まで進めておく
        for (int i = 0; i < INS_POS; i++) {
            if (result.hasNext())
                result.next();
        }
        return result;
    }

    public static void main(String[] args) {
        int total = 0;
        for (int i = 0; i < 100; i++) {
            ListIterator<String> itr = prepare(0);
            long start = System.nanoTime();
            test(itr);
            long time = System.nanoTime() - start;
            total += time / 1000;
        }
        System.out.println(total / 100);
    }
}