Ark's Blog

参加記や備忘録などを書いていきます

ようこそ

LCAのライブラリを作った

LCAのライブラリ作ってないなあと気づいたので作った

LCAとは

Lowest Common Ancestorの略で最小共通祖先です。
根付き木 T=(V, E)に対して、次のクエリを解く問題です。

  •  \text{lca}(u, v) := u\text{と}v\text{の共通祖先のうち最も近い頂点 } (u, v \in V)

定番の手法として「ダブリングによる実装」と「Euler Tourとセグ木による実装」があります。

  • ダブリングによる実装
    • 初期化:  O(n\log n)
    • クエリ: O(\log n)
  • Euler Tourとセグ木による実装
    • 初期化:  O(n)
    • クエリ: O(\log n)

詳しくは蟻本に載っているのでそちらを見るといいかと思います。
理解も兼ねてとりあえず両方共実装した。

確認としてABC 014 D - 閉路を使った。

ソースコード

ダブリングによる実装

// LCA: Lowest Common Ancestor
//   using Doubling Technique

struct LCA {

private:
    size_t _size;
    size_t _logSize;
    Vertex[] _vertices;
    size_t _rootIndex;
    bool _builded;

public:

    this(size_t size) {
        _size = size;
        _logSize = cast(size_t)(log2(_size) + 2);
        _vertices.length = _size;
        foreach(i, ref v; _vertices) {
            _vertices[i] = new Vertex;
            _vertices[i].index = i;
            _vertices[i].powParents.length = _logSize;
        }
        _rootIndex = 0;
        _builded = false;
    }

    // O(1)
    void addEdge(size_t x, size_t y) {
        _builded = false;
        Vertex vx = _vertices[x];
        Vertex vy = _vertices[y];
        vx.edges ~= Edge(vx, vy);
        vy.edges ~= Edge(vy, vx);
    }

    // O(1)
    void setRoot(size_t index) {
        _builded = false;
        _rootIndex = index;
    }

    // O(log N)
    size_t queryIndex(size_t x, size_t y) {
        return queryVertex(x, y).index;
    }

    // O(log N)
    int queryDepth(size_t x, size_t y) {
        return queryVertex(x, y).depth;
    }

    // O(1)
    int getDepth(size_t x) {
        // return queryDepth(x, x);
        if (!_builded) build();
        return _vertices[x].depth;
    }

private:

    // O(N log N)
    void build() {
        _vertices.each!(v => v.visited = false);

        void dfs(Vertex vertex, Vertex parent, int depth) {
            vertex.powParents[0] = parent;
            vertex.depth = depth;
            foreach(v; vertex.edges.map!"a.end") {
                if (v.visited) continue;
                v.visited = true;
                dfs(v, vertex, depth + 1);
            }
        }
        _vertices[_rootIndex].visited = true;
        dfs(_vertices[_rootIndex], null, 0);

        foreach(k; 1.._logSize) {
            foreach(v; _vertices) {
                if (v.powParents[k-1] is null) continue;
                v.powParents[k] = v.powParents[k-1].powParents[k-1];
            }
        }

        _builded = true;
    }

    Vertex queryVertex(size_t x, size_t y) {
        if (!_builded) build();

        Vertex vx = _vertices[x];
        Vertex vy = _vertices[y];

        if (vx.depth > vy.depth) swap(vx, vy);
        foreach(k; 0.._logSize) {
            if ((vy.depth - vx.depth)>>k&1) {
                vy = vy.powParents[k];
                assert(vy !is null);
            }
        }
        assert(vx.depth == vy.depth);
        if (vx is vy) return vx;

        foreach(k; _logSize.iota.retro) {
            if (vx.powParents[k] !is vy.powParents[k]) {
                vx = vx.powParents[k];
                vy = vy.powParents[k];
                assert(vx !is null && vy !is null);
            }
        }
        assert(vx.powParents[0] is vy.powParents[0]);

        return vx.powParents[0];
    }

    class Vertex {
        size_t index;
        int depth;
        bool visited;
        Vertex[] powParents; // powParents[i][j] = 頂点iから2^j回親を辿った頂点(いなければnull)
        Edge[] edges;
    }

    struct Edge {
        Vertex start, end;
    }

}

Euler Tourとセグ木による実装

// LCA: Lowest Common Ancestor
//   using EulerTour & SegTree

struct LCA {

private:
    alias SegT = SegTree!(SegNode, (a, b) => a.depth<b.depth ? a:b, SegNode(int.max, size_t.max));

    size_t _size;
    Vertex[] _vertices;
    size_t _rootIndex;
    SegT _segT;
    bool _builded;

public:
    // O(N)
    this(size_t size) {
        _size = size;
        _vertices.length = size;
        foreach(i, ref v; _vertices) {
            v = new Vertex(i);
        }
        _rootIndex = 0;
        _builded = false;
    }

