
/* Copyright (C) 2002-2008 Free Software Foundation, Inc.
   Contributed by Andy Vaught

  This file is part of g95.

  G95 is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.

  G95 is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with g95; see the file COPYING.  If not, write to
  the Free Software Foundation, 59 Temple Place - Suite 330,
  Boston, MA 02111-1307, USA.

  In addition to the permissions in the GNU General Public License, the
  Free Software Foundation gives you unlimited permission to link the
  compiled version of this file into combinations with other programs,
  and to distribute those combinations without any restriction coming
  from the use of this file.  (The General Public License restrictions
  do apply in other respects; for example, they cover modification of
  the file, and distribution when not linked into a combined executable.)
*/


/* Subroutines supporting masked assignment in FORALL and WHERE
 * statements. */

#include <string.h>
#include "runtime.h"



/* Our task here is to store a set of logical values corresponding to
 * values of the FORALL predicate for cases where the variables
 * composing the predicate are modified by the FORALL body.  In this
 * case, an additional pass computes the values of the predicates,
 * which are stored here, and the next passes evaluate the FORALL body
 * for true values of the predicate.  FORALL statements can be nested
 * but cannot exit early. */

#define CHUNK_OF_BITS 8192

typedef struct bitvector {
    struct bitvector *next;
    unsigned char bits[CHUNK_OF_BITS / 8];
} bitvector;


typedef struct forall_header {
    struct forall_header *next;
    int rcount, wcount;

    bitvector *head, *tail, *current;
} forall_header;


typedef struct where_header {
    struct where_header *next;
    int rcount, wcount;
    enum { WHERE_INITIAL, WHERE_ELSEWHERE1, WHERE_ELSEWHERE2 } state;

    bitvector *head, *tail, *prev_current, *current;
} where_header;


typedef struct array_temp {
    struct array_temp *next;

    g95_array_descriptor desc;
    char data[1];
} array_temp;


typedef struct scalar_temp {
    struct scalar_temp *next;

    G95_DINT len;
    void *dest;

    char data[1];
} scalar_temp;



static forall_header *forall_head;
static where_header *where_head;
static array_temp *array_head;
static scalar_temp *scalar_head;


#define get_bitvector() get_mem(sizeof(bitvector))



/* forall_start()-- Start a new forall statement. */

#define forall_start prefix(forall_start)

void forall_start(void) {
forall_header *h;

    h = get_mem(sizeof(forall_header));
    h->next = forall_head;
    forall_head = h;

    h->head = h->tail = h->current = NULL;
    h->rcount = h->wcount = 0;

    forall_head->head = forall_head->tail = forall_head->current =
	get_bitvector();
}



/* forall_save()-- Store a predicate bit, allocating a new bitvector
 * structure if necessary. */

#define forall_save prefix(forall_save)

void forall_save(G95_DINT *value) {
unsigned char mask, *p;
bitvector *f;

    if (forall_head->wcount == CHUNK_OF_BITS) {
	f = get_bitvector();

	forall_head->tail->next = f;
	forall_head->tail = f;

	forall_head->wcount = 0;
    }

    p = &forall_head->tail->bits[0] + (forall_head->wcount >> 3);
    mask = 1 << (forall_head->wcount & 0x07);

    if (*value)
	*p |= mask;
    else
	*p &= ~mask;

    forall_head->wcount++;
}



/* forall_get()-- Retrieve a predicate bit, deallocating bitvector
 * structures as they are used up.  The last bit in a sequence
 * restarts the sequence. */

#define forall_get prefix(forall_get)

G95_DINT forall_get(void) {
unsigned char mask, *p;
int value;

    if (forall_head->rcount == CHUNK_OF_BITS) {
	forall_head->current = forall_head->current->next;
	forall_head->rcount = 0;
    }

    if (forall_head->current == forall_head->tail &&
	forall_head->rcount == forall_head->wcount) {
	forall_head->current = forall_head->head;
	forall_head->rcount = 0;
    }

    p = &forall_head->current->bits[0] + (forall_head->rcount >> 3); 
    mask = 1 << (forall_head->rcount & 0x07);
    value = !!(*p & mask);

    forall_head->rcount++;

    return value;
}



/* forall_done()-- Cleanup after a FORALL is completed */

#define forall_done prefix(forall_done)

void forall_done(void) {
bitvector *f, *g;
forall_header *h;

    h = forall_head;
    forall_head = h->next;

    for(f=h->head; f; f=g) {
	g = f->next;
	free_mem(f);
    }

    free_mem(h);
}



