@@ -457,8 +457,9 @@ def __call__(self, X, alpha=None, bytes=False):
457457 return the RGBA values ``X*100`` percent along the Colormap line.
458458 For integers, X should be in the interval ``[0, Colormap.N)`` to
459459 return RGBA values *indexed* from the Colormap with index ``X``.
460- alpha : float, None
461- Alpha must be a scalar between 0 and 1, or None.
460+ alpha : float, array-like, None
461+ Alpha must be a scalar between 0 and 1, a sequence of such
462+ floats with shape matching X, or None.
462463 bytes : bool
463464 If False (default), the returned RGBA values will be floats in the
464465 interval ``[0, 1]`` otherwise they will be uint8s in the interval
@@ -484,6 +485,26 @@ def __call__(self, X, alpha=None, bytes=False):
484485xa = xma .filled ()# Fill to avoid infs, etc.
485486del xma
486487
488+ if alpha is not None :
489+ if not cbook .iterable (alpha ):
490+ alphatype = 'scalar'
491+ if alpha < 0 or alpha > 1 :
492+ raise ValueError ("alpha must be in [0, 1];"
493+ " found %s" % alpha )
494+ else :
495+ alphatype = 'array'
496+ alpha = np .asarray (alpha )
497+ _mina = alpha .min ()
498+ _maxa = alpha .max ()
499+ if _mina < 0 or _maxa > 1 :
500+ raise ValueError ("alpha must be in [0, 1];"
501+ " found min %s and max %s" %
502+ (_mina ,_maxa ))
503+ if not (alpha .shape == xa .shape ):
504+ raise ValueError ("alpha is array-like but it's shape"
505+ " %s doesn't match that of X %s" %
506+ (alpha .shape ,xa .shape ))
507+
487508# Calculations with native byteorder are faster, and avoid a
488509# bug that otherwise can occur with putmask when the last
489510# argument is a numpy scalar.
@@ -514,7 +535,7 @@ def __call__(self, X, alpha=None, bytes=False):
514535else :
515536lut = self ._lut .copy ()# Don't let alpha modify original _lut.
516537
517- if alpha is not None :
538+ if alpha is not None and alphatype == 'scalar' :
518539alpha = np .clip (alpha ,0 ,1 )
519540if bytes :
520541alpha = int (alpha * 255 )
@@ -530,6 +551,17 @@ def __call__(self, X, alpha=None, bytes=False):
530551
531552rgba = np .empty (shape = xa .shape + (4 ,),dtype = lut .dtype )
532553lut .take (xa ,axis = 0 ,mode = 'clip' ,out = rgba )
554+
555+ if alpha is not None and alphatype == 'array' :
556+ if bytes :
557+ alpha = (alpha * 255 ).astype (np .uint8 )
558+ rgba [...,- 1 ]= alpha
559+ if (lut [- 1 ]== 0 ).all ()and mask_bad is not None :
560+ if mask_bad .shape == xa .shape :
561+ rgba [mask_bad ]= (0 ,0 ,0 ,0 )
562+ elif mask_bad :
563+ rgba [..., :]= (0 ,0 ,0 ,0 )
564+
533565if vtype == 'scalar' :
534566rgba = tuple (rgba [0 , :])
535567return rgba