Sample Bias 2D

The following example is a 2D regression domain adaptation issue. The goal is to learn the regression task on the target data (orange points) knowing only the labels on the source data (blue points).

In this example, there is a sample bias between the source and target datasets. The sources are mostly located in X1=0 whereas the targets are uniformly distributed.

The following methods are being tested:

[1]:
importnumpyasnpimportmatplotlib.pyplotaspltfrommpl_toolkits.mplot3dimportAxes3Dimportmatplotlib.animationasanimationfromsklearn.neural_networkimportMLPRegressorfromsklearn.metricsimportmean_absolute_error,mean_squared_errorfromsklearn.metrics.pairwiseimportrbf_kernelfromadapt.instance_basedimportKMM,KLIEPimporttensorflowastffromtensorflow.kerasimportSequentialfromtensorflow.keras.optimizersimportAdamfromtensorflow.keras.layersimportDensefromtensorflow.keras.modelsimportclone_model

Setup

[2]:
deff(x0,x1):x0=(x0+1.)/2.x1=(x1+1.)/2.return(1/100)*(100*(x1-x0**2)**2+(1-x0)**2)np.random.seed(5)Xs=np.stack([np.linspace(-1,1,20),np.zeros(20)],-1)Xs=np.concatenate((Xs,np.random.random((10,2))*2-1))xt_grid,yt_grid=np.meshgrid(np.linspace(-1,1.,20),np.linspace(-1,1.,10))Xt=np.stack([xt_grid.ravel(),yt_grid.ravel()],-1)x_grid,y_grid=np.meshgrid(np.linspace(-1,1.,100),np.linspace(-1,1.,100))ys=f(Xs[:,0],Xs[:,1])yt=f(Xt[:,0],Xt[:,1])z_grid=f(x_grid.ravel(),y_grid.ravel())
[3]:
fig=plt.figure(figsize=(20,5))ax1=fig.add_subplot(1,3,1)ax1.plot(Xt[:,0],Xt[:,1],'.',c="C1",label="target",ms=8,alpha=0.7,markeredgecolor="black")ax1.plot(Xs[:,0],Xs[:,1],'.',c="C0",label="source",ms=14,alpha=0.7,markeredgecolor="black")ax1.set_yticklabels([])ax1.set_xticklabels([])ax1.tick_params(direction='in')ax1.legend()ax1.set_xlabel("X0",fontsize=12)ax1.set_ylabel("X1",fontsize=12)ax1.set_title("Input Space",fontsize=14)ax2=fig.add_subplot(1,3,2)ax2.scatter(x_grid.ravel(),y_grid.ravel(),c=z_grid,cmap="jet")ax2.set_yticklabels([])ax2.set_xticklabels([])ax2.tick_params(direction='in')ax2.set_xlabel("X0",fontsize=12)ax2.set_ylabel("X1",fontsize=12)ax2.set_title("Output Function Y=f(X0, X1)",fontsize=14)ax3=fig.add_subplot(1,3,3)forx1in[0.,0.5,1.,-0.5,-0.8]:X_=np.concatenate((np.linspace(-1,1,100).reshape(-1,1),np.ones((100,1))*x1),axis=1)ax3.plot(X_[:,0],f(X_[:,0],X_[:,1]),label="X1 =%.1f"%x1)ax3.set_yticklabels([])ax3.set_xticklabels([])ax3.tick_params(direction='in')ax3.legend()ax3.set_xlabel("X0",fontsize=12)ax3.set_ylabel("Y",fontsize=12)ax3.set_title("Y in function of X0",fontsize=14)plt.subplots_adjust(wspace=0.1)
../_images/examples_sample_bias_2d_6_0.png

Estimator

[19]:
np.random.seed(0)tf.random.set_seed(0)model=Sequential()model.add(Dense(100,activation="relu",input_shape=(2,)))model.add(Dense(100,activation="relu"))model.add(Dense(1,activation=None))model.compile(loss="mse",optimizer=Adam(0.001))fit_params=dict(epochs=300,batch_size=34,verbose=0)

Source Only

[20]:
np.random.seed(0)tf.random.set_seed(0)estimator=clone_model(model)estimator.compile(loss="mse",optimizer=Adam(0.001))estimator.fit(Xs,ys,**fit_params);
[21]:
yp_grid=estimator.predict(np.stack([x_grid.ravel(),y_grid.ravel()],-1)).ravel()error_grid=np.abs(yp_grid-z_grid)score=mean_absolute_error(estimator.predict(Xt).ravel(),yt)fig,ax=plt.subplots(1,1,figsize=(8,5))ax.plot(Xs[:,0],Xs[:,1],'.',c="C0",ms=14,alpha=0.7,markeredgecolor="black")ax.scatter(x_grid.ravel(),y_grid.ravel(),c=error_grid)ax.set_xlabel("X0",fontsize=16)ax.set_ylabel("X1",fontsize=16)ax.set_title("Error Map -- Source Only -- Target MAE :%.3f"%score)ax.set_yticklabels([])ax.set_xticklabels([])ax.tick_params(direction='in')plt.show()
../_images/examples_sample_bias_2d_11_0.png