/* Bitvector manipulation for WHILE statements.  These different than
 * the FORALL subroutines because of the two bits that have to be
 * stored for every element of the mask array.  The top bit is the
 * control mask and the bottom the pending control mask.    */


/* where_scalar()-- Do a single assignment from a forall-nested where
 * statement. */

#define where_copy prefix(where_copy)

void where_copy(void *dest, void *src, G95_DINT len) {
scalar_temp *s;

    s = temp_alloc(sizeof(scalar_temp) + len);

    s->next = scalar_head;
    scalar_head = s;

    s->dest = dest;
    s->len = len;

    memcpy(s->data, src, len);
}



/* copy_scalar()-- Copy a scalar to its proper place */

static void copy_scalar(scalar_temp *s) {

    memcpy(s->dest, s->data, s->len);
    temp_free((void *) &s);
}



/* where_assign()-- Do all pending assignments in a forall-nested
 * where statement. */

static void where_assign(void) {
scalar_temp *s;

    while(scalar_head != NULL) {
	s = scalar_head;
	scalar_head = scalar_head->next;

	copy_scalar(s);
    }
}



/* where_start()-- Start a new WHILE statement. */

#define where_start prefix(where_start)

void where_start(void) {
where_header *h;

    h = get_mem(sizeof(where_header));
    h->next = where_head;
    where_head = h;

    h->head = h->tail = NULL;
    h->rcount = h->wcount = 0;
    h->state = WHERE_INITIAL;

    h->prev_current = (h->next == NULL) ? NULL : h->next->head;

    where_head->head = where_head->tail = h->current = get_bitvector();
}



/* bump_rcount()-- Move the rcount/current pointer to the next bit pair. */

static void bump_rcount(void) {

    where_head->rcount++;

    if (where_head->rcount == CHUNK_OF_BITS/2) {
	where_head->current = where_head->current->next;
	where_head->rcount = 0;
    }

    if (where_head->current == where_head->tail &&
	where_head->rcount == where_head->wcount) {

	where_head->current = where_head->head;
	where_head->rcount = 0;
	where_assign();
    }
}



/* where_write()-- Process a bit from a WHERE/ELSEWHERE statement. */

#define where_write prefix(where_write)

void where_write(G95_DINT *value, G95_DINT kind) {
int v, m, mask, control, pending;
bitvector *f;
char *p;

    v = extract_logical(value, kind); 

    mask = 1 << (2*(where_head->wcount & 0x03));

    switch(where_head->state) {
    case WHERE_INITIAL:
	if (where_head->prev_current == NULL) {
	    m = 1;
	    control = v;
	    pending = !control;

	} else {
	    p = &where_head->prev_current->bits[0] + (where_head->wcount >> 2);
	    m = !!(*p & (mask << 1));

	    control = m && v;
	    pending = m && !v;
	}

	/* Yechhh */

	p = &where_head->tail->bits[0] + (where_head->wcount >> 2);
	*p &= ~(0x03 << (2*(where_head->wcount & 0x03)));

	if (control)
	    *p |= (mask << 1);

	if (pending)
	    *p |= mask;

	where_head->wcount++;

	if (where_head->wcount == CHUNK_OF_BITS/2) {
	    f = get_bitvector();

	    where_head->tail->next = f;
	    where_head->tail = f;

	    if (where_head->prev_current != NULL)
		where_head->prev_current = where_head->prev_current->next;

	    where_head->wcount = 0;
	}

	break;

    case WHERE_ELSEWHERE1:
	/* We use the rcount/current pointers to update the pending mask. */

	p = &where_head->current->bits[0] + (where_head->rcount >> 2);
	mask = 1 << (2*(where_head->rcount & 0x03));

	control = !!(*p & mask);
	pending = control && !v;
	control &= v;

	*p &= ~(mask | (mask << 1));

	if (pending)
	    *p |= mask;

	if (control)
	    *p |= (mask << 1);

	bump_rcount();
	break;

    case WHERE_ELSEWHERE2:  /* Doesn't happen */
	break;
    }
}



/* where_read()-- Read the next bit control for a masked assignment. */

#define where_read prefix(where_read)

G95_DINT where_read(void) {
unsigned char mask, *p;
int value;

    p = &where_head->current->bits[0] + (where_head->rcount >> 2);
    mask = 1 << (2*(where_head->rcount & 0x03)+1);
    value = !!(*p & mask);

    bump_rcount();
    return value;
}



/* where_bits()-- Count the number of ones in the control mask. */

#define where_bits prefix(where_bits)

