In [ ]:
import xarray as xr

%matplotlib inline

import matplotlib.pyplot as plt

import numpy as np
import keras
from keras.layers import Dense, Input
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.losses import MSE

from sklearn.externals import joblib
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold

from sklearn.pipeline import make_pipeline
from sklearn.metrics import r2_score


from lib.util import output_to_xr, dict_to_xr, swap_coord
In [ ]:
preproc = make_pipeline(VarianceThreshold(.001), StandardScaler())
In [ ]:
data = joblib.load("../data/ml/ngaqua/data.pkl")
p = xr.open_dataset("../data/raw/ngaqua/stat.nc").p

x_train, y_train = preproc.fit_transform(data['train'][0]), data['train'][1]
x_test, y_test = preproc.transform(data['test'][0]), data['test'][1]

w_out = data['w'][1]
In [ ]:
n_in, n_out = x_train.shape[1], y_train.shape[1]
In [ ]:
model = Sequential([
    Dense(256, activation='relu', input_shape=(n_in,)),
    Dense(256, activation='relu', input_shape=(n_in,)),


#     Dense(100, activation='relu'),
#     Dense(100, activation='relu'),
#     Dense(100, activation='relu'),
#     Dense(256, activation='relu'),
#     Dense(256, activation='relu'),
#     Dense(256, activation='relu'),
#     Dense(256, activation='relu'),
    Dense(n_out, activation='linear')])

optimizer = Adam(lr=.001)

import tensorflow as tf

def myloss(true, pred):
    return tf.reduce_mean((tf.pow(true - pred, 2) * w_out))

model.compile(optimizer, myloss)
In [ ]:
inds = np.random.choice(x_train.shape[0], 10000)
model.compile(optimizer, myloss)
In [ ]:
model.fit(x_train, y_train, batch_size=40, epochs=3)
In [ ]:
from lib.models import weighted_r2_score

pred=  model.predict(x_test)
weighted_r2_score(y_test, pred, w_out)
In [ ]:
predictions= dict(pred = output_to_xr(pred, y_test.coords),
                  true = output_to_xr(y_test, y_test.coords))

preds_xr = dict_to_xr(predictions, dim_name="model")\
.pipe(lambda x: swap_coord(x, z=p))
In [ ]:
axs = preds_xr.isel(x=0, y=8).Q1c.plot(col='model', col_wrap=1, cmap="inferno", vmin=-20, vmax=100,
                                figsize=(8,4))
plt.gca().invert_yaxis()