% Demo of Gradient Ascent % % The purpose of this program is to clearly demonstrate gradient ascent % which is a simple and fundamental machine learning algorithm. % % Please drop me an email if this is helpful to you. % % CREATED BY: Ethan Schreiber (els@cs.brown.edu), Brown University, December, 2005 clear clc close all; % Define our function as f(x) = 1-x^2 (an inverse parabola) % type 'help polyval' for more information as to % why this is defined this way fX = [-1 0 1]; % So our derivative is -2x fXPrime = [-2 0]; % NU is the learning rate. % IE, how fast do we climb the hill % Experiement with this value to see how the rate of learning % and whether we actually can find the maximum. NU = .01; % How close to a slope of 0 do we need to get before we stop? % We stop when abs(slope) < STOPPING_VALUE STOPPING_VALUE = .01; % In case we dont reach our stopping value, how long are we % willing to wait before we stop our algorithm? MAX_ITERATIONS = 1000; % Create the X_AXIS to go from -100 to 100 X_AXIS = -100:100; % Set an initial guess for the peak % pick a uniform random integer somewhere on our x axis INITIAL_GUESS = randint(1,1,[min(X_AXIS) max(X_AXIS)]); % Create the y values for our function which is y = polyval(fX,X_AXIS); % Plot the function and prepare for ascent figure(1); plot(X_AXIS,y) xlabel('x'); ylabel('y'); title('Gradient Ascent Example'); hold on; % Set the initial value x = INITIAL_GUESS; fprintf('Initial Guess: %d\n\n', INITIAL_GUESS); % Now start the ascent for i=1:MAX_ITERATIONS % compute the slope by plugging into the derivative function slope = polyval(fXPrime,x); fprintf('x: %-5.2f slope: %-5.2f\n',x,slope); % If we are close enough to the derivative = 0 and % have found a peak or inflection point, stop! if (abs(slope) < STOPPING_VALUE) break; else % else we compute the ascent by moving along x = x + (NU * slope); % the slope by the learning rate end % P.S. If you would like to do gradient descent, % change the plus in x + (NU... to a minus. % Plot the next guess plot(x,1 - x^2,'.r'); hold on; drawnow; end fprintf('Found a peak at x=%5.2f',x);