    // O(1)
    void addEdge(size_t x, size_t y) {
        _builded = false;
        Vertex vx = _vertices[x];
        Vertex vy = _vertices[y];
        vx.edges ~= Edge(vx, vy);
        vy.edges ~= Edge(vy, vx);
    }

    // O(1)
    void setRoot(size_t index) {
        _builded = false;
        _rootIndex = index;
    }

    // O(log N)
    size_t queryIndex(size_t x, size_t y) {
        return querySegNode(x, y).index;
    }

    // O(log N)
    int queryDepth(size_t x, size_t y) {
        return querySegNode(x, y).depth;
    }

    // O(1)
    int getDepth(size_t x) {
        // return queryDepth(x, x);
        if (!_builded) build();
        return _segT.get(_vertices[x].id).depth;
    }

private:

    // O(N)
    void build() {
        _vertices.each!(v => v.visited = false);

        // Euler tour
        SegNode[] segNodes = new SegNode[2*_size];
        size_t dfs(Vertex v, int depth, size_t id) {
            segNodes[id].depth = depth;
            segNodes[id].index = v.index;
            v.id = id;
            foreach(u; v.edges.map!"a.end") {
                if (u.visited) continue;
                u.visited = true;
                id = dfs(u, depth + 1, id + 1) + 1;
                segNodes[id].depth = depth;
                segNodes[id].index = v.index;
            }
            return id;
        }
        _vertices[_rootIndex].visited = true;
        dfs(_vertices[_rootIndex], 0, 0);

        _segT = SegT(segNodes);

        _builded = true;
    }

    // O(log N)
    SegNode querySegNode(size_t x, size_t y) {
        if (!_builded) build();
        size_t idX = _vertices[x].id;
        size_t idY = _vertices[y].id;
        return _segT.query(min(idX, idY), max(idX, idY) + 1);
    }

    class Vertex {
        size_t index;
        size_t id;
        bool visited;
        Edge[] edges;
        this(size_t index) {
            this.index = index;
        }
    }
    struct Edge {
        Vertex start, end;
    }
    struct SegNode {
        int depth;
        size_t index;
    }
}

// SegTree (Segment Tree)
struct SegTree(T, alias fun, T initValue)
    if (is(typeof(fun(T.init, T.init)) : T)) {

private:
    Node[] _data;
    size_t _size;
    size_t _l, _r;

public:
    // size ... データ数
    // initValue ... 初期値(例えばRMQだとINF)
    this(size_t size) {
        init(size);
    }

    // 配列で指定
    this(T[] ary) {
        init(ary.length);
        update(ary);
    }

    // O(N)
    void init(size_t size){
        _size = 1;
        while(_size < size) {
            _size *= 2;
        }
        _data.length = _size*2-1;
        _data[] = Node(size_t.max, initValue);
        _l = 0;
        _r = size;
    }

    // i番目の要素をxに変更
    // O(logN)
    void update(size_t i, T x) {
        size_t index = i;
        i += _size-1;
        _data[i] = Node(index, x);
        while(i > 0) {
            i = (i-1)/2;
            Node nl = _data[i*2+1];
            Node nr = _data[i*2+2];
            _data[i] = select(nl, nr);
        }
    }

    // 配列で指定
    // O(N)
    void update(T[] ary) {
        foreach(i, e; ary) {
            _data[i+_size-1] = Node(i, e);
        }
        foreach(i; (_size-1).iota.retro) {
            Node nl = _data[i*2+1];
            Node nr = _data[i*2+2];
            _data[i] = select(nl, nr);
        }
    }

    // 区間[a, b)でのクエリ (値の取得)
    // O(logN)
    T query(size_t a, size_t b) {
        return queryRec(a, b, 0, 0, _size).value;
    }

    // 区間[a, b)でのクエリ (indexの取得)
    // O(logN)
    size_t queryIndex(size_t a, size_t b) out(result) {
        // fun == (a, b) => a+b のようなときはindexを聞くとassertion
        assert(result != size_t.max);
    } body {
        return queryRec(a, b, 0, 0, _size).index;
    }

    private Node queryRec(size_t a, size_t b, size_t k, size_t l, size_t r) {
        if (b<=l || r<=a) return Node(size_t.max, initValue);
        if (a<=l && r<=b) return _data[k];
        Node nl = queryRec(a, b, k*2+1, l, (l+r)/2);
        Node nr = queryRec(a, b, k*2+2, (l+r)/2, r);
        return select(nl, nr);
    }

    private Node select(Node nl, Node nr) {
        T v = fun(nl.value, nr.value);
        if (nl.value == v) {
            return nl;
        } else if (nr.value == v) {
            return nr;
        } else {
            return Node(size_t.max, v);
        }
    }

    // O(1)
    T get(size_t i) {
        return _data[_size-1 + i].value;
    }

    // O(N)
    T[] array() {
        return _data[_l+_size-1.._r+_size-1].map!"a.value".array;
    }

private:
    struct Node {
        size_t index;
        T value;
    }
}

参考