import numpy as np

INPUT_SIZE = 3
OUTPUT_SIZE = 3

CONCAT_SIZE = OUTPUT_SIZE + INPUT_SIZE  # Size of concatenate(y_prev,x)

W_f = np.random.randn(OUTPUT_SIZE, CONCAT_SIZE)
print('W_f shape',W_f)
b_f = np.random.randn(OUTPUT_SIZE, 1)

W_i = np.random.randn(OUTPUT_SIZE, CONCAT_SIZE)
b_i = np.random.randn(OUTPUT_SIZE, 1)

W_c = np.random.randn(OUTPUT_SIZE, CONCAT_SIZE)
b_c = np.random.randn(OUTPUT_SIZE, 1)

W_o = np.random.randn(OUTPUT_SIZE, CONCAT_SIZE)
b_o = np.random.randn(OUTPUT_SIZE, 1)

c_prev = np.zeros((INPUT_SIZE,1), dtype=float)
y_prev = np.zeros((INPUT_SIZE,1), dtype=float)


def forward(x):
    global y_prev
    global c_prev

    assert x.shape == (INPUT_SIZE, 1)
    
    concat_y_prev_x = np.row_stack((y_prev, x))
    print(concat_y_prev_x)
    f = sigmoid(np.dot(W_f,concat_y_prev_x) + b_f)

    i = sigmoid(np.dot(W_i, concat_y_prev_x) + b_i)
    c_bar = np.tanh(np.dot(W_c, concat_y_prev_x) + b_c)

    c = f * c_prev + i * c_bar
    c_prev = c

    o = sigmoid(np.dot(W_o, concat_y_prev_x) + b_o)
    y = o * np.tanh(c)

    y_prev = y
    return y


def sigmoid(s):
    return 1 / (1 + np.exp(-s))

