/*
  Softshell: Dynamic Scheduling on GPUs.
  http://www.icg.tugraz.at/project/mvp

  Copyright (C) 2012 Institute for Computer Graphics and Vision,
                     Graz University of Technology

  Author(s):  Markus Steinberger - steinberger ( at ) icg.tugraz.at
              Bernhard Kainz - kainz ( at ) icg.tugraz.at
              Michael Kenzel - kenzel ( at ) icg.tugraz.at
              Stefan Hauswiesner - hauswiesner ( at ) icg.tugraz.at
              Bernhard Kerbl - kerbl ( at ) icg.tugraz.at
              Dieter Schmalstieg - schmalstieg ( at ) icg.tugraz.at

  Permission is hereby granted, free of charge, to any person obtaining a copy
  of this software and associated documentation files (the "Software"), to deal
  in the Software without restriction, including without limitation the rights
  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  copies of the Software, and to permit persons to whom the Software is
  furnished to do so, subject to the following conditions:

  The above copyright notice and this permission notice shall be included in
  all copies or substantial portions of the Software.

  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  THE SOFTWARE.
*/


/*
* file created by    Markus Steinberger / steinberger ( at ) icg.tugraz.at
*
* modifications by
*/

#ifndef UNORDERED_MAP_CUH
#define UNORDERED_MAP_CUH

#include "pair.cuh"
#include "hash.cuh"
#include "asms.cuh"
#include <cstdio>

namespace std_gpu
{
  //simple vector class
  template<class Key, class Value, uint Size, class Hash = hash<Key> >
  class _unordered_map
  {
    friend class Iterator;
  public:
    typedef pair<Key,Value> Pair_t;
  private:
    typedef pair<uint, Pair_t> Map_t;
  protected:
    static const uint EntryFree = 0;
    static const uint EntryUsed = 1;
    static const uint EntryRemoved = 0;
    static const uint EntryChanging = 3;

    uint _size;
    Map_t _map[Size];

    __device__ uint hashed(Key k)
    {
      return Hash::compute(k) % Size;
    }
    __device__ uint probe(uint old, uint hashed, Key k)
    {
      return (old + 1) % Size;
    }

  public:

    class Iterator
    {
      friend class _unordered_map<Key, Value, Size, Hash>;
      Map_t* _map;
      //debug
    public:
      int _pos;
    private:

      __device__ Iterator(Map_t* map, int pos = 0, bool search = true) : _map(map), _pos(pos)
      {
        if(search)
        while(_pos < Size && _map[_pos].first != _unordered_map::EntryUsed)
          ++pos;
      }
      __device__ Map_t& getMapEntry() const
      {
        return _map[_pos];
      }
    public:
      __device__ Iterator() : _map(0), _pos(0) { }
      __device__ Iterator(const Iterator& other) : _map(other._map), _pos(other._pos) { }
      __device__ const Iterator& operator = (const Iterator& other)
       {
         _map = other._map;
         _pos = other._pos;
         return *this;
      }
      __device__ bool operator == (const Iterator& other) const
      {
        return _map == other._map && _pos == other._pos;
      }
      __device__ bool operator != (const Iterator& other) const
      {
        return !(*this = other);
      }

      __device__ Pair_t& operator * () const
      {
        return _map[_pos].second;
      }
      __device__ Pair_t* operator -> () const
      {
        return &_map[_pos].second;
      }

      __device__ const Iterator& operator ++()
      {
        for(++_pos; _pos < Size && _map[_pos].first != _unordered_map::EntryUsed; ++_pos);
        return *this;
      }
      __device__ const Iterator& operator --()
      {
        for(--_pos; _pos >= 0 && _map[_pos].first != _unordered_map::EntryUsed; --_pos);
        if(_pos < 0) _pos = Size;
        return *this;
      }
      __device__ const Iterator& operator += (uint i)
      {
        for(uint j = 0; j < i && _pos != Size; ++j)
          ++*this;
        return *this;
      }
      __device__ const Iterator& operator -= (uint i)
      {
        for(uint j = 0; j < i && _pos != Size; ++j)
          ++*this;
        return *this;
      }
      __device__ Iterator operator + (uint i) const
      {
        Iterator temp(*this);
        temp += i;
        return temp;
      }
      __device__ Iterator operator - (uint i) const
      {
        Iterator temp(*this);
        temp -= i;
        return temp;
      }

      __device__ int operator - (const Iterator other) const
      {
        if(other._map != _map)
          return -Size;
        if(other > *this)
          return -(other - *this);
        Iterator temp(other);
        uint count = 0;
        while(temp != *this) ++temp, ++count;
        return count;
      }

      __device__ bool operator > (const Iterator other) const
      {
        if(other._map != _map)
          return false;
        return _pos > other._pos;
      }
      __device__ bool operator < (const Iterator other) const
      {
        if(other._map != _map)
          return false;
        return _pos < other._pos;
      }
      __device__ bool operator >= (const Iterator other) const
      {
        if(other._map != _map)
          return false;
        return _pos >= other._pos;
      }
      __device__ bool operator <= (const Iterator other) const
      {
        if(other._map != _map)
          return false;
        return _pos <= other._pos;
      }

      __device__ Pair_t& operator[] (int i) const
      {
        Iterator temp(*this);
        if(i >= 0)
        {
          temp += i;
          return *temp;
        }
        else
        {
          temp -= -i;
          return *temp;
        }
      }
    };

