Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/atoms/affine/right_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ static void refresh_dense_right(left_matmul_expr *lnode)

expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
{
/* Validate dimensions for y = u @ A, A is (m, n):
- Standard 2D: u is (p, m) -> y is (p, n).
- numpy 1D broadcast: u has shape (m,), stored as (1, m) -> y is (1, n).
We already do something similar in left_matmul
But it is good to catch it earlier here. */
if (u->d2 != A->m)
{
fprintf(stderr, "Error in new_right_matmul: dimension mismatch \n");
exit(1);
}

/* We can express right matmul using left matmul and transpose:
u @ A = (A^T @ u^T)^T. */
int *work_transpose = (int *) SP_MALLOC(A->n * sizeof(int));
Expand Down Expand Up @@ -82,6 +93,13 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
const double *data)
{
if (u->d2 != m)
{
fprintf(stderr,
"Error in new_right_matmul_dense: dimension mismatch \n");
exit(1);
}

/* We express: u @ A = (A^T @ u^T)^T. A is m x n, so A^T is n x m. */
double *AT = (double *) SP_MALLOC(n * m * sizeof(double));
A_transpose(AT, data, m, n);
Expand Down
5 changes: 5 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "forward_pass/affine/test_linear_op.h"
#include "forward_pass/affine/test_neg.h"
#include "forward_pass/affine/test_promote.h"
#include "forward_pass/affine/test_right_matmul.h"
#include "forward_pass/affine/test_sum.h"
#include "forward_pass/affine/test_upper_tri.h"
#include "forward_pass/affine/test_variable_parameter.h"
Expand Down Expand Up @@ -134,6 +135,8 @@ int main(void)
mu_run_test(test_forward_prod_axis_one, tests_run);
mu_run_test(test_matmul, tests_run);
mu_run_test(test_left_matmul_dense, tests_run);
mu_run_test(test_right_matmul, tests_run);
mu_run_test(test_right_matmul_vector, tests_run);
mu_run_test(test_diag_mat_forward, tests_run);
mu_run_test(test_upper_tri_forward_4x4, tests_run);

Expand Down Expand Up @@ -382,6 +385,8 @@ int main(void)
mu_run_test(test_param_promote_vector_mult, tests_run);
mu_run_test(test_const_sum_scalar_mult, tests_run);
mu_run_test(test_param_sum_scalar_mult, tests_run);
mu_run_test(test_const_hstack_left_matmul, tests_run);
mu_run_test(test_param_hstack_left_matmul, tests_run);
#endif /* PROFILE_ONLY */

#ifdef PROFILE_ONLY
Expand Down
95 changes: 95 additions & 0 deletions tests/forward_pass/affine/test_right_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include <string.h>

#include "atoms/affine.h"
#include "expr.h"
#include "minunit.h"
#include "test_helpers.h"

const char *test_right_matmul(void)
{
/* Test: Y = u @ A where
* u is 2x3 variable: [[1, 2, 3], [4, 5, 6]]
* A is 3x2 sparse: [[1, 0], [0, 2], [3, 4]]
* Y = u @ A = [[10, 16], [22, 34]]
*/

/* Create u variable (2 x 3) */
expr *u = new_variable(2, 3, 0, 6);

/* Constant sparse matrix A (3 x 2) in CSR */
CSR_Matrix *A = new_csr_matrix(3, 2, 4);
int A_p[4] = {0, 1, 2, 4};
int A_i[4] = {0, 1, 0, 1};
double A_x[4] = {1.0, 2.0, 3.0, 4.0};
memcpy(A->p, A_p, 4 * sizeof(int));
memcpy(A->i, A_i, 4 * sizeof(int));
memcpy(A->x, A_x, 4 * sizeof(double));

/* Build expression Y = u @ A */
expr *Y = new_right_matmul(NULL, u, A);

/* Variable values in column-major order: cols [1,4], [2,5], [3,6] */
double u_vals[6] = {1.0, 4.0, 2.0, 5.0, 3.0, 6.0};

/* Evaluate forward pass */
Y->forward(Y, u_vals);

/* Expected result (2 x 2) in column-major order: cols [10,22], [16,34] */
double expected[4] = {10.0, 22.0, 16.0, 34.0};

/* Verify dimensions */
mu_assert("right_matmul result should have d1=2", Y->d1 == 2);
mu_assert("right_matmul result should have d2=2", Y->d2 == 2);
mu_assert("right_matmul result should have size=4", Y->size == 4);

/* Verify values */
mu_assert("right_matmul forward pass value mismatch",
cmp_double_array(Y->value, expected, 4));

free_csr_matrix(A);
free_expr(Y);
return 0;
}

const char *test_right_matmul_vector(void)
{
/* Test: numpy 1D broadcast with u shape (3,) stored as (1, 3).
* u = [1, 2, 3] @ A (same 3x2 sparse as above) = [10, 16]
*/

/* Create u variable (1 x 3) */
expr *u = new_variable(1, 3, 0, 3);

/* Constant sparse matrix A (3 x 2) in CSR */
CSR_Matrix *A = new_csr_matrix(3, 2, 4);
int A_p[4] = {0, 1, 2, 4};
int A_i[4] = {0, 1, 0, 1};
double A_x[4] = {1.0, 2.0, 3.0, 4.0};
memcpy(A->p, A_p, 4 * sizeof(int));
memcpy(A->i, A_i, 4 * sizeof(int));
memcpy(A->x, A_x, 4 * sizeof(double));

/* Build expression Y = u @ A */
expr *Y = new_right_matmul(NULL, u, A);

/* Variable values */
double u_vals[3] = {1.0, 2.0, 3.0};

/* Evaluate forward pass */
Y->forward(Y, u_vals);

/* Expected result (1 x 2) */
double expected[2] = {10.0, 16.0};

/* Verify dimensions */
mu_assert("right_matmul (vector) result should have d1=1", Y->d1 == 1);
mu_assert("right_matmul (vector) result should have d2=2", Y->d2 == 2);

/* Verify values */
mu_assert("right_matmul (vector) forward pass value mismatch",
cmp_double_array(Y->value, expected, 2));

free_csr_matrix(A);
free_expr(Y);
return 0;
}
80 changes: 80 additions & 0 deletions tests/problem/test_param_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,84 @@ const char *test_param_sum_scalar_mult(void)
return 0;
}

