ext/triez.cc
#include <hat-trie.h>
#include <ruby.h>
#include <ruby/encoding.h>
// for rubinius
#ifndef rb_enc_fast_mbclen
# define rb_enc_fast_mbclen rb_enc_mbclen
#endif
static VALUE hat_class;
static rb_encoding* u8_enc;
static rb_encoding* bin_enc;
static inline VALUE unify_key(VALUE key) {
rb_encoding* enc = rb_enc_get(key);
if (enc != u8_enc && enc != bin_enc) {
return rb_funcall(key, rb_intern("encode"), 1, rb_enc_from_encoding(u8_enc));
} else {
return key;
}
}
static inline long long V2LL(VALUE v) {
union {VALUE v; long long l;} u;
u.v = v;
return u.l;
}
static inline VALUE LL2V(long long l) {
union {VALUE v; long long l;} u;
u.l = l;
return u.v;
}
struct HatTrie {
hattrie_t* p;
VALUE default_value;
bool obj_value;
bool initialized;
HatTrie() : default_value(Qnil), obj_value(false), initialized(false) {
p = hattrie_create();
}
~HatTrie() {
hattrie_free(p);
}
};
static void hat_mark(void* p_ht) {
HatTrie* ht = (HatTrie*)p_ht;
if (!IMMEDIATE_P(ht->default_value)) {
rb_gc_mark(ht->default_value);
}
if (!ht->obj_value) {
return;
}
hattrie_t* p = ht->p;
hattrie_iter_t* it = hattrie_iter_begin(p, false);
while (!hattrie_iter_finished(it)) {
value_t* v = hattrie_iter_val(it);
if (!IMMEDIATE_P(*v)) {
rb_gc_mark(*v);
}
hattrie_iter_next(it);
}
hattrie_iter_free(it);
}
static void hat_free(void* p) {
delete (HatTrie*)p;
}
static VALUE hat_alloc(VALUE self) {
HatTrie* ht = new HatTrie();
return Data_Wrap_Struct(hat_class, hat_mark, hat_free, ht);
}
#define PRE_HAT\
hattrie_t* p;\
HatTrie* ht;\
Data_Get_Struct(self, HatTrie, ht);\
p = ht->p;\
Check_Type(key, T_STRING);\
key = unify_key(key);
static VALUE hat_set_type(VALUE self, VALUE obj_value, VALUE default_value) {
HatTrie* ht;
Data_Get_Struct(self, HatTrie, ht);
if (ht->initialized) {
rb_raise(rb_eRuntimeError, "Already initialized");
return self;
}
ht->default_value = default_value;
ht->obj_value = RTEST(obj_value);
ht->initialized = true;
return self;
}
static VALUE hat_value_type(VALUE self) {
HatTrie* ht;
Data_Get_Struct(self, HatTrie, ht);
return ht->obj_value ? ID2SYM(rb_intern("object")) : ID2SYM(rb_intern("int64"));
}
static VALUE hat_size(VALUE self) {
HatTrie* ht;
Data_Get_Struct(self, HatTrie, ht);
return ULL2NUM(hattrie_size(ht->p));
}
static VALUE hat_set(VALUE self, VALUE key, VALUE value) {
PRE_HAT;
long long v = ht->obj_value ? value : NUM2LL(value);
char* s = RSTRING_PTR(key);
size_t len = RSTRING_LEN(key);
hattrie_get(p, s, len)[0] = v;
return self;
}
static inline void hat_change(HatTrie* ht, hattrie_t* p, char* s, size_t len) {
// NOTE must use 2-step change, because the block may change the trie
value_t* vp = hattrie_tryget(p, s, len);
long long v;
if (ht->obj_value) {
VALUE value = vp ? LL2V(vp[0]) : ht->default_value;
v = V2LL(rb_yield(value));
} else {
VALUE value = vp ? LL2NUM(vp[0]) : ht->default_value;
v = NUM2LL(rb_yield(value));
}
hattrie_get(p, s, len)[0] = v;
}
static inline void hat_change_prefix(HatTrie* ht, hattrie_t* p, char* s, size_t len, char* rs) {
char* rs_end = rs + len;
long n;
for (; rs < rs_end; rs += n, len -= n) {
hat_change(ht, p, s, len);
// no need check encoding because reverse succeeded
n = rb_enc_fast_mbclen(rs, rs_end, u8_enc);
}
}
static VALUE hat_change_all(VALUE self, VALUE type, VALUE key) {
PRE_HAT;
char* s = RSTRING_PTR(key);
size_t len = RSTRING_LEN(key);
ID ty = SYM2ID(type);
if (ty == rb_intern("suffix")) {
char* s_end = s + len;
long n;
for (; s < s_end; s += n, len -= n) {
hat_change(ht, p, s, len);
n = rb_enc_mbclen(s, s_end, u8_enc);
}
} else if (ty == rb_intern("prefix")) {
volatile VALUE reversed = rb_funcall(key, rb_intern("reverse"), 0);
hat_change_prefix(ht, p, s, len, RSTRING_PTR(reversed));
} else if (ty == rb_intern("substring")) {
volatile VALUE reversed = rb_funcall(key, rb_intern("reverse"), 0);
char* rs = RSTRING_PTR(reversed);
char* s_end = s + len;
long n;
for (; s < s_end; s += n, len -= n) {
hat_change_prefix(ht, p, s, len, rs);
n = rb_enc_fast_mbclen(s, s_end, u8_enc);
}
}
return self;
}
static VALUE hat_append(VALUE self, VALUE key) {
HatTrie* ht;
Data_Get_Struct(self, HatTrie, ht);
return hat_set(self, key, ht->default_value);
}
static VALUE hat_get(VALUE self, VALUE key) {
PRE_HAT;
value_t* vt = hattrie_tryget(p, RSTRING_PTR(key), RSTRING_LEN(key));
if (vt) {
return ht->obj_value ? (*vt) : LL2NUM(*vt);
} else {
return ht->default_value;
}
}
static VALUE hat_del(VALUE self, VALUE key) {
PRE_HAT;
const char* s = RSTRING_PTR(key);
size_t len = RSTRING_LEN(key);
value_t* vt = hattrie_tryget(p, s, len);
if (vt) {
hattrie_del(p, RSTRING_PTR(key), RSTRING_LEN(key));
return ht->obj_value ? (*vt) : LL2NUM(*vt);
} else {
return ht->default_value;
}
}
static VALUE hat_check(VALUE self, VALUE key) {
PRE_HAT;
value_t* vt = hattrie_tryget(p, RSTRING_PTR(key), RSTRING_LEN(key));
return vt ? Qtrue : Qfalse;
}
struct SearchCbData {
VALUE callback;
VALUE suffix;
VALUE value;
};
static VALUE hat_search_callback(VALUE data) {
SearchCbData* p = (SearchCbData*)data;
return rb_funcall(p->callback, rb_intern("call"), 2, p->suffix, p->value);
}
static VALUE hat_search(VALUE self, VALUE key, VALUE vlimit, VALUE vsort, VALUE callback) {
PRE_HAT;
long limit = 0;
if (vlimit != Qnil) {
limit = NUM2LONG(vlimit);
}
hattrie_iter_t* it = hattrie_iter_with_prefix(p, RTEST(vsort), RSTRING_PTR(key), RSTRING_LEN(key));
int error = 0;
SearchCbData data = {callback};
while (!hattrie_iter_finished(it)) {
if (vlimit != Qnil && limit-- <= 0) {
break;
}
size_t suffix_len;
const char* suffix_s = hattrie_iter_key(it, &suffix_len);
value_t* v = hattrie_iter_val(it);
data.suffix = rb_enc_str_new(suffix_s, suffix_len, u8_enc);
data.value = ht->obj_value ? (*v) : LL2NUM(*v);
rb_protect(hat_search_callback, (VALUE)&data, &error);
if (error) {
break;
}
hattrie_iter_next(it);
}
hattrie_iter_free(it);
if (error) {
rb_funcall(rb_mKernel, rb_intern("raise"), 0);
}
return self;
}
typedef struct {
bool obj_value;
VALUE arr;
} HatWalkData;
static int hat_walk_cb(const char* key, size_t len, value_t* v, void* data_p) {
HatWalkData* data = (HatWalkData*)data_p;
volatile VALUE r = rb_ary_new();
rb_ary_push(r, rb_str_new(key, len));
rb_ary_push(r, data->obj_value ? (*v) : LL2NUM(*v));
rb_ary_push(data->arr, r);
return hattrie_walk_continue;
}
static VALUE hat_walk(VALUE self, VALUE key) {
PRE_HAT;
size_t len = (size_t)RSTRING_LEN(key);
volatile HatWalkData data = {ht->obj_value, rb_ary_new()};
// to prevent leak by break/next, we have to collect the array first
hattrie_walk(p, RSTRING_PTR(key), len, (void*)&data, hat_walk_cb);
return data.arr;
}
#define DEF(k,n,f,c) rb_define_method(k,n,RUBY_METHOD_FUNC(f),c)
extern "C"
void Init_triez() {
hat_class = rb_define_class("Triez", rb_cObject);
u8_enc = rb_utf8_encoding();
bin_enc = rb_ascii8bit_encoding();
rb_define_alloc_func(hat_class, hat_alloc);
DEF(hat_class, "_internal_set_type", hat_set_type, 2);
DEF(hat_class, "value_type", hat_value_type, 0);
DEF(hat_class, "size", hat_size, 0);
DEF(hat_class, "[]=", hat_set, 2);
DEF(hat_class, "change_all", hat_change_all, 2);
DEF(hat_class, "<<", hat_append, 1);
DEF(hat_class, "[]", hat_get, 1);
DEF(hat_class, "has_key?", hat_check, 1);
DEF(hat_class, "delete", hat_del, 1);
DEF(hat_class, "_internal_search", hat_search, 4);
DEF(hat_class, "_internal_walk", hat_walk, 1);
}