src/SciPhp/NumPhp/ArithmeticTrait.php
<?php
declare(strict_types=1);
namespace SciPhp\NumPhp;
use RecursiveArrayIterator;
use RecursiveIteratorIterator;
use SciPhp\Exception\Message;
use SciPhp\NdArray;
use Webmozart\Assert\Assert;
trait ArithmeticTrait
{
/**
* Return the reciprocal of the argument, element-wise.
*
* @param \SciPhp\NdArray|array|float|int $m
* @link http://sciphp.org/numphp.reciprocal
* Documentation for reciprocal() method
* @api
*/
final public static function reciprocal($m)
{
if (is_numeric($m)) {
Assert::notEq(0, $m);
return 1 / $m;
}
static::transform($m, true);
return static::ones($m->shape)->divide($m);
}
/**
* Subtract a matrix from matrix
*
* @param \SciPhp\NdArray|array|float|int $m
* @param \SciPhp\NdArray|array|float|int $n
* @link http://sciphp.org/numphp.subtract Documentation
* @api
*/
final public static function subtract($m, $n)
{
if (static::allNumeric($m, $n)) {
return $m - $n;
}
static::transform($n);
// lambda - array
if (is_numeric($m) && $n instanceof NdArray) {
return static::full_like($n, $m)->subtract($n);
}
// array - array
static::transform($m, true);
// array - array OR array - lambda
return $m->negative()->add($n)->negative();
}
/**
* Add two array_like
*
* @param \SciPhp\NdArray|array|int|float $m
* @param \SciPhp\NdArray|array|int|float $n
* @return \SciPhp\NdArray|int|float
* @link http://sciphp.org/numphp.add Documentation
* @api
*/
final public static function add($m, $n)
{
if (static::allNumeric($m, $n)) {
return $m + $n;
}
static::transform($n);
// lambda + array
if (is_numeric($m) && $n instanceof NdArray) {
return $n->copy()->add($m);
}
// array + array
static::transform($m, true);
// array + array OR array + lambda
return $m->copy()->add($n);
}
/**
* Divide two arrays, element-wise
*
* @param \SciPhp\NdArray|array|float|int $m A 2-dim array.
* @param \SciPhp\NdArray|array|float|int $n A 2-dim array.
* @return \SciPhp\NdArray|float|int
* @throws \InvalidArgumentException
* @link http://sciphp.org/numphp.divide
* Documentation for divide()
* @api
*/
final public static function divide($m, $n)
{
if (static::allNumeric($m, $n)) {
Assert::notEq(0, $n);
return $m / $n;
}
static::transform($m);
static::transform($n);
// array / lamba
if (is_numeric($n) && $m instanceof NdArray) {
return $m->copy()->divide($n);
}
// lamba / array
if (is_numeric($m) && $n instanceof NdArray) {
return static::full_like($n, $m)->divide($n);
}
// array / array
Assert::isInstanceof($m, 'SciPhp\NdArray');
Assert::isInstanceof($n, 'SciPhp\NdArray');
$shape_m = $m->shape;
$shape_n = $n->shape;
// n & m are vectors:
if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
}
// n is a vector
elseif (! isset($shape_n[1])) {
Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
}
// m is a vector
elseif (! isset($shape_m[1])) {
Assert::eq($shape_m[0], $shape_n[1], Message::MAT_NOT_ALIGNED);
$m = $m->resize($shape_n);
}
// array / array -> broadcast
elseif ($m->ndim === $n->ndim && $shape_m[0] === $shape_n[0] && $shape_m[1] > $shape_n[1]) {
$n = static::broadcast_to($n, $shape_m);
}
// array / array
elseif ($m->ndim === $n->ndim) {
Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
}
$iterator = new RecursiveIteratorIterator(
new RecursiveArrayIterator($n->data),
RecursiveIteratorIterator::LEAVES_ONLY
);
$func = static function (&$item) use (&$iterator, $n): void {
Assert::notEq(0, $value = $n->iterate($iterator));
$item /= $value;
};
return $m->copy()->walk_recursive($func);
}
/**
* Multiply two arrays, element-wise
*
* @param \SciPhp\NdArray|array|float|int $m A 2-dim array.
* @param \SciPhp\NdArray|array|float|int $n A 2-dim array.
* @return \SciPhp\NdArray|float|int
* @throws \InvalidArgumentException
* @link http://sciphp.org/numphp.multiply Documentation
* @api
*/
final public static function multiply($m, $n)
{
if (static::allNumeric($m, $n)) {
return $m * $n;
}
static::transform($m);
static::transform($n);
// array * lamba
if (is_numeric($n) && $m instanceof NdArray) {
return $m->copy()->dot($n);
}
// lamba * array
if (is_numeric($m) && $n instanceof NdArray) {
return $n->copy()->dot($m);
}
// array * array
Assert::isInstanceof($m, NdArray::class);
Assert::isInstanceof($n, NdArray::class);
$shape_m = $m->shape;
$shape_n = $n->shape;
// n & m are vectors:
if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
}
// n is a vector
elseif (! isset($shape_n[1])) {
Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
}
// m is a vector
elseif (! isset($shape_m[1])) {
Assert::eq($shape_m[0], $shape_n[1], Message::MAT_NOT_ALIGNED);
$m = $m->resize($shape_n);
}
// array * array
elseif ($m->ndim === $n->ndim) {
Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
}
$iterator = new RecursiveIteratorIterator(
new RecursiveArrayIterator($n->data),
RecursiveIteratorIterator::LEAVES_ONLY
);
$func = static function (&$item) use (&$iterator, $n): void {
$item *= $n->iterate($iterator);
};
return $m->copy()->walk_recursive($func);
}
/**
* Dot product of two arrays
*
* @param \SciPhp\NdArray|array|float|int $m A 2-dim array.
* @param \SciPhp\NdArray|array|float|int $n A 2-dim array.
* @return \SciPhp\NdArray|float|int
* @throws \InvalidArgumentException
* @link http://sciphp.org/numphp.dot Documentation
* @api
*/
final public static function dot($m, $n)
{
if (static::allNumeric($m, $n)) {
return $m * $n;
}
static::transform($m);
static::transform($n);
// array.lamba
if (is_numeric($n) && $m instanceof NdArray) {
return $m->copy()->dot($n);
}
// lamba.array
if (is_numeric($m) && $n instanceof NdArray) {
return $n->copy()->dot($m);
}
// array.array
Assert::isInstanceof($m, NdArray::class);
Assert::isInstanceof($n, NdArray::class);
$shape_m = $m->shape;
$shape_n = $n->shape;
// n & m are vectors:
if (count($shape_m) === 1 && $m->ndim === $n->ndim) {
Assert::eq($shape_m, $shape_n, Message::MAT_NOT_ALIGNED);
return array_sum(
array_map(
static function ($el_m, $el_n) {
return $el_m * $el_n;
},
$m->data,
$n->data
)
);
}
// n is a vector
if (! isset($shape_n[1])) {
Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
return static::zeros($shape_m[0], 1)
->walk(
self::rowDot(
$m,
$n->reshape($shape_n[0], 1)
)
)->reshape($shape_m[0]);
}
// m is a vector
if (! isset($shape_m[1])) {
Assert::eq($shape_m[0], $shape_n[0], Message::MAT_NOT_ALIGNED);
$callback = static function (&$item, $k_m) use ($m, $n): void {
$item = array_sum(
array_map(
static function($el_n, $el_m) {
return $el_n * $el_m;
},
$m->data,
array_column($n->data, $k_m)
)
);
};
return static::zeros($shape_n[1])->walk($callback);
}
Assert::eq($shape_m[1], $shape_n[0], Message::MAT_NOT_ALIGNED);
return static::zeros($shape_m[0], $shape_n[1])->walk(
self::rowDot($m, $n)
);
}
/**
* Browse p rows
*/
final protected static function rowDot(NdArray $m, NdArray $n): callable
{
return static function (&$row, $row_m) use ($m, $n): void {
array_walk(
$row,
self::colDot($row_m, $m, $n)
);
};
}
/**
* Browse p cols and sum products
*/
final protected static function colDot($row_m, NdArray $m, NdArray $n): callable
{
// row_m * col_n
return static function (&$item, $col_m) use ($row_m, $m, $n): void {
$item = array_sum(
array_map(
static function ($el_m, $row_n) use ($col_m) {
return $el_m * $row_n[$col_m];
},
$m->data[$row_m],
$n->data
)
);
};
}
}