油売り算 #2

幅優先探索自体を部品化してしまうのが、関数型プログラミングなんだろうなと思ったのでやってみた。まぁ、だからなんだという話だ。

import sys
from collections import deque

class cons(object):
    __slots__ = ('car', 'cdr')

    def __init__(self, car, cdr):
        self.car = car
        self.cdr = cdr

    def __iter__(self):
        c = self
        while c is not None:
            yield c.car
            c = c.cdr

def breadth_first_search(transit, enum_step, criteria, initial_state):
    if criteria(initial_state):
        return ()

    queue = deque()
    queue.append((None, initial_state))
    visited = set()

    while queue:
        seq, state = queue.popleft()
        for step in enum_step(state):
            new_state = transit(state, step)
            if new_state in visited:
                continue

            if criteria(new_state):
                correct_seq = list(cons(step, seq))
                correct_seq.reverse()
                return (correct_seq, new_state)

            queue.append((cons(step, seq), new_state))
            visited.add(new_state)
    return None

def solve_abura(a, b, c):
    limits = (a, b, c)
    
    def move(state, (i, j)):
        new_state = list(state)
        if state[i] + state[j] > limits[j]:
            new_state[i] -= limits[j] - state[j]
            new_state[j] = limits[j]
        else:
            new_state[j] += state[i]
            new_state[i] = 0
        return tuple(new_state)

    def criteria(state):
        return state == (a / 2, a / 2, 0)

    def enum_step(state):
        return ((0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1))

    initial_state = (a, 0, 0)

    return breadth_first_search(move, enum_step, criteria, initial_state)

def print_sequence(seq):
    chars = 'A B C'.split()
    for i, j in seq:
        print 'move from %c to %c' % (chars[i], chars[j])

def main(a, b, c):
    answer = solve_abura(a, b, c)

    if answer is not None:
        print_sequence(answer[0])
    else:
        print 'cannot be solved'
                
if __name__ == '__main__':
    if len(sys.argv) == 4:
        a, b, c = [int(s) for s in sys.argv[1:]]
    else:
        a, b, c = 10, 7, 3

    main(a, b, c)