/* $Id: halotest.c 1324 2006-11-28 16:07:45Z olau $ 
 * Copyright (c) 2006 Oliver Lau <ola@ctmagazin.de>
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <mpi.h>

#include "../globaldefs.h"
#include "../helper.h"

#define TAG_DIST_GATHER  (31337)

#define ROOT                 (0)
#define TESTRANK             (0)

#define WIDTH                (8)
#define HEIGHT               (8)

#define DEFAULT_T            (1)
#define DEFAULT_TICK         (1)

#define MAX_NEIGHBORS        (8)

enum {
    W  = 0,
    E  = 1,
    N  = 2,
    S  = 3,
    NW = 4,
    NE = 5,
    SW = 6,
    SE = 7
};

// const char *dirstr[8] = { "W", "E", "N", "S", "NW", "NE", "SW", "SE" };

typedef struct _direction_t {
    int target_rank;
    int target_offset;
    int orig_offset;
    MPI_Datatype orig_type;
    MPI_Datatype target_type;
} direction_t;


void print_world(int gen, int rank, int *cell, int width, int height) {
    printf("\nGeneration %d (Prozess %d)", gen, rank);
    printf("\nvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv\n");
    for (int y = -1; y <= height; ++y) {
        for (int x = -1; x <= width; ++x) {
            printf(" %5d", cell[x + y * (width+2)]);
        }
        printf("\n");
    }
    printf("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n");
}


void iterate(int *cell, int *cell_new, int width, int height, int n) {
    for (int y = 0; y < height; ++y) {
        for (int x = 0; x < width; ++x) {
            int cn = 0;
            for (int yy = -1; yy <= 1; ++yy) {
                for (int xx = -1; xx <= 1; ++xx) {
                    cn += cell[(x + xx) + (y + yy) * (width+2)];
                }
            }
            cell_new[x + y * (width+2)] = cn;
        }
    }
}


// hier gehts los ...
int main(int argc, char *argv[]) {
    int size; // Anzahl der Jobs
    int myrank; // Job-ID
    int width; // Breite des Weltausschnitts
    int height; // Hoehe des Weltausschnitts
    int dims[2] = { 0, 0 }; // Anzahl Spalten/Zeilen je Dimension
    int periods[2] = { 1, 1 }; // Welt ist in beiden Dimensionen periodisch
    int cart_coords[2], coords[2];
    int *sec0, *sec0_o;
    int *sec1, *sec1_o;
    int *tmpworld, *world = NULL;
    direction_t direction[8];
    int neighbor_ranks[8];
    int num_real_neighbors;
    MPI_Status stat;
    MPI_Request req;
    MPI_Comm comm;
    MPI_Win win;
    MPI_Group world_group, neighbors;
    bool existing;
    int i, j, k;

    int globalwidth = WIDTH;
    int globalheight = HEIGHT;
    
    int tmax = DEFAULT_T;
    if (argc > 1)
        tmax = atoi(argv[1]);

    int ttick = DEFAULT_TICK;
    if (argc > 2)
        ttick = atoi(argv[2]);

    // MPI-Umgebung initialisieren
    MPI_Init(&argc, &argv);

    // Anzahl der Prozesse im Kommunikator ermitteln
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    // Anzahl der Abschnitte berechnen, in die die Welt in
    // horizontaler und vertikaler Richtung zerlegt werden soll
    MPI_Dims_create(size, 2, dims);

    // Einen neuen Kommunikator mit der 2D-Welt verknuepfen
    MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 1, &comm);

    // eigene ID innerhalb der zweidimensionalen Welt bestimmen
    MPI_Comm_rank(comm, &myrank);

    // Breite und Hoehe der Welt muessen durch die Anzahl 
    // der Abschnitte in jeder Dimension teilbar sein, damit
    // die Abschnitte nahtlos aneinander liegen
    globalwidth  -= globalwidth  % dims[0];
    globalheight -= globalheight % dims[1];

    // Die Breite eines Abschnitts ergibt sich aus der Breite 
    // bzw. Hoehe der Welt geteilt durch die Anzahl der Abschnitte
    // in der jeweiligen Dimension
    width         = globalwidth  / dims[0];
    height        = globalheight / dims[1];

    // Horizontales Halo
    MPI_Datatype row_type;
    MPI_Type_vector(1, width, width, MPI_INT, &row_type);
    MPI_Type_commit(&row_type);

    // Vertikales Halo
    MPI_Datatype column_type;
    MPI_Type_vector(height, 1, width+2, MPI_INT, &column_type);
    MPI_Type_commit(&column_type);

    // Rechteckiger Ausschnitt aus der Welt inklusive re.+li. Halo
    MPI_Datatype section_type;
    MPI_Type_vector(height, width, width+2, MPI_INT, &section_type);
    MPI_Type_commit(&section_type);

    // Rechteckiger Ausschnitt aus der Welt
    MPI_Datatype submatrix_type;
    MPI_Type_vector(height, width, globalwidth, MPI_INT, &submatrix_type);
    MPI_Type_commit(&submatrix_type);

    // Speicher fuer eigenen Ausschnitt inklusive Halo belegen
    sec0_o = (int *) safe_malloc((width+2)*(height+2)*sizeof(*sec0_o));
    sec0 = sec0_o + 1*(width+2) + 1;
    sec1_o = (int *) safe_malloc((width+2)*(height+2)*sizeof(*sec1_o));
    sec1 = sec1_o + 1*(width+2) + 1;
    memset(sec0_o, 9999, (width+2)*(height+2)*sizeof(*sec0_o));
    memset(sec1_o, 9999, (width+2)*(height+2)*sizeof(*sec1_o));

    if (myrank == ROOT) {
        printf("globalwidth  = %d\n", globalwidth);
        printf("globalheight = %d\n", globalheight);
        printf("dims[]       = (%d, %d)\n", dims[0], dims[1]);
        printf("size         = %d\n", size);
        printf("width        = %d\n", width);
        printf("height       = %d\n", height);
        world = safe_malloc(globalwidth * globalheight * sizeof(*world));
        for (int y = 0; y < globalheight; ++y)
            for (int x = 0; x < globalwidth; ++x)
                world[x + y * globalwidth] = (x + 1) * 100 + (y + 1);
        printf("Berechnen von %d Iterationen ..\n", tmax);
    }

    // Die kartesischen Koordinaten des zugeteilten Abschnitts ermitteln
    MPI_Cart_coords(comm, myrank, 2, cart_coords);

    MPI_Cart_shift(comm, 0, +1, &direction[W].target_rank, &direction[E].target_rank);
    direction[W].orig_type     = column_type;
    direction[W].target_type   = column_type;
    direction[W].orig_offset   = width + 2 + 1;
    direction[W].target_offset = width + 2 + width + 1;
    direction[E].orig_type     = column_type;
    direction[E].target_type   = column_type;
    direction[E].orig_offset   = width + 2 + width;
    direction[E].target_offset = width + 2;

    MPI_Cart_shift(comm, 1, +1, &direction[N].target_rank, &direction[S].target_rank);
    direction[N].orig_type     = row_type;
    direction[N].target_type   = row_type;
    direction[N].orig_offset   = width + 2 + 1;
    direction[N].target_offset = (height + 1) * (width + 2) + 1; 
    direction[S].orig_type     = row_type;
    direction[S].target_type   = row_type;
    direction[S].orig_offset   = (height) * (width + 2) + 1;
    direction[S].target_offset = 1;

    // Die kartesischen Koordinaten des zugeteilten Abschnitts ermitteln
    MPI_Cart_coords(comm, myrank, 2, cart_coords);

    coords[0] = cart_coords[0] - 1;
    coords[1] = cart_coords[1] - 1;
    MPI_Cart_rank(comm, coords, &direction[NW].target_rank);
    direction[NW].orig_type     = MPI_INT;
    direction[NW].target_type   = MPI_INT;
    direction[NW].orig_offset   = width + 2 + 1;
    direction[NW].target_offset = (height + 1) * (width + 2) + width + 1;

    coords[0] = cart_coords[0] + 1;
    coords[1] = cart_coords[1] - 1;
    MPI_Cart_rank(comm, coords, &direction[NE].target_rank);
    direction[NE].orig_type     = MPI_INT;
    direction[NE].target_type   = MPI_INT;
    direction[NE].orig_offset   = width + 2 + width;
    direction[NE].target_offset = (height + 1) * (width + 2);

    coords[0] = cart_coords[0] - 1;
    coords[1] = cart_coords[1] + 1;
    MPI_Cart_rank(comm, coords, &direction[SW].target_rank);
    direction[SW].orig_type     = MPI_INT;
    direction[SW].target_type   = MPI_INT;
    direction[SW].orig_offset   = height * (width + 2) + 1;
    direction[SW].target_offset = width + 1;

    coords[0] = cart_coords[0] + 1;
    coords[1] = cart_coords[1] + 1;
    MPI_Cart_rank(comm, coords, &direction[SE].target_rank);
    direction[SE].orig_type     = MPI_INT;
    direction[SE].target_type   = MPI_INT;
    direction[SE].orig_offset   = height * (width + 2) + width;
    direction[SE].target_offset = 0;

    MPI_Comm_group(comm, &world_group);
    for (i = 0, j = 0; i < MAX_NEIGHBORS; ++i) {
        existing = FALSE;
        for (k = 0; k < j && !existing; ++k)
            if (neighbor_ranks[k] == direction[i].target_rank)
                existing = TRUE;
        if (!existing)
            neighbor_ranks[j++] = direction[i].target_rank;
    }
    num_real_neighbors = j;

    MPI_Win_create(sec0_o, (width+2)*(height+2)*sizeof(*sec0_o), sizeof(*sec0_o),
                   MPI_INFO_NULL, comm, &win);

    /*
    MPI_Win_fence(0, win);
    printf("neighbors of rank %d: ", myrank);
    for (i = 0; i < num_real_neighbors; ++i)
        printf("%d ", neighbor_ranks[i]);
    printf("\n");
    */
    MPI_Win_fence(0, win);

    MPI_Group_incl(world_group, num_real_neighbors, neighbor_ranks, &neighbors);
    MPI_Group_free(&world_group);

    for (int y = -1; y <= height; ++y)
        for (int x = -1; x <= width; ++x)
            sec0[y * (width+2) + x] = 0;

    if (myrank == ROOT) {
        printf("\n");
        for (int y = 0; y < globalheight; ++y) {
            for (int x = 0; x < globalwidth; ++x) {
                printf(" %4d", world[x + y * globalwidth]);
            }
            printf("\n");
        }
    }

    // Welt abschnittsweise an Knoten verteilen
    MPI_Irecv(sec0, 1, section_type, ROOT, TAG_DIST_GATHER, comm, &req);
    if (myrank == ROOT) {
        for (int y = 0; y < dims[1]; ++y) {
            int xy[2];
            xy[1] = y;
            for (int x = 0; x < dims[0]; ++x) {
                xy[0] = x;
                int destination_rank;
                MPI_Cart_rank(comm, xy, &destination_rank);
                // printf("sending block %2d + %2d ..\n", y*height*globalwidth, x*width);
                MPI_Send(world + y * height * globalwidth + x * width,
                         1, submatrix_type,
                         destination_rank, TAG_DIST_GATHER,
                         comm);
            }
        }
        for (int y = 0; y < globalheight; ++y)
            for (int x = 0; x < globalwidth; ++x)
                world[x + y * globalwidth] = (x + 1) * 10 + (y + 1);
    }
    // auf Ende von MPI_Irecv() warten
    MPI_Wait(&req, &stat);

    MPI_Barrier(comm);

    if (myrank == TESTRANK)
        print_world(0, myrank, sec0, width, height);

    int t = 0;
    do {
        for (int dt = 0; dt < ttick; ++dt) {
            //MPI_Barrier(comm);
            MPI_Win_fence(0, win);
            //MPI_Win_start(neighbors, 0, win);
            //MPI_Win_post(neighbors, 0, win);
            for (int i = 0; i < MAX_NEIGHBORS; ++i) {
                MPI_Put(sec0_o + direction[i].orig_offset,
                        1,
                        direction[i].orig_type, 
                        direction[i].target_rank,
                        direction[i].target_offset,
                        1,
                        direction[i].target_type,
                        win);
            }
            //MPI_Barrier(comm);
            MPI_Win_fence(0, win);
            //MPI_Win_complete(win);
            //MPI_Win_wait(win);
            iterate(sec0, sec1, width, height, t+dt);
            // Matrizen tauschen
            memcpy(sec0_o, sec1_o, (width+2)*(height+2)*sizeof(*sec0_o)); 
            //tmpworld = sec0;
            //sec0 = sec1;
            //sec1 = tmpworld;
        }
        t += ttick;
        if (myrank == TESTRANK)
            print_world(t, myrank, sec0, width, height);

        // Berechnung fertig, Ergebnis an Hauptprozess schicken
        MPI_Isend(sec0, 1, section_type, ROOT, TAG_DIST_GATHER, comm, &req);
        
        // Der Hauptprozess sammelt die Ergebnisse ein ...
        if (myrank == ROOT) {
            for (int y = 0; y < dims[1]; ++y) {
                int xy[2];
                xy[1] = y;
                for (int x = 0; x < dims[0]; ++x) {
                    xy[0] = x;
                    int rank;
                    MPI_Cart_rank(comm, xy, &rank);
                    MPI_Recv(world + x * width + y * globalwidth * height, 
                             1, submatrix_type, rank, TAG_DIST_GATHER, comm, &stat);
                }
            }
        }

        // Warten, bis Hauptprozess die Daten empfangen hat
        MPI_Wait(&req, &stat);

        MPI_Barrier(comm);
    }
    while (t < tmax);

    MPI_Barrier(comm);


    if (myrank == ROOT) {
        printf("\n");
        for (int y = 0; y < globalheight; ++y) {
            for (int x = 0; x < globalwidth; ++x) {
                printf(" %4d", world[x + y * globalwidth]);
            }
            printf("\n");
        }
        free(world);
        printf("\nFertig.\n");
    }

    /* Aufraeumen */
    free(sec0_o);
    free(sec1_o);
    MPI_Type_free(&submatrix_type);
    MPI_Type_free(&section_type);
    MPI_Type_free(&column_type);
    MPI_Type_free(&row_type);
    MPI_Comm_free(&comm);
    MPI_Finalize();
    
    return EXIT_SUCCESS;
}