const char *test_const_hstack_left_matmul(void)
{
int n = 4;

/* minimize hstack(p1, p2) @ x, where p1 and p2 are fixed */
expr *x = new_variable(2*n, 1, 0, 2*n);
double p1_vals[4] = {1.0, 2.0, 3.0, 0.0};
expr *p1 = new_parameter(1, n, PARAM_FIXED, 2*n, p1_vals);
double p2_vals[4] = {4.0, 0.0, 5.0, 6.0};
expr *p2 = new_parameter(1, n, PARAM_FIXED, 2*n, p2_vals);
expr *param_nodes[2] = {p1, p2};
expr *p_hstack = new_hstack(param_nodes, 2, 2*n);
/* pass concatenated parameter vectors */
double A_data[8] = {1.0, 2.0, 3.0, 0.0, 4.0, 0.0, 5.0, 6.0};
expr *objective = new_left_matmul_dense(p_hstack, x, 1, 2*n, A_data);
problem *prob = new_problem(objective, NULL, 0, false);

problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[8] = {2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0};

problem_objective_forward(prob, x_vals);
double obj_val = 1.0*2.0 + 2.0*2.0 + 3.0*2.0 + 0.0*2.0 + 4.0*1.0 + 0.0*1.0 + 5.0*1.0 + 6.0*1.0;
mu_assert("vals fail", fabs(prob->objective->value[0] - obj_val) < 1e-6);

problem_gradient(prob);
double grad_x[8] = {1.0, 2.0, 3.0, 0.0, 4.0, 0.0, 5.0, 6.0};
mu_assert("vals fail", cmp_double_array(prob->gradient_values, grad_x, 8));

free_problem(prob);
return 0;
}

const char *test_param_hstack_left_matmul(void)
{
int n = 4;

/* minimize hstack(p1, p2) @ x, where p1 and p2 are parameter */
expr *x = new_variable(2*n, 1, 0, 2*n);
double p1_vals[4] = {1.0, 2.0, 3.0, 0.0};
expr *p1 = new_parameter(1, n, 0, 2*n, p1_vals);
double p2_vals[4] = {4.0, 0.0, 5.0, 6.0};
expr *p2 = new_parameter(1, n, n, 2*n, p2_vals);
expr *param_nodes[2] = {p1, p2};
expr *p_hstack = new_hstack(param_nodes, 2, 2*n);
/* pass concatenated parameter vectors */
double A_data[8] = {1.0, 2.0, 3.0, 0.0, 4.0, 0.0, 5.0, 6.0};
expr *objective = new_left_matmul_dense(p_hstack, x, 1, 2*n, A_data);
problem *prob = new_problem(objective, NULL, 0, false);

problem_register_params(prob, param_nodes, 2);
problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[8] = {2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0};

problem_objective_forward(prob, x_vals);
double obj_val = 1.0*2.0 + 2.0*2.0 + 3.0*2.0 + 0.0*2.0 + 4.0*1.0 + 0.0*1.0 + 5.0*1.0 + 6.0*1.0;
mu_assert("vals fail", fabs(prob->objective->value[0] - obj_val) < 1e-6);

problem_gradient(prob);
double grad_x[8] = {1.0, 2.0, 3.0, 0.0, 4.0, 0.0, 5.0, 6.0};
mu_assert("vals fail", cmp_double_array(prob->gradient_values, grad_x, 8));

double theta[8] = {5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0};
problem_update_params(prob, theta);

problem_objective_forward(prob, x_vals);
double updated_obj_val = 5.0*2.0 + 4.0*2.0 + 3.0*2.0 + 2.0*2.0 + 1.0*1.0 + 0.0*1.0 + 1.0*1.0 + 2.0*1.0;
mu_assert("vals fail", fabs(prob->objective->value[0] - updated_obj_val) < 1e-6);

problem_gradient(prob);
double updated_grad_x[8] = {5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0};
mu_assert("vals fail", cmp_double_array(prob->gradient_values, updated_grad_x, 8));

free_problem(prob);
return 0;
}

#endif /* TEST_PARAM_BROADCAST_H */
Loading