package org.eclipse.january.dataset;

import org.eclipse.january.asserts.TestUtils;
import org.eclipse.january.dataset.LinearAlgebra;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/eclipse/january/dataset/LinearAlgebraTest.class */
public class LinearAlgebraTest {
    private static boolean close(Number number, double d) {
        double doubleValue = number.doubleValue();
        return doubleValue == 0.0d ? Math.abs(d) < 1.0E-5d : Math.abs(doubleValue - d) < 1.0E-5d * doubleValue;
    }

    @Test
    public void testTensorDot() {
        Dataset reshape = DatasetFactory.createRange(60.0d, 5).reshape(new int[]{3, 4, 5});
        Dataset reshape2 = DatasetFactory.createRange(24.0d, 2).reshape(new int[]{4, 3, 2});
        long j = -System.nanoTime();
        Dataset tensorDotProduct = LinearAlgebra.tensorDotProduct(reshape, reshape2, new int[]{1}, new int[]{0, 1});
        System.out.printf("Time taken %dus\n", Long.valueOf((j + System.nanoTime()) / 1000));
        Assert.assertArrayEquals("Shape", new int[]{5, 2}, tensorDotProduct.getShape());
        Assert.assertEquals("Type", 5L, tensorDotProduct.getDType());
        Assert.assertTrue("Data does not match", DatasetFactory.createFromObject(new double[]{4400.0d, 4730.0d, 4532.0d, 4874.0d, 4664.0d, 5018.0d, 4796.0d, 5162.0d, 4928.0d, 5306.0d}, new int[]{5, 2}).cast(tensorDotProduct.getDType()).equals(tensorDotProduct));
        Dataset reshape3 = DatasetFactory.createRange(20 * 16, 5).reshape(new int[]{16, 4, 5});
        Dataset reshape4 = DatasetFactory.createRange(8 * 16, 2).reshape(new int[]{4, 16, 2});
        long j2 = -System.nanoTime();
        Dataset tensorDotProduct2 = LinearAlgebra.tensorDotProduct(reshape3, reshape4, 0, 1);
        long nanoTime = j2 + System.nanoTime();
        long j3 = -System.nanoTime();
        Dataset tensorDotProduct3 = LinearAlgebra.tensorDotProduct(reshape3, reshape4, new int[1], new int[]{1});
        System.out.printf("Time taken %dus %dus\n", Long.valueOf(nanoTime / 1000), Long.valueOf((j3 + System.nanoTime()) / 1000));
        Assert.assertTrue("Data does not match", tensorDotProduct3.equals(tensorDotProduct2));
    }

    @Test
    public void testDot() {
        Dataset createRange = DatasetFactory.createRange(10.0d, 5);
        Dataset createRange2 = DatasetFactory.createRange(-6.0d, 4.0d, 1.0d, 2);
        long j = -System.nanoTime();
        Dataset dotProduct = LinearAlgebra.dotProduct(createRange, createRange2);
        long nanoTime = j + System.nanoTime();
        long j2 = -System.nanoTime();
        Float valueOf = Float.valueOf(((Number) Maths.multiply(createRange, createRange2).sum(new boolean[0])).floatValue());
        System.out.printf("Time taken %dus %dus\n", Long.valueOf(nanoTime / 1000), Long.valueOf((j2 + System.nanoTime()) / 1000));
        Assert.assertTrue("Data does not match", valueOf.equals(dotProduct.getObjectAbs(0)));
        Assert.assertTrue("Data does not match", valueOf.equals(dotProduct.getObject()));
    }

    @Test
    public void testRandomDot() {
        DoubleDataset randn = Random.randn(123.5d, 23.4d, new int[]{100});
        Dataset square = Maths.square(randn);
        Dataset tensorDotProduct = LinearAlgebra.tensorDotProduct(randn, randn, 0, 0);
        System.nanoTime();
        Number number = (Number) square.sum(new boolean[0]);
        Assert.assertTrue("Second moment does not match: " + number + " cf " + tensorDotProduct.getObject(), close(number, tensorDotProduct.getDouble()));
        Dataset dotProduct = LinearAlgebra.dotProduct(square, randn);
        Number number2 = (Number) Maths.multiply(randn, square).sum(new boolean[0]);
        Assert.assertTrue("Third moment does not match: " + number2 + " cf " + dotProduct.getObject(), close(number2, dotProduct.getDouble()));
        Dataset dotProduct2 = LinearAlgebra.dotProduct(square, square);
        Number number3 = (Number) Maths.multiply(square, square).sum(new boolean[0]);
        Assert.assertTrue("Fourth moment does not match: " + number3 + " cf " + dotProduct2.getObject(), close(number3, dotProduct2.getDouble()));
    }

