30 September 2012

Python code for non rigid registration using opencv and numpy

Here is a simple python code for registering two images. The transform is non-rigid. The error metric is sum of squared difference. Code is very simple, does not handle physical spacing. Main aim was to get hands dirty of numpy/scipy, python, opencv cv2 interface etc.
One practical tip for registration: if all else fails, smooth heavily the motion fields ;-)

##################################################
"""
Created on Sun Jan 15 10:41:19 2012

@author: vsk
"""
import sys, os, cv2, numpy, matplotlib, pylab
def SSD_Register(static_image, moving_image):
    IMG_BLUR = 11;
    ITER = 200;
    TIME_STEP = 0.00010;
    SMOOTH = 1.0;
    K_SIZE = 11;
   
    # Create float versions of all images
    registered_image = numpy.empty(static_image.shape, dtype=numpy.float32);  
    static_image_float = numpy.empty(static_image.shape, dtype=numpy.float32);
    mov_image_float = numpy.empty(static_image.shape, dtype=numpy.float32);

    for i in range(static_image.shape[0]):
        for j in range(static_image.shape[1]):
            static_image_float[i,j] = static_image[i,j];
            mov_image_float[i,j] = moving_image[i,j];
                       
    static_image_blur = cv2.GaussianBlur(static_image_float,(IMG_BLUR,IMG_BLUR), 1.20);
    mov_image_float = cv2.GaussianBlur(mov_image_float, (IMG_BLUR,IMG_BLUR), 1.20);
   
    difference_image = numpy.empty(static_image.shape, dtype = numpy.float32);
    # Space for motion field
    x_field = numpy.empty(static_image.shape, dtype=numpy.float32);
    y_field = numpy.empty(static_image.shape, dtype=numpy.float32);
    for i in range(x_field.shape[0]):
        for j in range(x_field.shape[1]):
            x_field[i,j] = j;
            y_field[i,j] = i;
           
    x_field_update = numpy.empty_like(x_field);
    y_field_update = numpy.empty_like(y_field);
    # Gradient image
    grad_x = numpy.empty_like(x_field_update);
    grad_y = numpy.empty_like(x_field_update);
   
    grad_x = cv2.Sobel(mov_image_float, cv2.CV_32FC1, 1, 0);
    grad_y = cv2.Sobel(mov_image_float, cv2.CV_32FC1, 0, 1);
    grad_x_resampled = numpy.copy(grad_x);
    grad_y_resampled = numpy.copy(grad_y);
    # registration iterations begin
   
    for i in range(x_field.shape[0]):
        for j in range(x_field.shape[1]):
            registered_image[i,j] = mov_image_float[i,j];
           
    difference_image = registered_image - static_image_float;
    print(numpy.sum(numpy.abs(difference_image))),
    ssd_err = numpy.zeros((ITER,1));
    for i in range(ITER):
    #    print i,
        difference_image = registered_image - static_image_float;
        #print(numpy.max(difference_image)),
        #print(numpy.min(difference_image))
    #    print(numpy.sum(numpy.abs(difference_image))),
        ssd_err[i]= numpy.sum(numpy.abs(difference_image));
        x_field_update = TIME_STEP*difference_image*grad_x_resampled;
        y_field_update = TIME_STEP*difference_image*grad_y_resampled;

        x_field_update = cv2.GaussianBlur(x_field_update, (K_SIZE, K_SIZE), 2.5);
        y_field_update = cv2.GaussianBlur(y_field_update, (K_SIZE, K_SIZE), 2.5);
    #    print(numpy.max(y_field_update))
        x_field = x_field - x_field_update;
        y_field = y_field - y_field_update;

        registered_image = cv2.remap(mov_image_float, x_field, y_field, cv2.INTER_CUBIC);
        grad_x_resampled = cv2.remap(grad_x, x_field, y_field, cv2.INTER_CUBIC);
        grad_y_resampled = cv2.remap(grad_y, x_field, y_field, cv2.INTER_CUBIC);
    print('')
    #pylab.imshow(registered_image); pylab.show(); pylab.gray()
    print(numpy.sum(numpy.abs(difference_image)))
    # display section
    display_image = numpy.copy(static_image);
    for i in range(x_field.shape[0]):
        for j in range(x_field.shape[1]):
            display_image[i,j] = registered_image[i,j];

#     cv2.namedWindow('Registered');
#     cv2.imshow("Registered", display_image);
#    
#     cv2.namedWindow('Final');
#     cv2.imshow("Final", moving_image);
#
#     cv2.namedWindow('Initial');
#     cv2.imshow("Initial", static_image);
#    
#     cv2.waitKey(-400)
#     cv2.destroyAllWindows();
#     pylab.plot(ssd_err)
#     pylab.show()
   
    pylab.subplot(131)
    pylab.imshow(static_image)
    frame = pylab.gca();
    frame.axes.get_xaxis().set_visible( False);
    frame.axes.get_yaxis().set_visible( False);
    pylab.title('Static Image');
   
    pylab.subplot(133)
    pylab.imshow(moving_image)
    frame = pylab.gca();
    frame.axes.get_xaxis().set_visible( False);
    frame.axes.get_yaxis().set_visible( False);
    pylab.title('Moving Image');
   
    pylab.subplot(132)
    pylab.imshow(registered_image),
    frame = pylab.gca();
    frame.axes.get_xaxis().set_visible( False);
    frame.axes.get_yaxis().set_visible( False);
    pylab.title('Registered Image');
   
    pylab.show()
    return (registered_image, ssd_err);
   
# ================  Read system arguments else generate images =================
if len(sys.argv) != 3:
    #print('Usage: static_image moving_image');
    static_im = numpy.zeros( (256, 256), dtype=numpy.ubyte);
    static_im[100:150, 100:150] = 100;

    x_field = numpy.empty(static_im.shape, dtype=numpy.float32);
    y_field = numpy.empty(static_im.shape, dtype=numpy.float32);
    for i in range(x_field.shape[0]):
        for j in range(x_field.shape[1]):
            x_field[i,j] = j + 12.00485;
            y_field[i,j] = i -  8.35;
   
    mov_im = numpy.zeros( (256, 256), dtype=numpy.ubyte);
    mov_im = cv2.remap(static_im, x_field, y_field, cv2.INTER_CUBIC);
    #mov_im[102:152, 100:150] = 100;
   
    SSD_Register(static_im, mov_im);  
else:
    static_image = cv2.imread(sys.argv[1], cv2.CV_LOAD_IMAGE_GRAYSCALE);
    moving_image = cv2.imread(sys.argv[2], cv2.CV_LOAD_IMAGE_GRAYSCALE)
    SSD_Register(static_image, moving_image)