ext/nmatrix/data/ruby_object.h
/////////////////////////////////////////////////////////////////////
// = NMatrix
//
// A linear algebra library for scientific computation in Ruby.
// NMatrix is part of SciRuby.
//
// NMatrix was originally inspired by and derived from NArray, by
// Masahiro Tanaka: http://narray.rubyforge.org
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
// == Contributing
//
// By contributing source code to SciRuby, you agree to be bound by
// our Contributor Agreement:
//
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
//
// == ruby_object.h
//
// Functions and classes for dealing with Ruby objects.
#ifndef RUBY_OBJECT_H
#define RUBY_OBJECT_H
/*
* Standard Includes
*/
#include <ruby.h>
#include <iostream>
#include <type_traits>
/*
* Project Includes
*/
#include "ruby_constants.h"
/*
* Macros
*/
#define NM_RUBYVAL_IS_NUMERIC(val) (FIXNUM_P(val) or RB_FLOAT_TYPE_P(val) or RB_TYPE_P(val, T_COMPLEX))
/*
* Classes and Functions
*/
namespace nm {
template<typename T, typename U>
struct made_from_same_template : std::false_type {};
template<template<typename> class Templ, typename Arg1, typename Arg2>
struct made_from_same_template<Templ<Arg1>, Templ<Arg2> > : std::true_type {};
class RubyObject {
public:
VALUE rval;
/*
* Value constructor.
*/
inline RubyObject(VALUE ref = Qnil) : rval(ref) {}
/*
* Complex number constructor.
*/
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline RubyObject(const Complex<FloatType>& other) : rval(rb_complex_new(rb_float_new(other.r), rb_float_new(other.i))) {}
/*
* Integer constructor.
*
* Does not work as a template.
*/
inline RubyObject(uint8_t other) : rval(INT2FIX(other)) {}
inline RubyObject(int8_t other) : rval(INT2FIX(other)) {}
inline RubyObject(int16_t other) : rval(INT2FIX(other)) {}
inline RubyObject(uint16_t other) : rval(INT2FIX(other)) {}
inline RubyObject(int32_t other) : rval(INT2FIX(other)) {}
// there is no uint32_t here because that's a Ruby VALUE type, and we need the compiler to treat that as a VALUE.
inline RubyObject(int64_t other) : rval(INT2FIX(other)) {}
// inline RubyObject(uint64_t other) : rval(INT2FIX(other)) {}
/*
* Float constructor.
*
* Does not work as a template.
*/
inline RubyObject(float other) : rval(rb_float_new(other)) {}
inline RubyObject(double other) : rval(rb_float_new(other)) {}
/*
* Operators for converting RubyObjects to other C types.
*/
#define RETURN_OBJ2NUM(mac) if (this->rval == Qtrue) return 1; else if (this->rval == Qfalse) return 0; else return mac(this->rval);
inline operator int8_t() const { RETURN_OBJ2NUM(NUM2INT) }
inline operator uint8_t() const { RETURN_OBJ2NUM(NUM2UINT) }
inline operator int16_t() const { RETURN_OBJ2NUM(NUM2INT) }
inline operator uint16_t() const { RETURN_OBJ2NUM(NUM2UINT) }
inline operator int32_t() const { RETURN_OBJ2NUM(NUM2LONG) }
inline operator VALUE() const { return rval; }
//inline operator uint32_t() const { return NUM2ULONG(this->rval); }
inline operator int64_t() const { RETURN_OBJ2NUM(NUM2LONG) }
//inline operator uint64_t() const { RETURN_OBJ2NUM(NUM2ULONG) }
inline operator double() const { RETURN_OBJ2NUM(NUM2DBL) }
inline operator float() const { RETURN_OBJ2NUM(NUM2DBL) }
inline operator Complex64() const { return this->to<Complex64>(); }
inline operator Complex128() const { return this->to<Complex128>(); }
/*
* Copy constructors.
*/
inline RubyObject(const RubyObject& other) : rval(other.rval) {}
/*
* Inverse operator.
*/
inline RubyObject inverse() const {
rb_raise(rb_eNotImpError, "RubyObject#inverse needs to be implemented");
}
/*
* Absolute value.
*/
inline RubyObject abs() const {
return RubyObject(rb_funcall(this->rval, rb_intern("abs"), 0));
}
/*
* Binary operator definitions.
*/
inline RubyObject operator+(const RubyObject& other) const {
return RubyObject(rb_funcall(this->rval, nm_rb_add, 1, other.rval));
}
inline RubyObject& operator+=(const RubyObject& other) {
this->rval = rb_funcall(this->rval, nm_rb_add, 1, other.rval);
return *this;
}
inline RubyObject operator-(const RubyObject& other) const {
return RubyObject(rb_funcall(this->rval, nm_rb_sub, 1, other.rval));
}
inline RubyObject& operator-=(const RubyObject& other) {
this->rval = rb_funcall(this->rval, nm_rb_sub, 1, other.rval);
return *this;
}
inline RubyObject operator*(const RubyObject& other) const {
return RubyObject(rb_funcall(this->rval, nm_rb_mul, 1, other.rval));
}
inline RubyObject& operator*=(const RubyObject& other) {
this->rval = rb_funcall(this->rval, nm_rb_mul, 1, other.rval);
return *this;
}
inline RubyObject operator/(const RubyObject& other) const {
return RubyObject(rb_funcall(this->rval, nm_rb_div, 1, other.rval));
}
inline RubyObject& operator/=(const RubyObject& other) {
this->rval = rb_funcall(this->rval, nm_rb_div, 1, other.rval);
return *this;
}
inline RubyObject operator%(const RubyObject& other) const {
return RubyObject(rb_funcall(this->rval, nm_rb_percent, 1, other.rval));
}
inline bool operator>(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_gt, 1, other.rval) == Qtrue;
}
inline bool operator<(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_lt, 1, other.rval) == Qtrue;
}
template <typename OtherType>
inline bool operator<(const OtherType& other) const {
return *this < RubyObject(other);
}
inline bool operator==(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_eql, 1, other.rval) == Qtrue;
}
template <typename OtherType>
inline bool operator==(const OtherType& other) const {
return *this == RubyObject(other);
}
inline bool operator!=(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_neql, 1, other.rval) == Qtrue;
}
template <typename OtherType>
inline bool operator!=(const OtherType& other) const {
return *this != RubyObject(other);
}
inline bool operator>=(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_gte, 1, other.rval) == Qtrue;
}
template <typename OtherType>
inline bool operator>=(const OtherType& other) const {
return *this >= RubyObject(other);
}
inline bool operator<=(const RubyObject& other) const {
return rb_funcall(this->rval, nm_rb_lte, 1, other.rval) == Qtrue;
}
template <typename OtherType>
inline bool operator<=(const OtherType& other) const {
return *this <= RubyObject(other);
}
////////////////////////////
// RUBY-NATIVE OPERATIONS //
////////////////////////////
/*
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator==(const NativeType& other) const {
return *this == RubyObject(other);
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator!=(const NativeType& other) const {
return *this != RubyObject(other);
}
*/
//////////////////////////////
// RUBY-COMPLEX OPERATIONS //
//////////////////////////////
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator==(const Complex<FloatType>& other) const {
return *this == RubyObject(other);
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator!=(const Complex<FloatType>& other) const {
return *this != RubyObject(other);
}
/*
* Convert a Ruby object to an integer.
*/
template <typename IntType>
inline typename std::enable_if<std::is_integral<IntType>::value, IntType>::type to(void) {
return NUM2INT(this->rval);
}
/*
* Convert a Ruby object to a floating point number.
*/
template <typename FloatType>
inline typename std::enable_if<std::is_floating_point<FloatType>::value, FloatType>::type to(void) {
return NUM2DBL(this->rval);
}
/*
* Convert a Ruby object to a complex number.
*/
template <typename ComplexType>
inline typename std::enable_if<made_from_same_template<ComplexType, Complex64>::value, ComplexType>::type to(void) const {
if (FIXNUM_P(this->rval) or TYPE(this->rval) == T_FLOAT) {
return ComplexType(NUM2DBL(this->rval));
} else if (TYPE(this->rval) == T_COMPLEX) {
return ComplexType(NUM2DBL(rb_funcall(this->rval, nm_rb_real, 0)), NUM2DBL(rb_funcall(this->rval, nm_rb_imag, 0)));
} else {
rb_raise(rb_eTypeError, "Invalid conversion to Complex type.");
}
}
};
// Negative operator
inline RubyObject operator-(const RubyObject& rhs) {
return RubyObject(rb_funcall(rhs.rval, nm_rb_negate, 0));
}
////////////////////////////
// NATIVE-RUBY OPERATIONS //
////////////////////////////
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline RubyObject operator/(const NativeType left, const RubyObject& right) {
return RubyObject(left) / right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator==(const NativeType left, const RubyObject& right) {
return RubyObject(left) == right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator!=(const NativeType left, const RubyObject& right) {
return RubyObject(left) != right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator<=(const NativeType left, const RubyObject& right) {
return RubyObject(left) <= right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator>=(const NativeType left, const RubyObject& right) {
return RubyObject(left) >= right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator<(const NativeType left, const RubyObject& right) {
return RubyObject(left) < right;
}
template <typename NativeType, typename = typename std::enable_if<std::is_arithmetic<NativeType>::value>::type>
inline bool operator>(const NativeType left, const RubyObject& right) {
return RubyObject(left) > right;
}
/////////////////////////////
// COMPLEX-RUBY OPERATIONS //
/////////////////////////////
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator==(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) == right;
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator!=(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) != right;
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator<=(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) <= right;
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator>=(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) >= right;
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator<(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) < right;
}
template <typename FloatType, typename = typename std::enable_if<std::is_floating_point<FloatType>::value>::type>
inline bool operator>(const Complex<FloatType>& left, const RubyObject& right) {
return RubyObject(left) > right;
}
} // end of namespace nm
namespace std {
inline nm::RubyObject abs(const nm::RubyObject& obj) {
return obj.abs();
}
inline nm::RubyObject sqrt(const nm::RubyObject& obj) {
VALUE cMath = rb_const_get(rb_cObject, rb_intern("Math"));
return nm::RubyObject(rb_funcall(cMath, rb_intern("sqrt"), 1, obj.rval));
}
}
#endif // RUBY_OBJECT_H