    __device__ inline void clear()
    {
      for(uint i = 0; i < Size; ++i)
        _map[i].first = 0U;
      _size = 0;
    }
    __device__ inline void init() { clear(); }
    __device__ inline void init(uint linId, uint threads)
    {
      for(uint i = linId; i < Size; i+=threads)
        _map[i].first = 0;
      if(linId == 0)
        _size = 0;
    }
    __device__ inline uint size() const { return _size; }

    __device__ inline Iterator begin() { return Iterator(_map, 0); }
    __device__ inline Iterator end() { return Iterator(_map, Size); }

    __device__ inline pair<Iterator,bool> insert(Pair_t pair)
    {
      //serialize insert from same warp
      uint __lanemask = lanemask_lt(),
        __mask = __ballot(1),
        __local_id = __popc(__lanemask & __mask),
        __num = __popc(__mask);
      for (uint __active = 0; __active < __num; ++__active)
        if (__active == __local_id)
        {

          uint old = atomicAdd(&_size, 1);
          if(old >= Size)
          {
            atomicSub(&_size, 1);
            return make_pair(end(), false);
          }

          uint hval = hashed(pair.first);
          uint current = hval;
          //printf("%llx inserting %d\n",this, hval);
          for(uint searching = 0; searching < Size; ++searching)
          {
            uint state = *const_cast<volatile uint*>(&_map[current].first);
            //int mycoco = 0;
            while(state == EntryUsed || state == EntryChanging)
            {
              //if(++mycoco > 100000)
              //{
              //  printf("error 1 in unordered map!ABORT\n");
              //  ::trap();
              //}
              if(state == EntryChanging)
              {
                state = *const_cast<volatile uint*>(&_map[current].first);
                continue;
              }
              if(Hash::same(*const_cast<volatile Key*>(&_map[current].second.first),pair.first))
              {
                atomicSub(&_size, 1);
                //printf("%llx found existing entry %d for %d\n",this, current, hval);
                return make_pair(Iterator(_map, current), false);
              }
              current = probe(current, hval, pair.first);
              state = *const_cast<volatile uint*>(&_map[current].first);
            }

            uint old = atomicCAS((uint*)&_map[current].first, state, EntryChanging);
            if(old != state)
            {
              state = old;
              continue;
            }
            else
            {
              _map[current].second = pair;
              __threadfence();
              *const_cast<volatile uint*>(&_map[current].first) = EntryUsed;
              //printf("%llx created entry %d for %d\n",this, current, hval);
              return make_pair(Iterator(_map, current, false), true);
            }
          }
        }
      printf("error unordered_map out of space!ABORT\n");
      ::trap();
      return make_pair(end(), false);
    }
    __device__ inline Iterator find(Key k)
    {
      uint hval = hashed(k);
      uint current = hval;
      uint state;
      //printf("searching %d\n", hval);
      //int mycoco = 0;
      while((state = *const_cast<volatile uint*>(&_map[current].first)) != EntryFree)
      {
        if(state == EntryUsed)
          if(Hash::same(*const_cast<volatile Key*>(&_map[current].second.first),k))
            return Iterator(_map, current, false);
        current = probe(current, hval, k);
        //if(++mycoco > 1000)
        //{
        //  printf("error 2 in unordered map, size: %d  current: %d hval: %d state: %d!ABORT\n",_size,current,hval,state);
        //  ::trap();
        //}
      }
      return end();
    }
    __device__ inline Iterator erase(Iterator what)
    {
      if(what._map != this)
        return end();
      uint old = atomicCAS((uint*)&(what.getMapEntry().first), EntryUsed, EntryRemoved);
      if(old == EntryUsed)
        atomicSub(&_size, 1);
      return ++what;
    }
    __device__ inline uint erase(Key k)
    {
      Iterator found = find(k);
      if(found == end())
        return 0;
      uint old = atomicCAS((uint*)&(found.getMapEntry().first), EntryUsed, EntryRemoved);
      if(old == EntryUsed)
      {
        atomicSub(&_size, 1);
        return 1;
      }
      return 0;
    }

    __device__ inline Iterator erase(Iterator what, const Pair_t& with)
    {
      if(what._map != this)
        return end();

      uint old = atomicCAS((uint*)&(what.getMapEntry().first), EntryUsed, EntryChanging);
      if(old == EntryUsed)
      {
        _map[what._pos].second = with;
        __threadfence();
        *const_cast<volatile uint*>(&_map[what._pos].first) = EntryRemoved;
        atomicSub(&_size, 1);
      }
      return ++what;
    }
    __device__ inline uint erase(Key k, const Pair_t& with)
    {
      Iterator found = find(k);
      if(found == end())
        return 0;

      uint old = atomicCAS((uint*)&found.getMapEntry().first, EntryUsed, EntryChanging);
      if(old == EntryUsed)
      {
        _map[found._pos].second = with;
        __threadfence();
        *const_cast<volatile uint*>(&_map[found._pos].first) = EntryRemoved;
        atomicSub(&_size, 1);
        return 1;
      }
      return 0;
    }
    __device__ inline Iterator& operator[] (Key k) { return find(k); }
  };
  template<class Key, class Value, uint Size, class Hash = hash<Key> >
  class unordered_map : public _unordered_map<Key, Value, Size, Hash>
  {
  public:
    __device__ inline unordered_map() { _unordered_map<Key, Value, Size, Hash>::init(); }
    __device__ inline unordered_map(const unordered_map& other)
    {
      _unordered_map<Key, Value, Size, Hash>::_size = other._size;
      for(uint i = 0; i < Size; ++i)
        _unordered_map<Key, Value, Size, Hash>::_map[i] = other._map[i];
    }
    __device__ inline ~unordered_map() {  }
  };
}

#endif
