propclass/neuralnet.h
00001 /* 00002 Crystal Space Entity Layer 00003 Copyright (C) 2007 by Jorrit Tyberghein 00004 00005 Neural Network Property Class 00006 Copyright (C) 2007 by Mat Sutcliffe 00007 00008 This library is free software; you can redistribute it and/or 00009 modify it under the terms of the GNU Library General Public 00010 License as published by the Free Software Foundation; either 00011 version 2 of the License, or (at your option) any later version. 00012 00013 This library is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00016 Library General Public License for more details. 00017 00018 You should have received a copy of the GNU Library General Public 00019 License along with this library; if not, write to the Free 00020 Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00021 */ 00022 00023 #ifndef __CEL_PF_NEURALNET__ 00024 #define __CEL_PF_NEURALNET__ 00025 00026 #include "cstypes.h" 00027 #include "csutil/scf.h" 00028 #include "csutil/refcount.h" 00029 #include "csutil/array.h" 00030 #include "csgeom/math.h" 00031 00032 #include "physicallayer/datatype.h" 00033 00034 class celNNActivationFunc; 00035 00049 struct iCelNNWeights : public virtual iBase 00050 { 00051 SCF_INTERFACE(iCelNNWeights, 0, 0, 1); 00052 00054 virtual csArray< csArray< csArray<float> > >& Data() = 0; 00055 00057 virtual const csArray< csArray< csArray<float> > >& Data() const = 0; 00058 00060 virtual csPtr<iCelNNWeights> Clone() const = 0; 00061 }; 00062 00096 struct iPcNeuralNet : public virtual iBase 00097 { 00098 SCF_INTERFACE(iPcNeuralNet, 0, 0, 1); 00099 00104 virtual void SetSize(size_t inputs, size_t outputs, size_t layers) = 0; 00105 00116 virtual void SetComplexity(const char *name) = 0; 00117 00124 virtual void SetLayerSizes(const csArray<size_t> &sizes) = 0; 00125 00129 virtual void SetActivationFunc(celNNActivationFunc *) = 0; 00130 00136 virtual bool Validate() = 0; 00137 00141 virtual void SetInput(size_t index, const celData &value) = 0; 00142 00146 virtual const celData& GetOutput(size_t index) const = 0; 00147 00151 virtual void SetInputs(const csArray<celData> &values) = 0; 00152 00156 virtual const csArray<celData>& GetOutputs() const = 0; 00157 00162 virtual void Process() = 0; 00163 00167 virtual csPtr<iCelNNWeights> CreateEmptyWeights() const = 0; 00168 00172 virtual void GetWeights(iCelNNWeights *out) const = 0; 00173 00177 virtual bool SetWeights(const iCelNNWeights *in) = 0; 00178 00182 virtual bool CacheWeights(const char *scope, uint32 id) const = 0; 00183 00187 virtual bool LoadCachedWeights(const char *scope, uint32 id) = 0; 00188 }; 00189 00203 class celNNActivationFunc : public virtual csRefCount 00204 { 00205 public: 00207 virtual void Function(celData &data) = 0; 00208 00210 virtual celDataType GetDataType() = 0; 00211 00213 virtual ~celNNActivationFunc() {} 00214 00215 protected: 00227 template <typename T> 00228 static const T& GetFrom(const celData &input); 00229 00240 template <typename T> 00241 static celDataType DataType(); 00242 }; 00243 00244 template<> 00245 inline const float& celNNActivationFunc::GetFrom<float>(const celData &input) 00246 { 00247 return input.value.f; 00248 } 00249 template<> 00250 inline const int8& celNNActivationFunc::GetFrom<int8>(const celData &input) 00251 { 00252 return input.value.b; 00253 } 00254 template<> 00255 inline const uint8& celNNActivationFunc::GetFrom<uint8>(const celData &input) 00256 { 00257 return input.value.ub; 00258 } 00259 template<> 00260 inline const int16& celNNActivationFunc::GetFrom<int16>(const celData &input) 00261 { 00262 return input.value.w; 00263 } 00264 template<> 00265 inline const uint16& celNNActivationFunc::GetFrom<uint16>(const celData &input) 00266 { 00267 return input.value.uw; 00268 } 00269 template<> 00270 inline const int32& celNNActivationFunc::GetFrom<int32>(const celData &input) 00271 { 00272 return input.value.l; 00273 } 00274 template<> 00275 inline const uint32& celNNActivationFunc::GetFrom<uint32>(const celData &input) 00276 { 00277 return input.value.ul; 00278 } 00279 00280 template<> 00281 inline celDataType celNNActivationFunc::DataType<int8>() 00282 { 00283 return CEL_DATA_BYTE; 00284 } 00285 template<> 00286 inline celDataType celNNActivationFunc::DataType<uint8>() 00287 { 00288 return CEL_DATA_UBYTE; 00289 } 00290 template<> 00291 inline celDataType celNNActivationFunc::DataType<int16>() 00292 { 00293 return CEL_DATA_WORD; 00294 } 00295 template<> 00296 inline celDataType celNNActivationFunc::DataType<uint16>() 00297 { 00298 return CEL_DATA_UWORD; 00299 } 00300 template<> 00301 inline celDataType celNNActivationFunc::DataType<int32>() 00302 { 00303 return CEL_DATA_LONG; 00304 } 00305 template<> 00306 inline celDataType celNNActivationFunc::DataType<uint32>() 00307 { 00308 return CEL_DATA_ULONG; 00309 } 00310 template<> 00311 inline celDataType celNNActivationFunc::DataType<float>() 00312 { 00313 return CEL_DATA_FLOAT; 00314 } 00315 00327 template <typename T> 00328 class celNopActivationFunc : public celNNActivationFunc 00329 { 00330 public: 00331 virtual void Function(celData &data) {} 00332 virtual celDataType GetDataType() { return DataType<T>(); } 00333 virtual ~celNopActivationFunc() {} 00334 }; 00335 00347 template <typename T> 00348 class celStepActivationFunc : public celNNActivationFunc 00349 { 00350 public: 00351 virtual void Function(celData &data) 00352 { 00353 const T &val = GetFrom<T>(data); 00354 data.Set(T (val > 1 ? 1 : 0)); 00355 } 00356 virtual celDataType GetDataType() { return DataType<T>(); } 00357 virtual ~celStepActivationFunc() {} 00358 }; 00359 00371 template <typename T> 00372 class celLogActivationFunc : public celNNActivationFunc 00373 { 00374 public: 00375 virtual void Function(celData &data) 00376 { 00377 const T &val = GetFrom<T>(data); 00378 double e_v = log(fabs((double) val)); // log may return not-a-number 00379 data.Set((T) (csNormal(e_v) ? e_v : 0.0)); 00380 } 00381 virtual celDataType GetDataType() { return DataType<T>(); } 00382 virtual ~celLogActivationFunc() {} 00383 }; 00384 00396 template <typename T> 00397 class celAtanActivationFunc : public celNNActivationFunc 00398 { 00399 public: 00400 virtual void Function(celData &data) 00401 { 00402 const T &val = GetFrom<T>(data); 00403 data.Set((T) atan((double) val)); 00404 } 00405 virtual celDataType GetDataType() { return DataType<T>(); } 00406 virtual ~celAtanActivationFunc() {} 00407 }; 00408 00421 template <typename T> 00422 class celTanhActivationFunc : public celNNActivationFunc 00423 { 00424 public: 00425 virtual void Function(celData &data) 00426 { 00427 const T &val = GetFrom<T>(data); 00428 data.Set((T) tanh((double) val)); 00429 } 00430 virtual celDataType GetDataType() { return DataType<T>(); } 00431 virtual ~celTanhActivationFunc() {} 00432 }; 00433 00445 template <typename T> 00446 class celExpActivationFunc : public celNNActivationFunc 00447 { 00448 public: 00449 virtual void Function(celData &data) 00450 { 00451 const T &val = GetFrom<T>(data); 00452 double e_v = exp((double) val); // exp may return infinite 00453 data.Set((T) (csNormal(e_v) ? e_v : 0.0)); 00454 } 00455 virtual celDataType GetDataType() { return DataType<T>(); } 00456 virtual ~celExpActivationFunc() {} 00457 }; 00458 00468 class celInvActivationFunc : public celNNActivationFunc 00469 { 00470 public: 00471 virtual void Function(celData &data) 00472 { 00473 const float &val = GetFrom<float>(data); 00474 data.Set(1.0f / val); 00475 } 00476 virtual celDataType GetDataType() { return CEL_DATA_FLOAT; } 00477 virtual ~celInvActivationFunc() {} 00478 }; 00479 00491 template <typename T> 00492 class celSqrActivationFunc : public celNNActivationFunc 00493 { 00494 public: 00495 virtual void Function(celData &data) 00496 { 00497 const T &val = GetFrom<T>(data); 00498 data.Set(val * val); 00499 } 00500 virtual celDataType GetDataType() { return DataType<T>(); } 00501 virtual ~celSqrActivationFunc() {} 00502 }; 00503 00515 template <typename T> 00516 class celGaussActivationFunc : public celNNActivationFunc 00517 { 00518 public: 00519 virtual void Function(celData &data) 00520 { 00521 const T &val = GetFrom<T>(data); 00522 data.Set((T) exp((double) -(val * val))); 00523 } 00524 virtual celDataType GetDataType() { return DataType<T>(); } 00525 virtual ~celGaussActivationFunc() {} 00526 }; 00527 00539 template <typename T> 00540 class celSinActivationFunc : public celNNActivationFunc 00541 { 00542 public: 00543 virtual void Function(celData &data) 00544 { 00545 const T &val = GetFrom<T>(data); 00546 data.Set((T) sin((double) val)); 00547 } 00548 virtual celDataType GetDataType() { return DataType<T>(); } 00549 virtual ~celSinActivationFunc() {} 00550 }; 00551 00563 template <typename T> 00564 class celCosActivationFunc : public celNNActivationFunc 00565 { 00566 public: 00567 virtual void Function(celData &data) 00568 { 00569 const T &val = GetFrom<T>(data); 00570 data.Set((T) cos((double) val)); 00571 } 00572 virtual celDataType GetDataType() { return DataType<T>(); } 00573 virtual ~celCosActivationFunc() {} 00574 }; 00575 00587 template <typename T> 00588 class celElliottActivationFunc : public celNNActivationFunc 00589 { 00590 public: 00591 virtual void Function(celData &data) 00592 { 00593 const T &val = GetFrom<T>(data); 00594 data.Set(val / (1 + (T)fabs((double) val))); 00595 } 00596 virtual celDataType GetDataType() { return DataType<T>(); } 00597 virtual ~celElliottActivationFunc() {} 00598 }; 00599 00611 template <typename T> 00612 class celSigActivationFunc : public celNNActivationFunc 00613 { 00614 public: 00615 virtual void Function(celData &data) 00616 { 00617 const T &val = GetFrom<T>(data); 00618 data.Set(T(1) / (T)(1 + exp((double) -val))); 00619 } 00620 virtual celDataType GetDataType() { return DataType<T>(); } 00621 virtual ~celSigActivationFunc() {} 00622 }; 00623 00624 #endif // __CEL_PF_NEURALNET__ 00625
Generated for CEL: Crystal Entity Layer 1.4.1 by doxygen 1.7.1