G95_DINT where_bits(void) {
int mask, rcount;
G95_DINT count;
bitvector *w;
char *p;

    count = 0;
    rcount = 0;
    w = where_head->head;

    do {
	p = &w->bits[0] + (rcount >> 2);
	mask = 1 << (2*(rcount & 0x03) + 1);

	if (*p & mask)
	    count++;

	if (++rcount == CHUNK_OF_BITS/2) {
	    w = w->next;
	    rcount = 0;
	}

    } while(rcount != where_head->wcount || w != where_head->tail);

    return count; 
}



/* where_elsewhere1()-- Move to the WHERE_ELSEWHERE1 state */

#define where_elsewhere1 prefix(where_elsewhere1)

void where_elsewhere1(void) {

    where_head->state = WHERE_ELSEWHERE1;
}



/* where_elsewhere2()-- Move to the WHERE_ELSEWHERE2 state */

#define where_elsewhere2 prefix(where_elsewhere2)

void where_elsewhere2(void) {
bitvector *w;
char *p;
int i;

    for(w=where_head->head; w; w=w->next) {
	p = &w->bits[0];

	for(i=0; i<CHUNK_OF_BITS/8; i++) {
	    *p = *p << 1;
	    p++;
	}
    }

    where_head->state = WHERE_ELSEWHERE2;
    where_assign();
}



/* where_done()-- Cleanup after a WHILE is completed */

#define where_done prefix(where_done)

void where_done(void) {
bitvector *f, *g;
where_header *h;

    h = where_head;
    where_head = h->next;

    for(f=h->head; f; f=g) {
	g = f->next;
	free_mem(f);
    }

    free_mem(h);
    where_assign();
}



/* forall_array_copy()-- When copying arrays with FORALL, a huge
 * number of temps can be necessary because the multiple arrays might
 * be going to different places.  Instead of allocating an array of
 * temporaries in the main loop, we just the loop to copy things here.
 * We take two arguments, a source and destination descriptor.  We
 * figure out how much memory is needed for the temporary, allocate
 * the memory and copies the contents of the array into a linked list
 * here.  Once done, we call forall_array_copy_done(), which traverses
 * the list, copying to the final destination and freeing the
 * temporary list. */

#define forall_copy_array prefix(forall_copy_array)

void forall_copy_array(g95_array_descriptor *dest,
		       g95_array_descriptor *source) {
G95_AINT size, extent, index[G95_MAX_DIMENSIONS];
array_temp *t;
int i, rank;
char *p, *q;

    rank = source->rank;
    size = source->element_size;

    for(i=0; i<rank; i++) {
	extent = source->info[i].ubound - source->info[i].lbound + 1;
	if (extent < 0)
	    extent = 0;

	size *= extent;
    }

    size += sizeof(array_temp);

    t = temp_alloc(size);

    t->next = array_head;
    array_head = t;

    /* The array might be descriptorless, so this tells us where to put
     * the data. */

    t->desc = *dest;
    p = &t->data[0];

    for(i=0; i<rank; i++) {
	index[i] = source->info[i].lbound;
	if (source->info[i].ubound < source->info[i].lbound)
	    return;    /* Zero sized */
    }

    do {
	q = source->offset;
	for(i=0; i<rank; i++)
	    q += index[i] * source->info[i].mult;

	memcpy(p, q, source->element_size);

	p += source->element_size;

    } while(!bump_element(source, index));
}



/* copy_array()-- Work function for forall_copy_array_done() */

static void copy_array(array_temp *t) {
G95_AINT index[G95_MAX_DIMENSIONS];
int i, rank;
char *p, *q;

    rank = t->desc.rank;

    for(i=0; i<rank; i++) {
	index[i] = t->desc.info[i].lbound;

	if (t->desc.info[i].ubound < t->desc.info[i].lbound)
	    goto done;  /* Zero sized */
    }

    q = t->data;

    do {
	p = t->desc.offset;
	for(i=0; i<rank; i++)
	    p += index[i] * t->desc.info[i].mult;

	memcpy(p, q, t->desc.element_size);

	q += t->desc.element_size;
    } while(!bump_element(&t->desc, index));

done:
    temp_free((void *) &t);
}



/* forall_copy_array_done()-- At this point, we've finished copying
 * the array into the temporaries.  Traverse the list copying things
 * back. */

#define forall_copy_array_done prefix(forall_copy_array_done)

void forall_copy_array_done(void) {
array_temp *t;

    while(array_head != NULL) {
	t = array_head;
	array_head = array_head->next;

	copy_array(t);
    }
}