    @Test
    public void testOuter() {
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{0.0d, 0.0d, 0.0d, 0.0d, 1.0d, 2.0d}, new int[]{2, 3}), LinearAlgebra.outerProduct(DatasetFactory.createRange(DoubleDataset.class, 2.0d), DatasetFactory.createRange(DoubleDataset.class, 3.0d)), 1.0E-12d, 1.0E-12d);
        DoubleDataset randn = Random.randn(123.5d, 23.4d, new int[]{10});
        DoubleDataset randn2 = Random.randn(-31.2d, 12.4d, new int[]{7});
        Dataset outerProduct = LinearAlgebra.outerProduct(randn, randn2);
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 7; i2++) {
                Assert.assertEquals("", randn.getDouble(i) * randn2.getDouble(i2), outerProduct.getDouble(i, i2), 1.0E-12d);
            }
        }
    }

    @Test
    public void testCross() {
        Dataset createFromObject = DatasetFactory.createFromObject(new int[]{2, 3, 5}, new int[]{3});
        Dataset createFromObject2 = DatasetFactory.createFromObject(new float[]{1.0f, 4.0f, 7.0f, 2.0f, 5.0f, 8.0f}, new int[]{2, 3});
        Dataset crossProduct = LinearAlgebra.crossProduct(createFromObject, createFromObject2);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{1.0d, -9.0d, 5.0d, -1.0d, -6.0d, 4.0d}, new int[]{2, 3}), crossProduct, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct), LinearAlgebra.crossProduct(createFromObject2, createFromObject), 1.0E-15d, 1.0E-15d);
        Dataset crossProduct2 = LinearAlgebra.crossProduct(createFromObject, createFromObject2, -1, -1, 0);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{1.0d, -1.0d, -9.0d, -6.0d, 5.0d, 4.0d}, new int[]{3, 2}), crossProduct2, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct2), LinearAlgebra.crossProduct(createFromObject2, createFromObject, -1, -1, 0), 1.0E-15d, 1.0E-15d);
        Dataset createFromObject3 = DatasetFactory.createFromObject(new int[]{2, 3}, new int[]{3});
        Dataset crossProduct3 = LinearAlgebra.crossProduct(createFromObject3, createFromObject2);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{21.0d, -14.0d, 5.0d, 24.0d, -16.0d, 4.0d}, new int[]{2, 3}), crossProduct3, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct3), LinearAlgebra.crossProduct(createFromObject2, createFromObject3), 1.0E-15d, 1.0E-15d);
        Dataset createFromObject4 = DatasetFactory.createFromObject(new int[]{2, 3}, new int[]{2});
        Dataset crossProduct4 = LinearAlgebra.crossProduct(createFromObject4, createFromObject2);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{21.0d, -14.0d, 5.0d, 24.0d, -16.0d, 4.0d}, new int[]{2, 3}), crossProduct4, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct4), LinearAlgebra.crossProduct(createFromObject2, createFromObject4), 1.0E-15d, 1.0E-15d);
        Dataset createFromObject5 = DatasetFactory.createFromObject(new int[]{2, 3, 5}, new int[]{3});
        Dataset createFromObject6 = DatasetFactory.createFromObject(new float[]{1.0f, 4.0f, 0.0f, 2.0f, 5.0f, 0.0f}, new int[]{2, 3});
        Dataset crossProduct5 = LinearAlgebra.crossProduct(createFromObject5, createFromObject6);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{-20.0d, 5.0d, 5.0d, -25.0d, 10.0d, 4.0d}, new int[]{2, 3}), crossProduct5, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct5), LinearAlgebra.crossProduct(createFromObject6, createFromObject5), 1.0E-15d, 1.0E-15d);
        Dataset createFromObject7 = DatasetFactory.createFromObject(new float[]{1.0f, 4.0f, 2.0f, 5.0f}, new int[]{2, 2});
        Dataset crossProduct6 = LinearAlgebra.crossProduct(createFromObject5, createFromObject7);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{-20.0d, 5.0d, 5.0d, -25.0d, 10.0d, 4.0d}, new int[]{2, 3}), crossProduct6, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct6), LinearAlgebra.crossProduct(createFromObject7, createFromObject5), 1.0E-15d, 1.0E-15d);
        Dataset createFromObject8 = DatasetFactory.createFromObject(new int[]{2, 3}, new int[]{2});
        Dataset createFromObject9 = DatasetFactory.createFromObject(new float[]{1.0f, 4.0f, 2.0f, 5.0f}, new int[]{2, 2});
        Dataset crossProduct7 = LinearAlgebra.crossProduct(createFromObject8, createFromObject9);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{5.0d, 4.0d}, new int[]{2}), crossProduct7, 1.0E-15d, 1.0E-15d);
        TestUtils.assertDatasetEquals(Maths.negative(crossProduct7), LinearAlgebra.crossProduct(createFromObject9, createFromObject8), 1.0E-15d, 1.0E-15d);
    }

    @Test
    public void testNorm() {
        Dataset createRange = DatasetFactory.createRange(9.0d, 3);
        createRange.isubtract(4);
        Dataset reshape = createRange.reshape(new int[]{3, 3});
        LinearAlgebra.NormOrder normOrder = LinearAlgebra.NormOrder.DEFAULT;
        Assert.assertEquals(7.745966692414834d, LinearAlgebra.norm(createRange, normOrder), 1.0E-15d);
        Assert.assertEquals(7.745966692414834d, LinearAlgebra.norm(reshape, normOrder), 1.0E-15d);
        LinearAlgebra.NormOrder normOrder2 = LinearAlgebra.NormOrder.POS_INFINITY;
        Assert.assertEquals(4.0d, LinearAlgebra.norm(createRange, normOrder2), 1.0E-15d);
        Assert.assertEquals(9.0d, LinearAlgebra.norm(reshape, normOrder2), 1.0E-15d);
        LinearAlgebra.NormOrder normOrder3 = LinearAlgebra.NormOrder.NEG_INFINITY;
        Assert.assertEquals(0.0d, LinearAlgebra.norm(createRange, normOrder3), 1.0E-15d);
        Assert.assertEquals(2.0d, LinearAlgebra.norm(reshape, normOrder3), 1.0E-15d);
        Assert.assertEquals(20.0d, LinearAlgebra.norm(createRange, 1), 1.0E-15d);
        Assert.assertEquals(7.0d, LinearAlgebra.norm(reshape, 1), 1.0E-15d);
        Assert.assertEquals(-4.656612877414201E-10d, LinearAlgebra.norm(createRange, -1), 1.0E-9d);
        Assert.assertEquals(6.0d, LinearAlgebra.norm(reshape, -1), 1.0E-15d);
        Assert.assertEquals(7.745966692414834d, LinearAlgebra.norm(createRange, 2), 1.0E-15d);
        Assert.assertEquals(7.3484692283495345d, LinearAlgebra.norm(reshape, 2), 1.0E-15d);
        Assert.assertEquals(0.0d, LinearAlgebra.norm(createRange, -2), 1.0E-15d);
        Assert.assertEquals(1.8570331885190563E-16d, LinearAlgebra.norm(reshape, -2), 1.0E-15d);
        Assert.assertEquals(5.848035476425731d, LinearAlgebra.norm(createRange, 3), 1.0E-15d);
        Assert.assertEquals(0.0d, LinearAlgebra.norm(createRange, -3), 1.0E-15d);
    }

    @Test
    public void testDeterminant() {
        Assert.assertEquals(0.0d, LinearAlgebra.calcDeterminant(DatasetFactory.createRange(1.0d, 21.0d, 1.0d, 3).reshape(new int[]{4, 5}).getSliceView(new Slice[]{null, new Slice(4)})), 1.0E-8d);
    }

    @Test
    public void testTrace() {
        Dataset reshape = DatasetFactory.createRange(20.0d, 3).reshape(new int[]{4, 5});
        Assert.assertEquals(36L, LinearAlgebra.trace(reshape).getInt());
        Assert.assertEquals(33L, LinearAlgebra.trace(reshape, -1, 0, 1).getInt());
        Assert.assertEquals(40L, LinearAlgebra.trace(reshape, 1, 0, 1).getInt());
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new int[]{36, 116, 196}, (int[]) null), LinearAlgebra.trace(DatasetFactory.createRange(60.0d, 3).reshape(new int[]{3, 4, 5}), 0, 1, 2), true, 1.0E-12d, 1.0E-12d);
    }

    @Test
    public void testKronecker() {
        Dataset createFromObject = DatasetFactory.createFromObject(3, new int[]{1, 10, 100});
        Dataset createFromObject2 = DatasetFactory.createFromObject(3, new int[]{5, 6, 7});
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new int[]{5, 6, 7, 50, 60, 70, 500, 600, 700}, (int[]) null), LinearAlgebra.kroneckerProduct(createFromObject, createFromObject2), true, 1.0E-12d, 1.0E-12d);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new int[]{5, 50, 500, 6, 60, 600, 7, 70, 700}, (int[]) null), LinearAlgebra.kroneckerProduct(createFromObject2, createFromObject), true, 1.0E-12d, 1.0E-12d);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new float[]{1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}, new int[]{4, 4}), LinearAlgebra.kroneckerProduct(DatasetUtils.eye(2, 2, 0, 2), DatasetFactory.ones(new int[]{2, 2}, 5)), true, 1.0E-12d, 1.0E-12d);
    }

    @Test
    public void testPower() {
        Dataset createFromObject = DatasetFactory.createFromObject(new int[]{0, 1, -1}, new int[]{2, 2});
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new int[]{0, -1, 1}, new int[]{2, 2}), LinearAlgebra.power(createFromObject, 3), true, 1.0E-12d, 1.0E-12d);
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{0.0d, 1.0d, -1.0d, 0.0d}, new int[]{2, 2}), LinearAlgebra.power(createFromObject, -3), true, 1.0E-12d, 1.0E-12d);
        TestUtils.assertDatasetEquals(DatasetUtils.eye(4, 4, 0, 3), LinearAlgebra.power(DatasetFactory.zeros(new int[]{4, 4}, 3), 0), true, 1.0E-12d, 1.0E-12d);
    }

    @Test
    public void testSolve() {
        TestUtils.assertDatasetEquals(DatasetFactory.createFromObject(new double[]{2.0d, 3.0d}), LinearAlgebra.solve(DatasetFactory.createFromObject(new int[]{3, 1, 1, 2}, new int[]{2, 2}), DatasetFactory.createFromObject(new int[]{9, 8}, (int[]) null)), true, 1.0E-12d, 1.0E-12d);
    }
}