KMM

[22]:
np.random.seed(0)kmm=KMM(model,gamma=2.,random_state=0)kmm.fit(Xs,ys,Xt,**fit_params);
Fit weights...     pcost       dcost       gap    pres   dres 0:  2.7084e+04 -4.3392e+05  1e+07  6e-01  2e-14 1:  2.3438e+02 -1.0551e+05  2e+05  4e-03  2e-11 2:  1.7072e+02 -2.1481e+04  2e+04  9e-06  2e-12 3:  1.6458e+02 -8.2292e+02  1e+03  4e-07  5e-14 4:  1.0422e+02 -6.2452e+02  7e+02  2e-07  3e-14 5: -7.2423e+01 -9.3040e+02  9e+02  5e-08  6e-15 6: -7.7015e+01 -2.8294e+02  2e+02  1e-08  2e-15 7: -7.9350e+01 -2.8047e+02  2e+02  1e-08  1e-15 8: -8.7166e+01 -1.2479e+02  4e+01  2e-16  3e-16 9: -8.9553e+01 -9.7350e+01  8e+00  2e-16  1e-1610: -9.0635e+01 -9.2636e+01  2e+00  2e-16  2e-1611: -9.1022e+01 -9.1316e+01  3e-01  2e-16  1e-1612: -9.1106e+01 -9.1145e+01  4e-02  2e-16  1e-1613: -9.1116e+01 -9.1120e+01  4e-03  2e-16  2e-1614: -9.1118e+01 -9.1118e+01  6e-05  2e-16  2e-16Optimal solution found.Fit Estimator...
[23]:
yp_grid=kmm.predict(np.stack([x_grid.ravel(),y_grid.ravel()],-1)).ravel()error_grid=np.abs(yp_grid-z_grid)score=mean_absolute_error(kmm.predict(Xt).ravel(),yt)weights=kmm.predict_weights()*100fig,ax=plt.subplots(1,1,figsize=(8,5))ax.scatter(x_grid.ravel(),y_grid.ravel(),c=error_grid)ax.scatter(Xs[:,0],Xs[:,1],c="C0",s=weights,alpha=0.7,edgecolor="black")ax.set_xlabel("X0",fontsize=16)ax.set_ylabel("X1",fontsize=16)ax.set_title("Error Map -- KMM -- Target MAE :%.3f"%score)ax.set_yticklabels([])ax.set_xticklabels([])ax.tick_params(direction='in')plt.show()
../_images/examples_sample_bias_2d_14_0.png

KLIEP

[24]:
np.random.seed(0)kliep=KLIEP(model,sigmas=[0.001,0.01,0.1,0.5,1.,2.,5.,10.],random_state=0,max_centers=200)kliep.fit(Xs,ys,Xt,**fit_params);
Fit weights...Cross Validation process...Parameter sigma = 0.0010 -- J-score = -0.000 (0.000)Parameter sigma = 0.0100 -- J-score = -0.004 (0.001)Parameter sigma = 0.1000 -- J-score = -0.033 (0.009)Parameter sigma = 0.5000 -- J-score = -0.068 (0.019)Parameter sigma = 1.0000 -- J-score = -0.002 (0.022)Parameter sigma = 2.0000 -- J-score = 0.157 (0.026)Parameter sigma = 5.0000 -- J-score = 0.393 (0.023)Parameter sigma = 10.0000 -- J-score = 0.467 (0.008)Fit Estimator...
[25]:
yp_grid=kliep.predict(np.stack([x_grid.ravel(),y_grid.ravel()],-1)).ravel()error_grid=np.abs(yp_grid-z_grid)score=mean_absolute_error(kliep.predict(Xt).ravel(),yt)weights=kliep.predict_weights()*100fig,ax=plt.subplots(1,1,figsize=(8,5))ax.scatter(x_grid.ravel(),y_grid.ravel(),c=error_grid)ax.scatter(Xs[:,0],Xs[:,1],c="C0",s=weights,alpha=0.7,edgecolor="black")ax.set_xlabel("X0",fontsize=16)ax.set_ylabel("X1",fontsize=16)ax.set_title("Error Map -- KLIEP -- Target MAE :%.3f"%score)ax.set_yticklabels([])ax.set_xticklabels([])ax.tick_params(direction='in')plt.show()
../_images/examples_sample_bias_2d_17_0.png
[ ]: